Skip to content

Commit

Permalink
[BugFix] Fix storage device (#1650)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 24, 2023
1 parent b7d148b commit 105e861
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,9 @@ def set( # noqa: F811
self._init(data)
if not isinstance(cursor, (*INT_CLASSES, slice)):
if not isinstance(cursor, torch.Tensor):
cursor = torch.tensor(cursor)
cursor = torch.tensor(cursor, dtype=torch.long, device=self.device)
elif cursor.dtype != torch.long:
cursor = cursor.to(dtype=torch.long, device=self.device)
if len(cursor) > len(self._storage):
warnings.warn(
"A cursor of length superior to the storage capacity was provided. "
Expand Down

0 comments on commit 105e861

Please sign in to comment.