diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 6193bd96afa..42dc35bfac5 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -60,7 +60,7 @@ class AtariDQNExperienceReplay(BaseDatasetExperienceReplay): root (Path or str, optional): The AtariDQN dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/atari`. + `~/.cache/torchrl/atari`.atari`. num_procs (int, optional): number of processes to launch for preprocessing. Has no effect whenever the data is already downloaded. Defaults to 0 (no multiprocessing used). diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index a1567aa0385..0189678dd1d 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -106,7 +106,7 @@ class D4RLExperienceReplay(BaseDatasetExperienceReplay): root (Path or str, optional): The D4RL dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/d4rl`. + `~/.cache/torchrl/atari`.d4rl`. download (bool, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. **env_kwargs (key-value pairs): additional kwargs for diff --git a/torchrl/data/datasets/gen_dgrl.py b/torchrl/data/datasets/gen_dgrl.py index ddde2e9d65d..672ec8465bf 100644 --- a/torchrl/data/datasets/gen_dgrl.py +++ b/torchrl/data/datasets/gen_dgrl.py @@ -60,7 +60,7 @@ class GenDGRLExperienceReplay(BaseDatasetExperienceReplay): dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/gen_dgrl`. + `~/.cache/torchrl/atari`.gen_dgrl`. download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as ``"force"``, in which case the downloaded data will be overwritten. diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 113c2595b77..126d0dadc93 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -66,7 +66,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): root (Path or str, optional): The Minari dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/minari`. + `~/.cache/torchrl/atari`.minari`. download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as ``"force"``, in which case the downloaded data will be overwritten. diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 01f5fdf98ce..c5b5dc9e8fc 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -123,7 +123,7 @@ class for more information on how to interact with non-tensor data root (Path or str, optional): The OpenX dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/openx`. + `~/.cache/torchrl/atari`.openx`. streaming (bool, optional): if ``True``, the data won't be downloaded but read from a stream instead. diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index f967ddfa1fc..1a83c302860 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -57,7 +57,7 @@ class RobosetExperienceReplay(BaseDatasetExperienceReplay): root (Path or str, optional): The Roboset dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/roboset`. + `~/.cache/torchrl/atari`.roboset`. download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as ``"force"``, in which case the downloaded data will be overwritten. diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 7290a714155..f1b2cc673cc 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -63,7 +63,7 @@ class VD4RLExperienceReplay(BaseDatasetExperienceReplay): root (Path or str, optional): The V-D4RL dataset root directory. The actual dataset memory-mapped files will be saved under `/`. If none is provided, it defaults to - ``~/.cache/torchrl/vd4rl`. + `~/.cache/torchrl/atari`.vd4rl`. download (bool or str, optional): Whether the dataset should be downloaded if not found. Defaults to ``True``. Download can also be passed as ``"force"``, in which case the downloaded data will be overwritten. diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index ff0fb4dfe24..6c4c2f9e0e2 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -80,12 +80,12 @@ class QueryModule(TensorDictModuleBase): If a single ``hash_module`` is provided but no aggregator is passed, it will take the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is provided but no aggregator is passed, it will default to ``SipHash``. - clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be + clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be returned. This can be used to retrieve the integer index within the storage, corresponding to a given input tensordict. This can be overridden at runtime by providing the ``clone`` argument to the forward method. Defaults to ``False``. - d + Examples: >>> query_module = QueryModule( ... in_keys=["key1", "key2"], @@ -106,6 +106,7 @@ class QueryModule(TensorDictModuleBase): >>> # The last three pairs of key1 and key2 have at least one mismatching value >>> assert res["index"][1] != res["index"][2] >>> assert res["index"][2] != res["index"][3] + """ def __init__( diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 4e0ee36cd4a..e2f2918b3dc 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -99,6 +99,7 @@ class ReplayBuffer: is used with PyTree structures (see example below). batch_size (int, optional): the batch size to be used when sample() is called. + .. note:: The batch-size can be specified at construction time via the ``batch_size`` argument, or at sampling time. The former should @@ -108,6 +109,7 @@ class ReplayBuffer: incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument. + dim_extend (int, optional): indicates the dim to consider for extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``. When using ``dim_extend > 0``, we recommend using the ``ndim`` @@ -128,6 +130,7 @@ class ReplayBuffer: >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. Using a dedicated generator for the replay buffer can allow a fine-grained control over seeding, for instance keeping the global seed different but the RB seed identical @@ -582,6 +585,7 @@ def register_save_hook(self, hook: Callable[[Any], Any]): .. note:: Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created. + """ self._storage.register_save_hook(hook) @@ -926,8 +930,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): construct a tensordict from the non-tensordict content. batch_size (int, optional): the batch size to be used when sample() is called. - .. note:: - The batch-size can be specified at construction time via the + + .. note:: The batch-size can be specified at construction time via the ``batch_size`` argument, or at sampling time. The former should be preferred whenever the batch-size is consistent across the experiment. If the batch-size is likely to change, it can be @@ -935,6 +939,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument. + dim_extend (int, optional): indicates the dim to consider for extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``. When using ``dim_extend > 0``, we recommend using the ``ndim`` @@ -1051,6 +1056,7 @@ class TensorDictReplayBuffer(ReplayBuffer): construct a tensordict from the non-tensordict content. batch_size (int, optional): the batch size to be used when sample() is called. + .. note:: The batch-size can be specified at construction time via the ``batch_size`` argument, or at sampling time. The former should @@ -1060,6 +1066,7 @@ class TensorDictReplayBuffer(ReplayBuffer): incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument. + priority_key (str, optional): the key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type @@ -1085,6 +1092,7 @@ class TensorDictReplayBuffer(ReplayBuffer): >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. Using a dedicated generator for the replay buffer can allow a fine-grained control over seeding, for instance keeping the global seed different but the RB seed identical @@ -1394,6 +1402,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): construct a tensordict from the non-tensordict content. batch_size (int, optional): the batch size to be used when sample() is called. + .. note:: The batch-size can be specified at construction time via the ``batch_size`` argument, or at sampling time. The former should @@ -1403,6 +1412,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument. + priority_key (str, optional): the key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type @@ -1431,6 +1441,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. Using a dedicated generator for the replay buffer can allow a fine-grained control over seeding, for instance keeping the global seed different but the RB seed identical @@ -1669,6 +1680,7 @@ class ReplayBufferEnsemble(ReplayBuffer): Defaults to ``None`` (global default generator). .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 26fb710f94c..bbde6761f4a 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -198,6 +198,7 @@ class RolloutFromModel: batch_size=torch.Size([4, 50]), device=cpu, is_shared=False) + """ EOS_TOKEN_ID = 50256