Skip to content

Commit 1f1f8bf

Browse files
committed
[Performance] Expose Isaac collector profiling knobs
ghstack-source-id: 1ff492c Pull-Request: #3726
1 parent 15edca3 commit 1f1f8bf

1 file changed

Lines changed: 72 additions & 7 deletions

File tree

examples/collectors/profile_isaaclab_collector.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import os
1717
import sys
18+
import typing
1819
from pathlib import Path
1920

2021
os.environ.setdefault("TORCHRL_PROFILING", "1")
@@ -56,6 +57,36 @@
5657
default="cuda:0",
5758
help="Device passed to the TorchRL IsaacLabWrapper.",
5859
)
60+
parser.add_argument(
61+
"--native-autoreset",
62+
action=argparse.BooleanOptionalAction,
63+
default=False,
64+
help="Use Isaac Lab native autoreset observations without synthetic reset calls.",
65+
)
66+
parser.add_argument(
67+
"--track-traj-ids",
68+
action=argparse.BooleanOptionalAction,
69+
default=True,
70+
help="Track collector trajectory ids in emitted batches.",
71+
)
72+
parser.add_argument(
73+
"--policy-mode",
74+
default="rand_action",
75+
choices=["rand_action", "fixed_zero"],
76+
help="Policy used by the collector during profiling.",
77+
)
78+
parser.add_argument(
79+
"--no-cuda-sync",
80+
action=argparse.BooleanOptionalAction,
81+
default=True,
82+
help="Disable CUDA synchronization in collector timing and profiling hooks.",
83+
)
84+
parser.add_argument(
85+
"--trust-policy",
86+
action=argparse.BooleanOptionalAction,
87+
default=True,
88+
help="Trust the policy output and skip selected collector safety checks.",
89+
)
5990
parser.add_argument(
6091
"--activities",
6192
nargs="+",
@@ -103,17 +134,46 @@
103134
import torch
104135

105136
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
137+
from tensordict import TensorDictBase
106138
from torchrl._utils import logger as torchrl_logger
107139
from torchrl.collectors import Collector
108140
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
109141

110142

111-
def make_env(task: str, device: str):
143+
class FixedZeroPolicy:
144+
def __init__(self, env: IsaacLabWrapper):
145+
self.action_key = env.action_key
146+
self.action_spec = env.action_spec
147+
148+
def __call__(self, tensordict: TensorDictBase) -> TensorDictBase:
149+
action = self.action_spec.zero()
150+
if isinstance(action, TensorDictBase):
151+
tensordict.update(action)
152+
else:
153+
tensordict.set(self.action_key, action)
154+
return tensordict
155+
156+
157+
def make_env(task: str, device: str, native_autoreset: bool):
112158
if task == "Isaac-Ant-v0":
113159
env = gym.make(task, cfg=AntEnvCfg())
114160
else:
115161
env = gym.make(task)
116-
return IsaacLabWrapper(env, device=torch.device(device))
162+
return IsaacLabWrapper(
163+
env,
164+
device=torch.device(device),
165+
native_autoreset=native_autoreset,
166+
)
167+
168+
169+
def make_policy(
170+
env: IsaacLabWrapper, policy_mode: typing.Literal["rand_action", "fixed_zero"]
171+
):
172+
if policy_mode == "rand_action":
173+
return env.rand_action
174+
if policy_mode == "fixed_zero":
175+
return FixedZeroPolicy(env)
176+
raise ValueError(f"Unknown policy mode {policy_mode}.")
117177

118178

119179
def main() -> None:
@@ -127,17 +187,22 @@ def main() -> None:
127187
trace_path = args_cli.output_dir / "collector_worker_{worker_idx}.json"
128188
worker_0_trace_path = args_cli.output_dir / "collector_worker_0.json"
129189

130-
torchrl_logger.info(f"Creating Isaac Lab env {args_cli.task}.")
131-
env = make_env(args_cli.task, args_cli.env_device)
190+
torchrl_logger.info(
191+
f"Creating Isaac Lab env {args_cli.task} "
192+
f"(native_autoreset={args_cli.native_autoreset})."
193+
)
194+
env = make_env(args_cli.task, args_cli.env_device, args_cli.native_autoreset)
132195
env.set_seed(args_cli.seed)
196+
policy = make_policy(env, args_cli.policy_mode)
133197

134198
collector = Collector(
135199
env,
136-
env.rand_action,
200+
policy,
137201
frames_per_batch=args_cli.frames_per_batch,
138202
total_frames=-1,
139-
no_cuda_sync=True,
140-
trust_policy=True,
203+
no_cuda_sync=args_cli.no_cuda_sync,
204+
trust_policy=args_cli.trust_policy,
205+
track_traj_ids=args_cli.track_traj_ids,
141206
)
142207
collector.enable_profile(
143208
workers=[0],

0 commit comments

Comments
 (0)