Skip to content

Commit

Permalink
[BugFix] Remove select() in favor of empty() (#1811)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 17, 2024
1 parent baea10b commit 93748e9
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 14 deletions.
8 changes: 4 additions & 4 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase:
state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)
tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state))
tensordict = tensordict.empty().set(self.out_key, self._get_out_obs(state))
tensordict = tensordict.set(self._out_key, self._get_out_obs(state))
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
tensordict.set(
Expand All @@ -514,7 +514,7 @@ def _step(
assert (a.sum(-1) == 1).all()

obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
tensordict = tensordict.select() # empty tensordict
tensordict = tensordict.empty() # empty tensordict

tensordict.set(self.out_key, self._get_out_obs(obs))
tensordict.set(self._out_key, self._get_out_obs(obs))
Expand Down Expand Up @@ -602,7 +602,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
# state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)
tensordict = tensordict.select()
tensordict = tensordict.empty()
tensordict.update(self.observation_spec.rand())
# tensordict.set("next_" + self.out_key, self._get_out_obs(state))
# tensordict.set("next_" + self._out_key, self._get_out_obs(state))
Expand All @@ -621,7 +621,7 @@ def _step(
a = tensordict.get("action")

obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)
tensordict = tensordict.select() # empty tensordict
tensordict = tensordict.empty() # empty tensordict

tensordict.set(self.out_key, self._get_out_obs(obs))
tensordict.set(self._out_key, self._get_out_obs(obs))
Expand Down
3 changes: 1 addition & 2 deletions test/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def test_shared(self, indexing_method):
subtd = TensorDict(
source={key: item[0] for key, item in td.items()},
batch_size=[],
_is_shared=True,
)
).share_memory_()
elif indexing_method == 1:
subtd = td.get_sub_tensordict(0)
elif indexing_method == 2:
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6990,7 +6990,7 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa
KeyError,
match=r"VIPRewardTransform.* requires .* key to be present in the input tensordict",
):
_ = transformed_env.reset(tensordict_reset.select())
_ = transformed_env.reset(tensordict_reset.empty())

td = transformed_env.reset(tensordict_reset)
assert td.device == device
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _get_dataset_direct(self, name, env_kwargs):
else:
self.metadata = {}
dataset.rename_key_("observations", "observation")
dataset.set("next", dataset.select())
dataset.create_nested("next")
dataset.rename_key_("next_observations", ("next", "observation"))
dataset.rename_key_("terminals", "terminated")
if "timeouts" in dataset.keys():
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,7 +3364,7 @@ def encode(
self, vals: Dict[str, Any], *, ignore_device: bool = False
) -> Dict[str, torch.Tensor]:
if isinstance(vals, TensorDict):
out = vals.select() # create and empty tensordict similar to vals
out = vals.empty() # create and empty tensordict similar to vals
else:
out = TensorDict({}, torch.Size([]), _run_checks=False)
for key, item in vals.items():
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ def _run_worker_pipe_shared_mem(
raise RuntimeError(
"tensordict must be placed in shared memory (share_memory_() or memmap_())"
)
shared_tensordict = shared_tensordict.clone(False)
shared_tensordict = shared_tensordict.clone(False).unlock_()

initialized = True

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,8 @@ def reset(
if tensordict_reset is tensordict:
raise RuntimeError(
"EnvBase._reset should return outplace changes to the input "
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
"tensordict.select()) inside _reset before writing new tensors onto this new instance."
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty())"
"inside _reset before writing new tensors onto this new instance."
)
if not isinstance(tensordict_reset, TensorDictBase):
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3119,7 +3119,7 @@ class DTypeCastTransform(Transform):
... obs = self.observation_spec.rand()
... assert reward.dtype == torch.float64
... assert obs["obs"].dtype == torch.float64
... return obs.select().set("next", obs.update({"reward": reward, "done": done}))
... return obs.empty().set("next", obs.update({"reward": reward, "done": done}))
... def _set_seed(self, seed):
... pass
>>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float))
Expand Down Expand Up @@ -3480,7 +3480,7 @@ class DoubleToFloat(DTypeCastTransform):
... obs = self.observation_spec.rand()
... assert reward.dtype == torch.float64
... assert obs["obs"].dtype == torch.float64
... return obs.select().set("next", obs.update({"reward": reward, "done": done}))
... return obs.empty().set("next", obs.update({"reward": reward, "done": done}))
... def _set_seed(self, seed):
... pass
>>> env = TransformedEnv(MyEnv(), DoubleToFloat())
Expand Down

0 comments on commit 93748e9

Please sign in to comment.