1515import logging
1616import os
1717import sys
18+ import typing
1819from pathlib import Path
1920
2021os .environ .setdefault ("TORCHRL_PROFILING" , "1" )
5657 default = "cuda:0" ,
5758 help = "Device passed to the TorchRL IsaacLabWrapper." ,
5859)
60+ parser .add_argument (
61+ "--native-autoreset" ,
62+ action = argparse .BooleanOptionalAction ,
63+ default = False ,
64+ help = "Use Isaac Lab native autoreset observations without synthetic reset calls." ,
65+ )
66+ parser .add_argument (
67+ "--track-traj-ids" ,
68+ action = argparse .BooleanOptionalAction ,
69+ default = True ,
70+ help = "Track collector trajectory ids in emitted batches." ,
71+ )
72+ parser .add_argument (
73+ "--policy-mode" ,
74+ default = "rand_action" ,
75+ choices = ["rand_action" , "fixed_zero" ],
76+ help = "Policy used by the collector during profiling." ,
77+ )
78+ parser .add_argument (
79+ "--no-cuda-sync" ,
80+ action = argparse .BooleanOptionalAction ,
81+ default = True ,
82+ help = "Disable CUDA synchronization in collector timing and profiling hooks." ,
83+ )
84+ parser .add_argument (
85+ "--trust-policy" ,
86+ action = argparse .BooleanOptionalAction ,
87+ default = True ,
88+ help = "Trust the policy output and skip selected collector safety checks." ,
89+ )
5990parser .add_argument (
6091 "--activities" ,
6192 nargs = "+" ,
103134import torch
104135
105136from isaaclab_tasks .manager_based .classic .ant .ant_env_cfg import AntEnvCfg
137+ from tensordict import TensorDictBase
106138from torchrl ._utils import logger as torchrl_logger
107139from torchrl .collectors import Collector
108140from torchrl .envs .libs .isaac_lab import IsaacLabWrapper
109141
110142
111- def make_env (task : str , device : str ):
143+ class FixedZeroPolicy :
144+ def __init__ (self , env : IsaacLabWrapper ):
145+ self .action_key = env .action_key
146+ self .action_spec = env .action_spec
147+
148+ def __call__ (self , tensordict : TensorDictBase ) -> TensorDictBase :
149+ action = self .action_spec .zero ()
150+ if isinstance (action , TensorDictBase ):
151+ tensordict .update (action )
152+ else :
153+ tensordict .set (self .action_key , action )
154+ return tensordict
155+
156+
157+ def make_env (task : str , device : str , native_autoreset : bool ):
112158 if task == "Isaac-Ant-v0" :
113159 env = gym .make (task , cfg = AntEnvCfg ())
114160 else :
115161 env = gym .make (task )
116- return IsaacLabWrapper (env , device = torch .device (device ))
162+ return IsaacLabWrapper (
163+ env ,
164+ device = torch .device (device ),
165+ native_autoreset = native_autoreset ,
166+ )
167+
168+
169+ def make_policy (
170+ env : IsaacLabWrapper , policy_mode : typing .Literal ["rand_action" , "fixed_zero" ]
171+ ):
172+ if policy_mode == "rand_action" :
173+ return env .rand_action
174+ if policy_mode == "fixed_zero" :
175+ return FixedZeroPolicy (env )
176+ raise ValueError (f"Unknown policy mode { policy_mode } ." )
117177
118178
119179def main () -> None :
@@ -127,17 +187,22 @@ def main() -> None:
127187 trace_path = args_cli .output_dir / "collector_worker_{worker_idx}.json"
128188 worker_0_trace_path = args_cli .output_dir / "collector_worker_0.json"
129189
130- torchrl_logger .info (f"Creating Isaac Lab env { args_cli .task } ." )
131- env = make_env (args_cli .task , args_cli .env_device )
190+ torchrl_logger .info (
191+ f"Creating Isaac Lab env { args_cli .task } "
192+ f"(native_autoreset={ args_cli .native_autoreset } )."
193+ )
194+ env = make_env (args_cli .task , args_cli .env_device , args_cli .native_autoreset )
132195 env .set_seed (args_cli .seed )
196+ policy = make_policy (env , args_cli .policy_mode )
133197
134198 collector = Collector (
135199 env ,
136- env . rand_action ,
200+ policy ,
137201 frames_per_batch = args_cli .frames_per_batch ,
138202 total_frames = - 1 ,
139- no_cuda_sync = True ,
140- trust_policy = True ,
203+ no_cuda_sync = args_cli .no_cuda_sync ,
204+ trust_policy = args_cli .trust_policy ,
205+ track_traj_ids = args_cli .track_traj_ids ,
141206 )
142207 collector .enable_profile (
143208 workers = [0 ],
0 commit comments