From e353b20391c3d0c0fb34547c05b0c5deacd58943 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 25 Oct 2023 10:50:42 -0400 Subject: [PATCH] [BugFix] Fix EXAMPLES.md (#1649) --- .../linux_examples/scripts/run_test.sh | 91 +- examples/EXAMPLES.md | 76 +- examples/cql/cql_offline.py | 2 +- examples/cql/cql_online.py | 2 +- examples/cql/online_config.yaml | 2 +- examples/cql/utils.py | 27 +- examples/ddpg/config.yaml | 2 +- examples/ddpg/utils.py | 28 +- examples/decision_transformer/utils.py | 2 +- examples/discrete_sac/discrete_sac.py | 31 +- examples/dqn/config.yaml | 1 - examples/dqn/dqn.py | 1 + examples/dreamer/dreamer_utils.py | 3 - examples/iql/iql_online.py | 110 +- examples/iql/online_config.yaml | 82 +- examples/redq/config.yaml | 130 +- examples/redq/redq.py | 91 +- examples/redq/utils.py | 1052 +++++++++++++++++ examples/sac/config.yaml | 2 +- examples/sac/utils.py | 35 +- examples/td3/config.yaml | 2 +- examples/td3/utils.py | 32 +- torchrl/envs/batched_envs.py | 8 +- torchrl/envs/transforms/transforms.py | 11 +- torchrl/objectives/value/functional.py | 5 +- torchrl/record/loggers/utils.py | 7 +- torchrl/trainers/helpers/losses.py | 6 +- torchrl/trainers/helpers/models.py | 6 +- torchrl/trainers/trainers.py | 5 +- 29 files changed, 1538 insertions(+), 314 deletions(-) create mode 100644 examples/redq/utils.py diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 39218bd82a6..e392e0c93aa 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -73,7 +73,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.collector_device=cuda:0 \ + collector.device=cuda:0 \ network.device=cuda:0 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ @@ -107,23 +107,24 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ record_frames=4 \ buffer_size=120 python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.device=cuda:0 \ + buffer.batch_size=10 \ + optim.steps_per_batch=1 \ + logger.record_video=True \ + logger.record_frames=4 \ + buffer.size=120 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.collector_device=cuda:0 \ + collector.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ @@ -152,21 +153,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.frames_per_batch=16 \ collector.num_workers=4 \ collector.env_per_collector=2 \ - collector.collector_device=cuda:0 \ + collector.device=cuda:0 \ + collector.device=cuda:0 \ network.device=cuda:0 \ logger.mode=offline \ env.name=Pendulum-v1 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ - total_frames=48 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - device=cuda:0 \ - mode=offline \ - logger= + collector.total_frames=48 \ + buffer.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.device=cuda:0 \ + network.device=cuda:0 \ + logger.mode=offline \ + logger.backend= # With single envs python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ @@ -188,7 +189,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=1 \ - collector.collector_device=cuda:0 \ + collector.device=cuda:0 \ network.device=cuda:0 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ @@ -209,23 +210,24 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ record_frames=4 \ buffer_size=120 python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ - total_frames=48 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=16 \ num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=1 \ + buffer.batch_size=10 \ + collector.device=cuda:0 \ + optim.steps_per_batch=1 \ + logger.record_video=True \ + logger.record_frames=4 \ + buffer.size=120 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=1 \ - collector.collector_device=cuda:0 \ + collector.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ network.device=cuda:0 \ @@ -235,24 +237,23 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ env.name=Pendulum-v1 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ - total_frames=48 \ - batch_size=10 \ - frames_per_batch=16 \ - num_workers=2 \ - env_per_collector=1 \ - mode=offline \ - device=cuda:0 \ - collector_device=cuda:0 \ - logger= + collector.total_frames=48 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=1 \ + collector.device=cuda:0 \ + network.device=cuda:0 \ + buffer.batch_size=10 \ + logger.mode=offline \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=2 \ collector.env_per_collector=1 \ + collector.device=cuda:0 \ logger.mode=offline \ - collector.collector_device=cuda:0 \ + optim.batch_size=10 \ env.name=Pendulum-v1 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \ diff --git a/examples/EXAMPLES.md b/examples/EXAMPLES.md index 523a397f486..f875829b6e6 100644 --- a/examples/EXAMPLES.md +++ b/examples/EXAMPLES.md @@ -18,7 +18,7 @@ python sac.py ``` or similar. Hyperparameters can be easily changed by providing the arguments to hydra: ``` -python sac.py frames_per_batch=63 +python sac.py collector.frames_per_batch=63 ``` # Results @@ -32,11 +32,11 @@ We average the results over 5 different seeds and plot the standard error. To reproduce a single run: ``` -python sac/sac.py env_name="HalfCheetah-v4" env_task="" env_library="gym" +python sac/sac.py env.name="HalfCheetah-v4" env.task="" env.library="gym" ``` ``` -python redq/redq.py env_name="HalfCheetah-v4" env_task="" env_library="gym" +python redq/redq.py env.name="HalfCheetah-v4" env.library="gymnasium" ``` @@ -48,39 +48,61 @@ python redq/redq.py env_name="HalfCheetah-v4" env_task="" env_library="gym" To reproduce a single run: ``` -python sac/sac.py env_name="cheetah" env_task="run" env_library="dm_control" +python sac/sac.py env.name="cheetah" env.task="run" env.library="dm_control" ``` ``` -python redq/redq.py env_name="cheetah" env_task="run" env_library="dm_control" +python redq/redq.py env.name="cheetah" env.task="run" env.library="dm_control" ``` -## Gym's Ant-v4 +[//]: # (TODO: adapt these scripts) +[//]: # (## Gym's Ant-v4) -

- -

-To reproduce a single run: +[//]: # () +[//]: # (

) -``` -python sac/sac.py env_name="Ant-v4" env_task="" env_library="gym" -``` +[//]: # () -``` -python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym" -``` +[//]: # (

) -## Gym's Walker2D-v4 +[//]: # (To reproduce a single run:) -

- -

-To reproduce a single run: +[//]: # () +[//]: # (```) -``` -python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym" -``` +[//]: # (python sac/sac.py env.name="Ant-v4" env.task="" env.library="gym") -``` -python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym" -``` +[//]: # (```) + +[//]: # () +[//]: # (``` ) + +[//]: # (python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym") + +[//]: # (```) + +[//]: # () +[//]: # (## Gym's Walker2D-v4) + +[//]: # () +[//]: # (

) + +[//]: # () + +[//]: # (

) + +[//]: # (To reproduce a single run:) + +[//]: # () +[//]: # (```) + +[//]: # (python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym") + +[//]: # (```) + +[//]: # () +[//]: # (``` ) + +[//]: # (python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym") + +[//]: # (```) diff --git a/examples/cql/cql_offline.py b/examples/cql/cql_offline.py index 9f9e3d6d857..122dd2579b8 100644 --- a/examples/cql/cql_offline.py +++ b/examples/cql/cql_offline.py @@ -26,7 +26,7 @@ ) -@hydra.main(config_path=".", config_name="offline_config") +@hydra.main(config_path=".", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name) logger = None diff --git a/examples/cql/cql_online.py b/examples/cql/cql_online.py index db8c8e3ad5c..beb1a71201d 100644 --- a/examples/cql/cql_online.py +++ b/examples/cql/cql_online.py @@ -27,7 +27,7 @@ ) -@hydra.main(config_path=".", config_name="online_config") +@hydra.main(version_base="1.1", config_path=".", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 exp_name = generate_exp_name("CQL-online", cfg.env.exp_name) logger = None diff --git a/examples/cql/online_config.yaml b/examples/cql/online_config.yaml index 0aa3f30467e..4528fe3fb8d 100644 --- a/examples/cql/online_config.yaml +++ b/examples/cql/online_config.yaml @@ -18,7 +18,7 @@ collector: multi_step: 0 init_random_frames: 1000 env_per_collector: 1 - collector_device: cpu + device: cpu max_frames_per_traj: 200 # logger diff --git a/examples/cql/utils.py b/examples/cql/utils.py index 23b14461da9..ac62eea28bc 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -12,14 +12,16 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( + CatTensors, Compose, + DMControlEnv, DoubleToFloat, EnvCreator, ParallelEnv, RewardScaling, TransformedEnv, ) -from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator from torchrl.objectives import CQLLoss, SoftUpdate @@ -32,8 +34,21 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def apply_env_transforms(env, reward_scaling=1.0): @@ -51,7 +66,7 @@ def make_environment(cfg, num_envs=1): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( num_envs, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ) parallel_env.set_seed(cfg.env.seed) @@ -60,7 +75,7 @@ def make_environment(cfg, num_envs=1): eval_env = TransformedEnv( ParallelEnv( num_envs, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ), train_env.transform.clone(), ) @@ -80,7 +95,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.collector_device, + device=cfg.collector.device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 5997ccb8fb3..2b3713c0407 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -14,7 +14,7 @@ collector: frames_per_batch: 1000 init_env_steps: 1000 reset_at_each_iter: False - collector_device: cpu + device: cpu env_per_collector: 1 diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index 17f927eca62..2260e220b4b 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -9,7 +9,9 @@ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import ( + CatTensors, Compose, + DMControlEnv, DoubleToFloat, EnvCreator, InitTracker, @@ -39,13 +41,21 @@ # ----------------- -def env_maker(task, device="cpu", from_pixels=False): - with set_gym_backend("gym"): - return GymEnv( - task, - device=device, - from_pixels=from_pixels, +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def apply_env_transforms(env, max_episode_steps=1000): @@ -65,7 +75,7 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ) parallel_env.set_seed(cfg.env.seed) @@ -76,7 +86,7 @@ def make_environment(cfg): eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ), train_env.transform.clone(), ) @@ -97,7 +107,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, total_frames=cfg.collector.total_frames, - device=cfg.collector.collector_device, + device=cfg.collector.device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 720d4842e1d..d870d383213 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -179,7 +179,7 @@ def make_collector(cfg, policy): policy, frames_per_batch=collector_cfg.frames_per_batch, total_frames=collector_cfg.total_frames, - device=collector_cfg.collector_devices, + device=collector_cfg.devices, max_frames_per_traj=collector_cfg.max_frames_per_traj, postproc=transforms, ) diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 325b789bb7e..29ccd1eca6d 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -20,9 +20,15 @@ ) from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs import ( + CatTensors, + DMControlEnv, + EnvCreator, + ParallelEnv, + TransformedEnv, +) -from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, SafeModule from torchrl.modules.distributions import OneHotCategorical @@ -33,10 +39,21 @@ from torchrl.record.loggers import generate_exp_name, get_logger -def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv( - env_name, device=device, frame_skip=frame_skip, from_pixels=from_pixels - ) +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def make_replay_buffer( @@ -101,7 +118,7 @@ def env_factory(num_workers): # 1.2 Create env vector vec_env = ParallelEnv( - create_env_fn=EnvCreator(lambda: env_maker(env_name=cfg.env_name)), + create_env_fn=EnvCreator(lambda cfg=cfg: env_maker(cfg)), num_workers=num_workers, ) diff --git a/examples/dqn/config.yaml b/examples/dqn/config.yaml index f8c863f3ad2..d9894cf522b 100644 --- a/examples/dqn/config.yaml +++ b/examples/dqn/config.yaml @@ -16,7 +16,6 @@ lr: 3e-4 multi_step: 1 init_random_frames: 25000 from_pixels: 1 -collector_device: cpu env_per_collector: 8 num_workers: 32 lr_scheduler: "" diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index 0c59d96ec9e..cd178ba3bbc 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -160,6 +160,7 @@ def main(cfg: "DictConfig"): # noqa: F821 print(f"init seed: {cfg.seed}, final seed: {final_seed}") trainer.train() + trainer.collector.shutdown() return (logger.log_dir, trainer._log_dict) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index c16337aa087..fba4247e2a7 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -102,8 +102,6 @@ def make_env_transforms( obs_stats = stats obs_stats["standard_normal"] = True obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) - # if obs_norm_state_dict: - # obs_norm.load_state_dict(obs_norm_state_dict) env.append_transform(obs_norm) if norm_rewards: reward_scaling = 1.0 @@ -132,7 +130,6 @@ def make_env_transforms( env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) ) - return env diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 6be18a66016..f27adc1789a 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -18,8 +18,14 @@ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.libs.gym import GymEnv +from torchrl.envs import ( + CatTensors, + DMControlEnv, + EnvCreator, + ParallelEnv, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal @@ -29,10 +35,22 @@ from torchrl.record.loggers import generate_exp_name, get_logger -def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv( - env_name, device=device, frame_skip=frame_skip, from_pixels=from_pixels - ) +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + frame_skip=cfg.env.frame_skip, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task, frame_skip=cfg.env.frame_skip) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def make_replay_buffer( @@ -73,34 +91,34 @@ def make_replay_buffer( @hydra.main(version_base="1.1", config_path=".", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.device) + device = torch.device(cfg.network.device) - exp_name = generate_exp_name("Online_IQL", cfg.exp_name) + exp_name = generate_exp_name("Online_IQL", cfg.logger.exp_name) logger = None - if cfg.logger: + if cfg.logger.backend: logger = get_logger( - logger_type=cfg.logger, + logger_type=cfg.logger.backend, logger_name="iql_logging", experiment_name=exp_name, - wandb_kwargs={"mode": cfg.mode}, + wandb_kwargs={"mode": cfg.logger.mode}, ) - torch.manual_seed(cfg.seed) - np.random.seed(cfg.seed) + torch.manual_seed(cfg.optim.seed) + np.random.seed(cfg.optim.seed) def env_factory(num_workers): """Creates an instance of the environment.""" # 1.2 Create env vector vec_env = ParallelEnv( - create_env_fn=EnvCreator(lambda: env_maker(env_name=cfg.env_name)), + create_env_fn=EnvCreator(lambda cfg=cfg: env_maker(cfg=cfg)), num_workers=num_workers, ) return vec_env # Sanity check - test_env = env_factory(num_workers=5) + test_env = env_factory(num_workers=cfg.collector.env_per_collector) num_actions = test_env.action_spec.shape[-1] # Create Agent @@ -117,14 +135,14 @@ def env_factory(num_workers): dist_class = TanhNormal dist_kwargs = { - "min": action_spec.space.minimum[-1], - "max": action_spec.space.maximum[-1], - "tanh_loc": cfg.tanh_loc, + "min": action_spec.space.low[-1], + "max": action_spec.space.high[-1], + "tanh_loc": cfg.network.tanh_loc, } actor_extractor = NormalParamExtractor( - scale_mapping=f"biased_softplus_{cfg.default_policy_scale}", - scale_lb=cfg.scale_lb, + scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", + scale_lb=cfg.network.scale_lb, ) actor_net = nn.Sequential(actor_net, actor_extractor) @@ -195,35 +213,41 @@ def env_factory(num_workers): qvalue_network=model[1], value_network=model[2], num_qvalue_nets=2, - temperature=cfg.temperature, - expectile=cfg.expectile, - loss_function="smooth_l1", + temperature=cfg.loss.temperature, + expectile=cfg.loss.expectile, + loss_function=cfg.loss.loss_function, ) - loss_module.make_value_estimator(gamma=cfg.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) + target_net_updater = SoftUpdate(loss_module, eps=cfg.loss.target_update_polyak) # Make Off-Policy Collector collector = SyncDataCollector( env_factory, - create_env_kwargs={"num_workers": cfg.env_per_collector}, + create_env_kwargs={"num_workers": cfg.collector.env_per_collector}, policy=model[0], - frames_per_batch=cfg.frames_per_batch, - max_frames_per_traj=cfg.max_frames_per_traj, - total_frames=cfg.total_frames, - device=cfg.collector_device, + frames_per_batch=cfg.collector.frames_per_batch, + max_frames_per_traj=cfg.collector.max_frames_per_traj, + total_frames=cfg.collector.total_frames, + device=cfg.collector.device, ) - collector.set_seed(cfg.seed) + collector.set_seed(cfg.optim.seed) # Make Replay Buffer replay_buffer = make_replay_buffer( - buffer_size=cfg.buffer_size, device="cpu", batch_size=cfg.batch_size + buffer_size=cfg.buffer.size, + device="cpu", + batch_size=cfg.buffer.batch_size, + prefetch=cfg.buffer.prefetch, + prb=cfg.buffer.prb, ) # Optimizers params = list(loss_module.parameters()) - optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay) + optimizer = optim.Adam( + params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps + ) rewards = [] rewards_eval = [] @@ -231,9 +255,13 @@ def env_factory(num_workers): # Main loop collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.total_frames) + pbar = tqdm.tqdm(total=cfg.collector.total_frames) r0 = None loss = None + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) + env_per_collector = cfg.collector.env_per_collector + prb = cfg.buffer.prb + max_frames_per_traj = cfg.collector.max_frames_per_traj for i, tensordict in enumerate(collector): @@ -260,9 +288,13 @@ def env_factory(num_workers): value_losses, ) = ([], [], []) # optimization steps - for _ in range(cfg.frames_per_batch * int(cfg.utd_ratio)): + for _ in range(num_updates): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(cfg.batch_size).clone() + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device == device: + sampled_tensordict = sampled_tensordict.clone() + else: + sampled_tensordict = sampled_tensordict.to(device, non_blocking=True) loss_td = loss_module(sampled_tensordict) @@ -284,11 +316,11 @@ def env_factory(num_workers): target_net_updater.step() # update priority - if cfg.prb: + if prb: replay_buffer.update_priority(sampled_tensordict) rewards.append( - (i, tensordict["next", "reward"].sum().item() / cfg.env_per_collector) + (i, tensordict["next", "reward"].sum().item() / env_per_collector) ) train_log = { "train_reward": rewards[-1][1], @@ -308,7 +340,7 @@ def env_factory(num_workers): with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): eval_rollout = test_env.rollout( - max_steps=cfg.max_frames_per_traj, + max_steps=max_frames_per_traj, policy=model[0], auto_cast_to_device=True, ).clone() diff --git a/examples/iql/online_config.yaml b/examples/iql/online_config.yaml index d1e49b90716..350560ea9a1 100644 --- a/examples/iql/online_config.yaml +++ b/examples/iql/online_config.yaml @@ -1,50 +1,46 @@ -env_name: Pendulum-v1 -env_library: gym -exp_name: "iql_pendulum" -seed: 42 -async_collection: 1 -record_video: 0 -frame_skip: 1 +env: + name: Pendulum-v1 + library: gym + async_collection: 1 + record_video: 0 + frame_skip: 1 -total_frames: 1000000 -init_env_steps: 10000 -init_random_frames: 5000 -# Updates -utd_ratio: 1.0 -batch_size: 256 -lr: 3e-4 -weight_decay: 0.0 -target_update_polyak: 0.995 -multi_step: 1.0 -gamma: 0.99 +logger: + exp_name: "iql_pendulum" + backend: wandb + mode: online -tanh_loc: False -default_policy_scale: 1.0 -scale_lb: 0.1 -activation: elu -from_pixels: 0 -collector_device: cuda:0 -env_per_collector: 5 -frames_per_batch: 1000 # 5*200 -max_frames_per_traj: 200 -num_workers: 1 +optim: + seed: 42 + utd_ratio: 1.0 + lr: 3e-4 + weight_decay: 0.0 + eps: 1e-4 -record_frames: 10000 -loss_function: smooth_l1 -batch_transform: 1 -buffer_prefetch: 64 -norm_stats: 1 +network: + tanh_loc: False + default_policy_scale: 1.0 + scale_lb: 0.1 + device: "cuda:0" -device: "cuda:0" +collector: + total_frames: 1000000 + init_random_frames: 5000 + device: cuda:0 + frames_per_batch: 1000 # 5*200 + env_per_collector: 5 + max_frames_per_traj: 200 # IQL hyperparameter -temperature: 3.0 -expectile: 0.7 +loss: + temperature: 3.0 + expectile: 0.7 + gamma: 0.99 + target_update_polyak: 0.995 + loss_function: smooth_l1 -# Logging -logger: wandb -mode: online - -# Replay Buffer -prb: 0 -buffer_size: 100000 +buffer: + prefetch: 64 + prb: 0 + size: 100000 + batch_size: 256 diff --git a/examples/redq/config.yaml b/examples/redq/config.yaml index da52aa5496a..24e9ae2a60e 100644 --- a/examples/redq/config.yaml +++ b/examples/redq/config.yaml @@ -1,35 +1,97 @@ -env_name: HalfCheetah-v4 -env_task: "" -env_library: gym -async_collection: 1 -record_video: 0 -normalize_rewards_online: 1 -normalize_rewards_online_scale: 5 -frame_skip: 1 -frames_per_batch: 1024 -optim_steps_per_batch: 1024 -batch_size: 256 -total_frames: 1000000 -prb: 1 -lr: 3e-4 -ou_exploration: 1 -multi_step: 1 -init_random_frames: 25000 -activation: elu -gSDE: 0 -from_pixels: 0 -collector_device: cpu -env_per_collector: 1 +# Seed for collector and model +seed: 0 + +# Number of workers for the whole script, split among collector num_workers: 2 -lr_scheduler: "" -value_network_update_interval: 200 -record_interval: 10 -max_frames_per_traj: -1 -weight_decay: 0.0 -annealing_frames: 1000000 -init_env_steps: 10000 -record_frames: 10000 -loss_function: smooth_l1 -batch_transform: 1 -buffer_prefetch: 64 -norm_stats: 1 + +env: + name: HalfCheetah-v4 + task: "" + library: gym + normalize_rewards_online: 1 + normalize_rewards_online_scale: 5 + normalize_rewards_online_decay: 0.999 + frame_skip: 1 + from_pixels: 0 + batch_transform: 1 + norm_stats: 1 + init_env_steps: 10000 + reward_scaling: + reward_loc: + vecnorm: False + categorical_action_encoding: + noops: + center_crop: + catframes: + image_size: + grayscale: False + +collector: + async_collection: 1 + frames_per_batch: 1024 + total_frames: 1_000_000 + device: cpu + env_per_collector: 1 + init_random_frames: 50_000 + multi_step: 1 + n_steps_return: 3 + max_frames_per_traj: -1 + exploration_mode: random + +logger: + record_video: 0 + record_interval: 10 + record_frames: 10000 + exp_name: cheetah + backend: wandb + kwargs: + offline: False + recorder_log_keys: + +optim: + optimizer: adam + steps_per_batch: 1024 + lr: 3e-4 + init_random_frames: 25000 + lr_scheduler: "" + value_network_update_interval: 200 + weight_decay: 0.0 + eps: 1e-4 + kwargs: + betas: [0.0,0.9] + clip_grad_norm: 100.0 + clip_norm: + +buffer: + batch_size: 256 + prb: 1 + sub_traj_len: + size: 500_000 + scratch_dir: + prefetch: 64 + +network: + activation: elu + tanh_loc: False + default_policy_scale: 1.0 + actor_cells: 256 + actor_depth: 2 + qvalue_cells: 256 + qvalue_depth: 2 + scale_lb: 0.05 + +exploration: + gSDE: False + ou_exploration: 1 + annealing_frames: 1000000 + ou_sigma: 0.2 + ou_theta: 0.15 + noisy: False + +loss: + loss_function: smooth_l1 + type: double + num_q_values: 10 + gamma: 0.99 + hard_update: False + value_network_update_interval: 200 diff --git a/examples/redq/redq.py b/examples/redq/redq.py index 2223d709174..913216f44a8 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -3,55 +3,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import dataclasses import uuid from datetime import datetime import hydra import torch.cuda -from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import OrnsteinUhlenbeckProcessWrapper from torchrl.record import VideoRecorder -from torchrl.record.loggers import generate_exp_name, get_logger -from torchrl.trainers.helpers.collectors import ( - make_collector_offpolicy, - OffPolicyCollectorConfig, -) -from torchrl.trainers.helpers.envs import ( +from torchrl.record.loggers import get_logger +from utils import ( correct_for_frame_skip, - EnvConfig, get_norm_state_dict, initialize_observation_norm_transforms, + make_collector_offpolicy, + make_redq_loss, + make_redq_model, + make_replay_buffer, + make_trainer, parallel_env_constructor, retrieve_observation_norms_state_dict, transformed_env_constructor, ) -from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss -from torchrl.trainers.helpers.models import make_redq_model, REDQModelConfig -from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig -from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig - -config_fields = [ - (config_field.name, config_field.type, config_field) - for config_cls in ( - TrainerConfig, - OffPolicyCollectorConfig, - EnvConfig, - LossConfig, - REDQModelConfig, - LoggerConfig, - ReplayArgsConfig, - ) - for config_field in dataclasses.fields(config_cls) -] - -Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) -cs = ConfigStore.instance() -cs.store(name="config", node=Config) DEFAULT_REWARD_SCALING = { "Hopper-v1": 5, @@ -69,8 +45,9 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg = correct_for_frame_skip(cfg) - if not isinstance(cfg.reward_scaling, float): - cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0) + if not isinstance(cfg.env.reward_scaling, float): + cfg.env.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env.name, 5.0) + cfg.env.reward_loc = 0.0 device = ( torch.device("cpu") @@ -81,26 +58,30 @@ def main(cfg: "DictConfig"): # noqa: F821 exp_name = "_".join( [ "REDQ", - cfg.exp_name, + cfg.logger.exp_name, str(uuid.uuid4())[:8], datetime.now().strftime("%y_%m_%d-%H_%M_%S"), ] ) - exp_name = generate_exp_name("REDQ", cfg.exp_name) logger = get_logger( - logger_type=cfg.logger, logger_name="redq_logging", experiment_name=exp_name + logger_type=cfg.logger.backend, + logger_name="redq_logging", + experiment_name=exp_name, + **OmegaConf.to_container(cfg.logger.kwargs), ) - video_tag = exp_name if cfg.record_video else "" + video_tag = exp_name if cfg.logger.record_video else "" key, init_env_steps, stats = None, None, None - if not cfg.vecnorm and cfg.norm_stats: - if not hasattr(cfg, "init_env_steps"): - raise AttributeError("init_env_steps missing from arguments.") - key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") - init_env_steps = cfg.init_env_steps + if not cfg.env.vecnorm and cfg.env.norm_stats: + key = ( + ("next", "pixels") + if cfg.env.from_pixels + else ("next", "observation_vector") + ) + init_env_steps = cfg.env.init_env_steps stats = {"loc": None, "scale": None} - elif cfg.from_pixels: + elif cfg.env.from_pixels: stats = {"loc": 0.5, "scale": 0.5} proof_env = transformed_env_constructor( @@ -121,20 +102,20 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module, target_net_updater = make_redq_loss(model, cfg) actor_model_explore = model[0] - if cfg.ou_exploration: - if cfg.gSDE: + if cfg.exploration.ou_exploration: + if cfg.exploration.gSDE: raise RuntimeError("gSDE and ou_exploration are incompatible") actor_model_explore = OrnsteinUhlenbeckProcessWrapper( actor_model_explore, - annealing_num_steps=cfg.annealing_frames, - sigma=cfg.ou_sigma, - theta=cfg.ou_theta, + annealing_num_steps=cfg.exploration.annealing_frames, + sigma=cfg.exploration.ou_sigma, + theta=cfg.exploration.ou_theta, ).to(device) if device == torch.device("cpu"): # mostly for debugging actor_model_explore.share_memory() - if cfg.gSDE: + if cfg.exploration.gSDE: with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): # get dimensions to build the parallel env proof_td = actor_model_explore(proof_env.reset().to(device)) @@ -155,10 +136,6 @@ def main(cfg: "DictConfig"): # noqa: F821 make_env=create_env_fn, actor_model_explore=actor_model_explore, cfg=cfg, - # make_env_kwargs=[ - # {"device": device} if device >= 0 else {} - # for device in args.env_rendering_devices - # ], ) replay_buffer = make_replay_buffer("cpu", cfg) @@ -201,11 +178,9 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg, ) - final_seed = collector.set_seed(cfg.seed) - print(f"init seed: {cfg.seed}, final seed: {final_seed}") - trainer.train() - return (logger.log_dir, trainer._log_dict) + if logger is not None: + return (logger.log_dir, trainer._log_dict) if __name__ == "__main__": diff --git a/examples/redq/utils.py b/examples/redq/utils.py new file mode 100644 index 00000000000..076d3bf75b3 --- /dev/null +++ b/examples/redq/utils.py @@ -0,0 +1,1052 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from copy import copy +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +import torch +from omegaconf import OmegaConf +from tensordict.nn import ( + InteractionType, + ProbabilisticTensorDictSequential, + TensorDictModule, + TensorDictModuleWrapper, +) +from torch import distributions as d, nn, optim +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchrl._utils import VERBOSE +from torchrl.collectors.collectors import DataCollectorBase + +from torchrl.data import ReplayBuffer, TensorDictReplayBuffer +from torchrl.data.postprocs import MultiStep +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import ParallelEnv +from torchrl.envs.common import EnvBase +from torchrl.envs.env_creator import env_creator, EnvCreator +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import ( + CatFrames, + CatTensors, + CenterCrop, + Compose, + DoubleToFloat, + GrayScale, + NoopResetEnv, + ObservationNorm, + Resize, + RewardScaling, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.transforms.transforms import ( + FlattenObservation, + gSDENoise, + InitTracker, + StepCounter, +) +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ( + ActorCriticOperator, + ActorValueOperator, + NoisyLinear, + NormalParamWrapper, + SafeModule, + SafeSequential, +) +from torchrl.modules.distributions import TanhNormal +from torchrl.modules.distributions.continuous import SafeTanhTransform +from torchrl.modules.models.exploration import LazygSDEModule +from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator +from torchrl.objectives import HardUpdate, SoftUpdate +from torchrl.objectives.common import LossModule +from torchrl.objectives.deprecated import REDQLoss_deprecated +from torchrl.objectives.utils import TargetNetUpdater +from torchrl.record.loggers import Logger +from torchrl.record.recorder import VideoRecorder +from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector +from torchrl.trainers.trainers import ( + BatchSubSampler, + ClearCudaCache, + CountFramesLog, + LogReward, + Recorder, + ReplayBufferTrainer, + RewardNormalizer, + Trainer, + UpdateWeights, +) + +LIBS = { + "gym": GymEnv, + "dm_control": DMControlEnv, +} +ACTIVATIONS = { + "elu": nn.ELU, + "tanh": nn.Tanh, + "relu": nn.ReLU, +} +OPTIMIZERS = { + "adam": optim.Adam, + "sgd": optim.SGD, + "adamax": optim.Adamax, +} + + +def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 + """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. + + This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targetting a total number of frames + of 1M but actually collecting frame_skip * 1M frames. + + Args: + cfg (DictConfig): DictConfig containing some frame-counting argument, including: + "max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames", + "init_random_frames", "init_env_steps" + + Returns: + the input DictConfig, modified in-place. + + """ + + def _hasattr(field): + local_cfg = cfg + fields = field.split(".") + for f in fields: + if not hasattr(local_cfg, f): + return False + local_cfg = getattr(local_cfg, f) + else: + return True + + def _getattr(field): + local_cfg = cfg + fields = field.split(".") + for f in fields: + local_cfg = getattr(local_cfg, f) + return local_cfg + + def _setattr(field, val): + local_cfg = cfg + fields = field.split(".") + for f in fields[:-1]: + local_cfg = getattr(local_cfg, f) + setattr(local_cfg, field[-1], val) + + # Adapt all frame counts wrt frame_skip + frame_skip = cfg.env.frame_skip + if frame_skip != 1: + fields = [ + "collector.max_frames_per_traj", + "collector.total_frames", + "collector.frames_per_batch", + "logger.record_frames", + "exploration.annealing_frames", + "collector.init_random_frames", + "env.init_env_steps", + "env.noops", + ] + for field in fields: + if _hasattr(cfg, field): + _setattr(field, _getattr(field) // frame_skip) + return cfg + + +def make_trainer( + collector: DataCollectorBase, + loss_module: LossModule, + recorder: EnvBase | None, + target_net_updater: TargetNetUpdater | None, + policy_exploration: TensorDictModuleWrapper | TensorDictModule | None, + replay_buffer: ReplayBuffer | None, + logger: Logger | None, + cfg: "DictConfig", # noqa: F821 +) -> Trainer: + """Creates a Trainer instance given its constituents. + + Args: + collector (DataCollectorBase): A data collector to be used to collect data. + loss_module (LossModule): A TorchRL loss module + recorder (EnvBase, optional): a recorder environment. + target_net_updater (TargetNetUpdater): A target network update object. + policy_exploration (TDModule or TensorDictModuleWrapper): a policy to be used for recording and exploration + updates (should be synced with the learnt policy). + replay_buffer (ReplayBuffer): a replay buffer to be used to collect data. + logger (Logger): a Logger to be used for logging. + cfg (DictConfig): a DictConfig containing the arguments of the script. + + Returns: + A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided. + + Examples: + >>> import torch + >>> import tempfile + >>> from torchrl.trainers.loggers import TensorboardLogger + >>> from torchrl.trainers import Trainer + >>> from torchrl.envs import EnvCreator + >>> from torchrl.collectors.collectors import SyncDataCollector + >>> from torchrl.data import TensorDictReplayBuffer + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper + >>> from torchrl.objectives.common import LossModule + >>> from torchrl.objectives.utils import TargetNetUpdater + >>> from torchrl.objectives import DDPGLoss + >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0")) + >>> env_proof = env_maker() + >>> obs_spec = env_proof.observation_spec + >>> action_spec = env_proof.action_spec + >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1]) + >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing + >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) + >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"]) + >>> collector = SyncDataCollector(env_maker, policy, total_frames=100) + >>> loss_module = DDPGLoss(policy, value, gamma=0.99) + >>> recorder = env_proof + >>> target_net_updater = None + >>> policy_exploration = EGreedyWrapper(policy) + >>> replay_buffer = TensorDictReplayBuffer() + >>> dir = tempfile.gettempdir() + >>> logger = TensorboardLogger(exp_name=dir) + >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, + ... replay_buffer, logger) + >>> print(trainer) + + """ + + optimizer = OPTIMIZERS[cfg.optim.optimizer]( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + **OmegaConf.to_container(cfg.optim.kwargs), + ) + device = next(loss_module.parameters()).device + if cfg.optim.lr_scheduler == "cosine": + optim_scheduler = CosineAnnealingLR( + optimizer, + T_max=int( + cfg.collector.total_frames + / cfg.collector.frames_per_batch + * cfg.optim.steps_per_batch + ), + ) + elif cfg.optim.lr_scheduler == "": + optim_scheduler = None + else: + raise NotImplementedError(f"lr scheduler {cfg.optim.lr_scheduler}") + + if VERBOSE: + print( + f"collector = {collector}; \n" + f"loss_module = {loss_module}; \n" + f"recorder = {recorder}; \n" + f"target_net_updater = {target_net_updater}; \n" + f"policy_exploration = {policy_exploration}; \n" + f"replay_buffer = {replay_buffer}; \n" + f"logger = {logger}; \n" + f"cfg = {cfg}; \n" + ) + + if logger is not None: + # log hyperparams + logger.log_hparams(cfg) + + trainer = Trainer( + collector=collector, + frame_skip=cfg.env.frame_skip, + total_frames=cfg.collector.total_frames * cfg.env.frame_skip, + loss_module=loss_module, + optimizer=optimizer, + logger=logger, + optim_steps_per_batch=cfg.optim.steps_per_batch, + clip_grad_norm=cfg.optim.clip_grad_norm, + clip_norm=cfg.optim.clip_norm, + ) + + if torch.cuda.device_count() > 0: + trainer.register_op("pre_optim_steps", ClearCudaCache(1)) + + trainer.register_op("batch_process", lambda batch: batch.cpu()) + + if replay_buffer is not None: + # replay buffer is used 2 or 3 times: to register data, to sample + # data and to update priorities + rb_trainer = ReplayBufferTrainer( + replay_buffer, + cfg.buffer.batch_size, + flatten_tensordicts=False, + memmap=False, + device=device, + ) + + trainer.register_op("batch_process", rb_trainer.extend) + trainer.register_op("process_optim_batch", rb_trainer.sample) + trainer.register_op("post_loss", rb_trainer.update_priority) + else: + # trainer.register_op("batch_process", mask_batch) + trainer.register_op( + "process_optim_batch", + BatchSubSampler( + batch_size=cfg.buffer.batch_size, sub_traj_len=cfg.buffer.sub_traj_len + ), + ) + trainer.register_op("process_optim_batch", lambda batch: batch.to(device)) + + if optim_scheduler is not None: + trainer.register_op("post_optim", optim_scheduler.step) + + if target_net_updater is not None: + trainer.register_op("post_optim", target_net_updater.step) + + if cfg.env.normalize_rewards_online: + # if used the running statistics of the rewards are computed and the + # rewards used for training will be normalized based on these. + reward_normalizer = RewardNormalizer( + scale=cfg.env.normalize_rewards_online_scale, + decay=cfg.env.normalize_rewards_online_decay, + ) + trainer.register_op("batch_process", reward_normalizer.update_reward_stats) + trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward) + + if policy_exploration is not None and hasattr(policy_exploration, "step"): + trainer.register_op( + "post_steps", policy_exploration.step, frames=cfg.collector.frames_per_batch + ) + + trainer.register_op( + "post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]} + ) + + if recorder is not None: + # create recorder object + recorder_obj = Recorder( + record_frames=cfg.logger.record_frames, + frame_skip=cfg.env.frame_skip, + policy_exploration=policy_exploration, + environment=recorder, + record_interval=cfg.logger.record_interval, + log_keys=cfg.logger.recorder_log_keys, + ) + # register recorder + trainer.register_op( + "post_steps_log", + recorder_obj, + ) + # call recorder - could be removed + recorder_obj(None) + # create explorative recorder - could be optional + recorder_obj_explore = Recorder( + record_frames=cfg.logger.record_frames, + frame_skip=cfg.env.frame_skip, + policy_exploration=policy_exploration, + environment=recorder, + record_interval=cfg.logger.record_interval, + exploration_type=ExplorationType.RANDOM, + suffix="exploration", + out_keys={("next", "reward"): "r_evaluation_exploration"}, + ) + # register recorder + trainer.register_op( + "post_steps_log", + recorder_obj_explore, + ) + # call recorder - could be removed + recorder_obj_explore(None) + + trainer.register_op( + "post_steps", UpdateWeights(collector, update_weights_interval=1) + ) + + trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip)) + + return trainer + + +def make_redq_model( + proof_environment: EnvBase, + cfg: "DictConfig", # noqa: F821 + device: DEVICE_TYPING = "cpu", + in_keys: Sequence[str] | None = None, + actor_net_kwargs=None, + qvalue_net_kwargs=None, + observation_key=None, + **kwargs, +) -> nn.ModuleList: + """Actor and Q-value model constructor helper function for REDQ. + + Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd. + Other configurations can easily be implemented by modifying this function at will. + A single instance of the Q-value model is returned. It will be multiplicated by the loss function. + + Args: + proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec + cfg (DictConfig): contains arguments of the REDQ script + device (torch.device, optional): device on which the model must be cast. Default is "cpu". + in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of + `'observation_vector'` or `'pixels'`. If none is provided, one of these two keys is chosen + based on the `cfg.from_pixels` argument. + actor_net_kwargs (dict, optional): kwargs of the actor MLP. + qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP. + + Returns: + A nn.ModuleList containing the actor, qvalue operator(s) and the value operator. + + """ + torch.manual_seed(cfg.seed) + tanh_loc = cfg.network.tanh_loc + default_policy_scale = cfg.network.default_policy_scale + gSDE = cfg.exploration.gSDE + + action_spec = proof_environment.action_spec + + if actor_net_kwargs is None: + actor_net_kwargs = {} + if qvalue_net_kwargs is None: + qvalue_net_kwargs = {} + + linear_layer_class = torch.nn.Linear if not cfg.exploration.noisy else NoisyLinear + + out_features_actor = (2 - gSDE) * action_spec.shape[-1] + if cfg.env.from_pixels: + if in_keys is None: + in_keys_actor = ["pixels"] + else: + in_keys_actor = in_keys + actor_net_kwargs_default = { + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[cfg.network.activation], + }, + "conv_net_kwargs": { + "activation_class": ACTIVATIONS[cfg.network.activation] + }, + } + actor_net_kwargs_default.update(actor_net_kwargs) + actor_net = DdpgCnnActor(out_features_actor, **actor_net_kwargs_default) + gSDE_state_key = "hidden" + out_keys_actor = ["param", "hidden"] + + value_net_default_kwargs = { + "mlp_net_kwargs": { + "layer_class": linear_layer_class, + "activation_class": ACTIVATIONS[cfg.network.activation], + }, + "conv_net_kwargs": { + "activation_class": ACTIVATIONS[cfg.network.activation] + }, + } + value_net_default_kwargs.update(qvalue_net_kwargs) + + in_keys_qvalue = ["pixels", "action"] + qvalue_net = DdpgCnnQNet(**value_net_default_kwargs) + else: + if in_keys is None: + in_keys_actor = ["observation_vector"] + else: + in_keys_actor = in_keys + + actor_net_kwargs_default = { + "num_cells": [cfg.network.actor_cells] * cfg.network.actor_depth, + "out_features": out_features_actor, + "activation_class": ACTIVATIONS[cfg.network.activation], + } + actor_net_kwargs_default.update(actor_net_kwargs) + actor_net = MLP(**actor_net_kwargs_default) + out_keys_actor = ["param"] + gSDE_state_key = in_keys_actor[0] + + qvalue_net_kwargs_default = { + "num_cells": [cfg.network.qvalue_cells] * cfg.network.qvalue_depth, + "out_features": 1, + "activation_class": ACTIVATIONS[cfg.network.activation], + } + qvalue_net_kwargs_default.update(qvalue_net_kwargs) + qvalue_net = MLP( + **qvalue_net_kwargs_default, + ) + in_keys_qvalue = in_keys_actor + ["action"] + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.low, + "max": action_spec.space.high, + "tanh_loc": tanh_loc, + } + + if not gSDE: + actor_net = NormalParamWrapper( + actor_net, + scale_mapping=f"biased_softplus_{default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ) + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=["loc", "scale"] + out_keys_actor[1:], + ) + + else: + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=["action"] + out_keys_actor[1:], # will be overwritten + ) + + if action_spec.domain == "continuous": + min = action_spec.space.low + max = action_spec.space.high + transform = SafeTanhTransform() + if (min != -1).any() or (max != 1).any(): + transform = d.ComposeTransform( + transform, + d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), + ) + else: + raise RuntimeError("cannot use gSDE with discrete actions") + + actor_module = SafeSequential( + actor_module, + SafeModule( + LazygSDEModule(transform=transform), + in_keys=["action", gSDE_state_key, "_eps_gSDE"], + out_keys=["loc", "scale", "action", "_eps_gSDE"], + ), + ) + + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=True, + ) + qvalue = ValueOperator( + in_keys=in_keys_qvalue, + module=qvalue_net, + ) + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = proof_environment.fake_tensordict() + td = td.unsqueeze(-1) + td = td.to(device) + for net in model: + net(td) + del td + return model + + +def transformed_env_constructor( + cfg: "DictConfig", # noqa: F821 + video_tag: str = "", + logger: Logger | None = None, + stats: dict | None = None, + norm_obs_only: bool = False, + use_env_creator: bool = False, + custom_env_maker: Callable | None = None, + custom_env: EnvBase | None = None, + return_transformed_envs: bool = True, + action_dim_gsde: int | None = None, + state_dim_gsde: int | None = None, + batch_dims: int | None = 0, + obs_norm_state_dict: dict | None = None, +) -> Union[Callable, EnvCreator]: + """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. + + Args: + cfg (DictConfig): a DictConfig containing the arguments of the script. + video_tag (str, optional): video tag to be passed to the Logger object + logger (Logger, optional): logger associated with the script + stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform + norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online. + Default is `False`. + use_env_creator (bool, optional): wheter the `EnvCreator` class should be used. By using `EnvCreator`, + one can make sure that running statistics will be put in shared memory and accessible for all workers + when using a `VecNorm` transform. Default is `True`. + custom_env_maker (callable, optional): if your env maker is not part + of torchrl env wrappers, a custom callable + can be passed instead. In this case it will override the + constructor retrieved from `args`. + custom_env (EnvBase, optional): if an existing environment needs to be + transformed_in, it can be passed directly to this helper. `custom_env_maker` + and `custom_env` are exclusive features. + return_transformed_envs (bool, optional): if ``True``, a transformed_in environment + is returned. + action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise. + Make sure this is indicated in environment executed in parallel. + state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise. + Make sure this is indicated in environment executed in parallel. + batch_dims (int, optional): number of dimensions of a batch of data. If a single env is + used, it should be 0 (default). If multiple envs are being transformed in parallel, + it should be set to 1 (or the number of dims of the batch). + obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the + environment + """ + + def make_transformed_env(**kwargs) -> TransformedEnv: + env_name = cfg.env.name + env_task = cfg.env.task + env_library = LIBS[cfg.env.library] + frame_skip = cfg.env.frame_skip + from_pixels = cfg.env.from_pixels + categorical_action_encoding = cfg.env.categorical_action_encoding + + if custom_env is None and custom_env_maker is None: + if isinstance(cfg.collector.device, str): + device = cfg.collector.device + elif isinstance(cfg.collector.device, Sequence): + device = cfg.collector.device[0] + else: + raise ValueError( + "collector_device must be either a string or a sequence of strings" + ) + env_kwargs = { + "env_name": env_name, + "device": device, + "frame_skip": frame_skip, + "from_pixels": from_pixels or len(video_tag), + "pixels_only": from_pixels, + } + if env_library is GymEnv: + env_kwargs.update( + {"categorical_action_encoding": categorical_action_encoding} + ) + elif categorical_action_encoding: + raise NotImplementedError( + "categorical_action_encoding=True is currently only compatible with GymEnvs." + ) + if env_library is DMControlEnv: + env_kwargs.update({"task_name": env_task}) + env_kwargs.update(kwargs) + env = env_library(**env_kwargs) + elif custom_env is None and custom_env_maker is not None: + env = custom_env_maker(**kwargs) + elif custom_env_maker is None and custom_env is not None: + env = custom_env + else: + raise RuntimeError("cannot provive both custom_env and custom_env_maker") + + if cfg.env.noops and custom_env is None: + # this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv + # that already has its NoopResetEnv set for the contained envs. + # There is a risk however that we're just skipping the NoopsReset instantiation + env = TransformedEnv(env, NoopResetEnv(cfg.env.noops)) + if not return_transformed_envs: + return env + + return make_env_transforms( + env, + cfg, + video_tag, + logger, + env_name, + stats, + norm_obs_only, + env_library, + action_dim_gsde, + state_dim_gsde, + batch_dims=batch_dims, + obs_norm_state_dict=obs_norm_state_dict, + ) + + if use_env_creator: + return env_creator(make_transformed_env) + return make_transformed_env + + +def get_norm_state_dict(env): + """Gets the normalization loc and scale from the env state_dict.""" + sd = env.state_dict() + sd = { + key: val + for key, val in sd.items() + if key.endswith("loc") or key.endswith("scale") + } + return sd + + +def initialize_observation_norm_transforms( + proof_environment: EnvBase, + num_iter: int = 1000, + key: Union[str, Tuple[str, ...]] = None, +): + """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. + + If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op. + Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect. + If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will + be raised. + + Args: + proof_environment (EnvBase instance, optional): if provided, this env will + be used ot execute the rollouts. If not, it will be created using + the cfg object. + num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms` + key (str, optional): if provided, the stats of this key will be gathered. + If not, it is expected that only one key exists in `env.observation_spec`. + + """ + if not isinstance(proof_environment.transform, Compose) and not isinstance( + proof_environment.transform, ObservationNorm + ): + return + + if key is None: + keys = list(proof_environment.base_env.observation_spec.keys(True, True)) + key = keys.pop() + if len(keys): + raise RuntimeError( + f"More than one key exists in the observation_specs: {[key] + keys} were found, " + "thus initialize_observation_norm_transforms cannot infer which to compute the stats of." + ) + + if isinstance(proof_environment.transform, Compose): + for transform in proof_environment.transform: + if isinstance(transform, ObservationNorm) and not transform.initialized: + transform.init_stats(num_iter=num_iter, key=key) + elif not proof_environment.transform.initialized: + proof_environment.transform.init_stats(num_iter=num_iter, key=key) + + +def parallel_env_constructor( + cfg: "DictConfig", **kwargs # noqa: F821 +) -> Union[ParallelEnv, EnvCreator]: + """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. + + Args: + cfg (DictConfig): config containing user-defined arguments + kwargs: keyword arguments for the `transformed_env_constructor` method. + """ + batch_transform = cfg.env.batch_transform + if not batch_transform: + raise NotImplementedError( + "batch_transform must be set to True for the recorder to be synced " + "with the collection envs." + ) + if cfg.collector.env_per_collector == 1: + kwargs.update({"cfg": cfg, "use_env_creator": True}) + make_transformed_env = transformed_env_constructor(**kwargs) + return make_transformed_env + kwargs.update({"cfg": cfg, "use_env_creator": True}) + make_transformed_env = transformed_env_constructor( + return_transformed_envs=not batch_transform, **kwargs + ) + parallel_env = ParallelEnv( + num_workers=cfg.collector.env_per_collector, + create_env_fn=make_transformed_env, + create_env_kwargs=None, + pin_memory=False, + ) + if batch_transform: + kwargs.update( + { + "cfg": cfg, + "use_env_creator": False, + "custom_env": parallel_env, + "batch_dims": 1, + } + ) + env = transformed_env_constructor(**kwargs)() + return env + return parallel_env + + +def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv): + """Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts. + + Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment + If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list + + Args: + proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm` + state dict from + """ + obs_norm_state_dicts = [] + + if isinstance(proof_environment.transform, Compose): + for idx, transform in enumerate(proof_environment.transform): + if isinstance(transform, ObservationNorm): + obs_norm_state_dicts.append((idx, transform.state_dict())) + + if isinstance(proof_environment.transform, ObservationNorm): + obs_norm_state_dicts.append((0, proof_environment.transform.state_dict())) + + return obs_norm_state_dicts + + +def make_env_transforms( + env, + cfg, + video_tag, + logger, + env_name, + stats, + norm_obs_only, + env_library, + action_dim_gsde, + state_dim_gsde, + batch_dims=0, + obs_norm_state_dict=None, +): + """Creates the typical transforms for and env.""" + env = TransformedEnv(env) + + from_pixels = cfg.env.from_pixels + vecnorm = cfg.env.vecnorm + norm_rewards = vecnorm and cfg.env.norm_rewards + _norm_obs_only = norm_obs_only or not norm_rewards + reward_scaling = cfg.env.reward_scaling + reward_loc = cfg.env.reward_loc + + if len(video_tag): + center_crop = cfg.env.center_crop + if center_crop: + center_crop = center_crop[0] + env.append_transform( + VideoRecorder( + logger=logger, + tag=f"{video_tag}_{env_name}_video", + center_crop=center_crop, + ), + ) + + if from_pixels: + if not cfg.env.catframes: + raise RuntimeError( + "this env builder currently only accepts positive catframes values" + "when pixels are being used." + ) + env.append_transform(ToTensorImage()) + if cfg.env.center_crop: + env.append_transform(CenterCrop(*cfg.env.center_crop)) + env.append_transform(Resize(cfg.env.image_size, cfg.env.image_size)) + if cfg.env.grayscale: + env.append_transform(GrayScale()) + env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True)) + env.append_transform(CatFrames(N=cfg.env.catframes, in_keys=["pixels"], dim=-3)) + if stats is None and obs_norm_state_dict is None: + obs_stats = {} + elif stats is None: + obs_stats = copy(obs_norm_state_dict) + else: + obs_stats = copy(stats) + obs_stats["standard_normal"] = True + obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) + env.append_transform(obs_norm) + if norm_rewards: + reward_scaling = 1.0 + reward_loc = 0.0 + if norm_obs_only: + reward_scaling = 1.0 + reward_loc = 0.0 + if reward_scaling is not None: + env.append_transform(RewardScaling(reward_loc, reward_scaling)) + + if not from_pixels: + selected_keys = [ + key + for key in env.observation_spec.keys(True, True) + if ("pixels" not in key) and (key not in env.state_spec.keys(True, True)) + ] + + # even if there is a single tensor, it'll be renamed in "observation_vector" + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + if not vecnorm: + if stats is None and obs_norm_state_dict is None: + _stats = {} + elif stats is None: + _stats = copy(obs_norm_state_dict) + else: + _stats = copy(stats) + _stats.update({"standard_normal": True}) + obs_norm = ObservationNorm( + **_stats, + in_keys=[out_key], + ) + env.append_transform(obs_norm) + else: + env.append_transform( + VecNorm( + in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key], + decay=0.9999, + ) + ) + + env.append_transform(DoubleToFloat()) + + if hasattr(cfg, "catframes") and cfg.env.catframes: + env.append_transform( + CatFrames(N=cfg.env.catframes, in_keys=[out_key], dim=-1) + ) + + else: + env.append_transform(DoubleToFloat()) + + if hasattr(cfg, "gSDE") and cfg.exploration.gSDE: + env.append_transform( + gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde) + ) + + env.append_transform(StepCounter()) + env.append_transform(InitTracker()) + + return env + + +def make_redq_loss( + model, cfg +) -> Tuple[REDQLoss_deprecated, Optional[TargetNetUpdater]]: + """Builds the REDQ loss module.""" + loss_kwargs = {} + loss_kwargs.update({"loss_function": cfg.loss.loss_function}) + loss_kwargs.update({"delay_qvalue": cfg.loss.type == "double"}) + loss_class = REDQLoss_deprecated + if isinstance(model, ActorValueOperator): + actor_model = model.get_policy_operator() + qvalue_model = model.get_value_operator() + elif isinstance(model, ActorCriticOperator): + raise RuntimeError( + "Although REDQ Q-value depends upon selected actions, using the" + "ActorCriticOperator will lead to resampling of the actions when" + "computing the Q-value loss, which we don't want. Please use the" + "ActorValueOperator instead." + ) + else: + actor_model, qvalue_model = model + + loss_module = loss_class( + actor_network=actor_model, + qvalue_network=qvalue_model, + num_qvalue_nets=cfg.loss.num_q_values, + gSDE=cfg.exploration.gSDE, + **loss_kwargs, + ) + loss_module.make_value_estimator(gamma=cfg.loss.gamma) + target_net_updater = make_target_updater(cfg, loss_module) + return loss_module, target_net_updater + + +def make_target_updater( + cfg: "DictConfig", loss_module: LossModule # noqa: F821 +) -> TargetNetUpdater | None: + """Builds a target network weight update object.""" + if cfg.loss.type == "double": + if not cfg.loss.hard_update: + target_net_updater = SoftUpdate( + loss_module, eps=1 - 1 / cfg.loss.value_network_update_interval + ) + else: + target_net_updater = HardUpdate( + loss_module, + value_network_update_interval=cfg.loss.value_network_update_interval, + ) + else: + if cfg.hard_update: + raise RuntimeError( + "hard/soft-update are supposed to be used with double SAC loss. " + "Consider using --loss=double or discarding the hard_update flag." + ) + target_net_updater = None + return target_net_updater + + +def make_collector_offpolicy( + make_env: Callable[[], EnvBase], + actor_model_explore: TensorDictModuleWrapper | ProbabilisticTensorDictSequential, + cfg: "DictConfig", # noqa: F821 + make_env_kwargs: Dict | None = None, +) -> DataCollectorBase: + """Returns a data collector for off-policy algorithms. + + Args: + make_env (Callable): environment creator + actor_model_explore (SafeModule): Model instance used for evaluation and exploration update + cfg (DictConfig): config for creating collector object + make_env_kwargs (dict): kwargs for the env creator + + """ + if cfg.collector.async_collection: + collector_helper = sync_async_collector + else: + collector_helper = sync_sync_collector + + if cfg.collector.multi_step: + ms = MultiStep( + gamma=cfg.loss.gamma, + n_steps=cfg.collector.n_steps_return, + ) + else: + ms = None + + env_kwargs = {} + if make_env_kwargs is not None and isinstance(make_env_kwargs, dict): + env_kwargs.update(make_env_kwargs) + elif make_env_kwargs is not None: + env_kwargs = make_env_kwargs + cfg.collector.device = ( + cfg.collector.device + if len(cfg.collector.device) > 1 + else cfg.collector.device[0] + ) + collector_helper_kwargs = { + "env_fns": make_env, + "env_kwargs": env_kwargs, + "policy": actor_model_explore, + "max_frames_per_traj": cfg.collector.max_frames_per_traj, + "frames_per_batch": cfg.collector.frames_per_batch, + "total_frames": cfg.collector.total_frames, + "postproc": ms, + "num_env_per_collector": 1, + # we already took care of building the make_parallel_env function + "num_collectors": -cfg.num_workers // -cfg.collector.env_per_collector, + "device": cfg.collector.device, + "storing_device": cfg.collector.device, + "init_random_frames": cfg.collector.init_random_frames, + "split_trajs": True, + # trajectories must be separated if multi-step is used + "exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode), + } + + collector = collector_helper(**collector_helper_kwargs) + collector.set_seed(cfg.seed) + return collector + + +def make_replay_buffer( + device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821 +) -> ReplayBuffer: # noqa: F821 + """Builds a replay buffer using the config built from ReplayArgsConfig.""" + device = torch.device(device) + if not cfg.buffer.prb: + sampler = RandomSampler() + else: + sampler = PrioritizedSampler( + max_capacity=cfg.buffer.size, + alpha=0.7, + beta=0.5, + ) + buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage( + cfg.buffer.size, + scratch_dir=cfg.buffer.scratch_dir, + # device=device, # when using prefetch, this can overload the GPU memory + ), + sampler=sampler, + pin_memory=device != torch.device("cpu"), + prefetch=cfg.buffer.prefetch, + batch_size=cfg.buffer.batch_size, + ) + return buffer diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 2d3425a2151..dfd0ae30c14 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -13,7 +13,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - collector_device: cpu + device: cpu env_per_collector: 1 reset_at_each_iter: False diff --git a/examples/sac/utils.py b/examples/sac/utils.py index f07d3715866..69c7b7c7658 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -10,7 +10,15 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + ParallelEnv, + TransformedEnv, +) from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -25,12 +33,21 @@ # ----------------- -def env_maker(task, device="cpu"): - with set_gym_backend("gym"): - return GymEnv( - task, - device=device, +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def apply_env_transforms(env, max_episode_steps=1000): @@ -50,7 +67,7 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ) parallel_env.set_seed(cfg.env.seed) @@ -59,7 +76,7 @@ def make_environment(cfg): eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ), train_env.transform.clone(), ) @@ -79,7 +96,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.collector_device, + device=cfg.collector.device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 4ef557ed50c..210d865c11d 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -14,7 +14,7 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 reset_at_each_iter: False - collector_device: cpu + device: cpu env_per_collector: 1 num_workers: 1 diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 090529782fd..36d3ef99a9a 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -12,7 +12,9 @@ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import ( + CatTensors, Compose, + DMControlEnv, DoubleToFloat, EnvCreator, InitTracker, @@ -41,17 +43,21 @@ # ----------------- -def env_maker( - task, - device="cpu", - from_pixels=False, -): - with set_gym_backend("gym"): - return GymEnv( - task, - device=device, - from_pixels=from_pixels, +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") def apply_env_transforms(env, max_episode_steps): @@ -71,7 +77,7 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ) parallel_env.set_seed(cfg.env.seed) @@ -82,7 +88,7 @@ def make_environment(cfg): eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda task=cfg.env.name: env_maker(task=task)), + EnvCreator(lambda cfg=cfg: env_maker(cfg)), ), train_env.transform.clone(), ) @@ -103,7 +109,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.collector_device, + device=cfg.collector.device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index c490dd0e16c..f0e132eb092 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -814,7 +814,9 @@ def step_and_maybe_reset( # as this mechanism can be used by a policy to set anticipatively the # keys of the next call (eg, with recurrent nets) if key in self._env_input_keys or ( - isinstance(key, tuple) and key[0] == "next" + isinstance(key, tuple) + and key[0] == "next" + and key in self.shared_tensordict_parent.keys(True, True) ): val = tensordict.get(key) self.shared_tensordict_parent.set_(key, val) @@ -854,7 +856,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # as this mechanism can be used by a policy to set anticipatively the # keys of the next call (eg, with recurrent nets) if key in self._env_input_keys or ( - isinstance(key, tuple) and key[0] == "next" + isinstance(key, tuple) + and key[0] == "next" + and key in self.shared_tensordict_parent.keys(True, True) ): val = tensordict.get(key) self.shared_tensordict_parent.set_(key, val) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7634606c1af..c295adc007f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4011,8 +4011,9 @@ class TensorDictPrimer(Transform): tensordict with the desired features. Args: - primers (dict, optional): a dictionary containing key-spec pairs which will - be used to populate the input tensordict. + primers (dict or CompositeSpec, optional): a dictionary containing + key-spec pairs which will be used to populate the input tensordict. + :class:`~torchrl.data.CompositeSpec` instances are supported too. random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. @@ -4073,7 +4074,7 @@ class TensorDictPrimer(Transform): def __init__( self, - primers: dict = None, + primers: dict | CompositeSpec = None, random: bool = False, default_value: float = 0.0, reset_key: NestedKey | None = None, @@ -4087,7 +4088,9 @@ def __init__( "as kwargs." ) kwargs = primers - self.primers = CompositeSpec(kwargs) + if not isinstance(kwargs, CompositeSpec): + kwargs = CompositeSpec(kwargs) + self.primers = kwargs self.random = random self.default_value = default_value self.reset_key = reset_key diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 318ba09d02c..7c33895e965 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -406,7 +406,7 @@ def td0_return_estimate( gamma: float, next_state_value: torch.Tensor, reward: torch.Tensor, - terminated: torch.Tensor, + terminated: torch.Tensor | None = None, *, done: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -431,7 +431,8 @@ def td0_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if done is not None: + if done is not None and terminated is None: + terminated = done warnings.warn( "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." ) diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index c405297a110..ec7321f5bbd 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -31,7 +31,8 @@ def get_logger( """Get a logger instance of the provided `logger_type`. Args: - logger_type (str): One of tensorboard / csv / wandb / mlflow + logger_type (str): One of tensorboard / csv / wandb / mlflow. + If empty, ``None`` is returned. logger_name (str): Name to be used as a log_dir experiment_name (str): Name of the experiment kwargs (dict[str]): might contain either `wandb_kwargs` or `mlflow_kwargs` @@ -60,6 +61,8 @@ def get_logger( exp_name=experiment_name, **mlflow_kwargs, ) + elif logger_type in ("", None): + return None else: - raise NotImplementedError(f"Unsupported logger_type: {logger_type}") + raise NotImplementedError(f"Unsupported logger_type: '{logger_type}'") return logger diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 1021698012c..0adff694d3f 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple @@ -10,7 +11,6 @@ from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate from torchrl.objectives.common import LossModule from torchrl.objectives.deprecated import REDQLoss_deprecated - from torchrl.objectives.utils import TargetNetUpdater @@ -42,6 +42,10 @@ def make_redq_loss( model, cfg ) -> Tuple[REDQLoss_deprecated, Optional[TargetNetUpdater]]: """Builds the REDQ loss module.""" + warnings.warn( + "This helper function will be deprecated in v0.4. Consider using the local helper in the REDQ example.", + category=DeprecationWarning, + ) loss_kwargs = {} if hasattr(cfg, "distributional") and cfg.distributional: raise NotImplementedError diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 3951aa88c32..ee343aa438e 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import itertools +import warnings from dataclasses import dataclass from typing import Optional, Sequence @@ -290,6 +290,10 @@ def make_redq_model( is_shared=False) """ + warnings.warn( + "This helper function will be deprecated in v0.4. Consider using the local helper in the REDQ example.", + category=DeprecationWarning, + ) tanh_loc = cfg.tanh_loc default_policy_scale = cfg.default_policy_scale gSDE = cfg.gSDE diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 6a7f47843d2..669a16ca4cd 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -465,7 +465,10 @@ def train(self): self.collector.shutdown() def __del__(self): - self.collector.shutdown() + try: + self.collector.shutdown() + except Exception: + pass def shutdown(self): if VERBOSE: