diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5dd855d65e2..31f681b8a48 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -495,7 +495,7 @@ def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase: state = torch.zeros(self.size) + self.counter if tensordict is None: tensordict = TensorDict({}, self.batch_size, device=self.device) - tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state)) + tensordict = tensordict.empty().set(self.out_key, self._get_out_obs(state)) tensordict = tensordict.set(self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) tensordict.set( @@ -514,7 +514,7 @@ def _step( assert (a.sum(-1) == 1).all() obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep - tensordict = tensordict.select() # empty tensordict + tensordict = tensordict.empty() # empty tensordict tensordict.set(self.out_key, self._get_out_obs(obs)) tensordict.set(self._out_key, self._get_out_obs(obs)) @@ -602,7 +602,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: # state = torch.zeros(self.size) + self.counter if tensordict is None: tensordict = TensorDict({}, self.batch_size, device=self.device) - tensordict = tensordict.select() + tensordict = tensordict.empty() tensordict.update(self.observation_spec.rand()) # tensordict.set("next_" + self.out_key, self._get_out_obs(state)) # tensordict.set("next_" + self._out_key, self._get_out_obs(state)) @@ -621,7 +621,7 @@ def _step( a = tensordict.get("action") obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a) - tensordict = tensordict.select() # empty tensordict + tensordict = tensordict.empty() # empty tensordict tensordict.set(self.out_key, self._get_out_obs(obs)) tensordict.set(self._out_key, self._get_out_obs(obs)) diff --git a/test/test_shared.py b/test/test_shared.py index dcfb798e35c..e7cfa77b137 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -62,8 +62,7 @@ def test_shared(self, indexing_method): subtd = TensorDict( source={key: item[0] for key, item in td.items()}, batch_size=[], - _is_shared=True, - ) + ).share_memory_() elif indexing_method == 1: subtd = td.get_sub_tensordict(0) elif indexing_method == 2: diff --git a/test/test_transforms.py b/test/test_transforms.py index 3ef633eee98..6340dec842e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6990,7 +6990,7 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa KeyError, match=r"VIPRewardTransform.* requires .* key to be present in the input tensordict", ): - _ = transformed_env.reset(tensordict_reset.select()) + _ = transformed_env.reset(tensordict_reset.empty()) td = transformed_env.reset(tensordict_reset) assert td.device == device diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 3afa680c88d..c2646f8366b 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -305,7 +305,7 @@ def _get_dataset_direct(self, name, env_kwargs): else: self.metadata = {} dataset.rename_key_("observations", "observation") - dataset.set("next", dataset.select()) + dataset.create_nested("next") dataset.rename_key_("next_observations", ("next", "observation")) dataset.rename_key_("terminals", "terminated") if "timeouts" in dataset.keys(): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index d59155e5d5e..26d15e61c0a 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3364,7 +3364,7 @@ def encode( self, vals: Dict[str, Any], *, ignore_device: bool = False ) -> Dict[str, torch.Tensor]: if isinstance(vals, TensorDict): - out = vals.select() # create and empty tensordict similar to vals + out = vals.empty() # create and empty tensordict similar to vals else: out = TensorDict({}, torch.Size([]), _run_checks=False) for key, item in vals.items(): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index dd96e2a7a5c..dc368ffc3a6 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1210,7 +1210,7 @@ def _run_worker_pipe_shared_mem( raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" ) - shared_tensordict = shared_tensordict.clone(False) + shared_tensordict = shared_tensordict.clone(False).unlock_() initialized = True diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index eda8c859692..633ac2f78a3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1483,8 +1483,8 @@ def reset( if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " - "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " - "tensordict.select()) inside _reset before writing new tensors onto this new instance." + "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty())" + "inside _reset before writing new tensors onto this new instance." ) if not isinstance(tensordict_reset, TensorDictBase): raise RuntimeError( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 21bb542cb1d..6f115ec118e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3119,7 +3119,7 @@ class DTypeCastTransform(Transform): ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 - ... return obs.select().set("next", obs.update({"reward": reward, "done": done})) + ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float)) @@ -3480,7 +3480,7 @@ class DoubleToFloat(DTypeCastTransform): ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 - ... return obs.select().set("next", obs.update({"reward": reward, "done": done})) + ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DoubleToFloat())