diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 6fcae219a8e..0fa5beb8017 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -605,6 +605,7 @@ to be able to create this other composition: RewardSum Reward2GoTransform SelectTransform + SignTransform SqueezeTransform StepCounter TargetReturn diff --git a/test/test_transforms.py b/test/test_transforms.py index 78840bd6ba2..3ef633eee98 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -95,6 +95,7 @@ RewardSum, SelectTransform, SerialEnv, + SignTransform, SqueezeTransform, StepCounter, TargetReturn, @@ -9594,6 +9595,168 @@ def test_transform_inverse(self): raise pytest.skip("No inverse for BurnInTransform") +class TestSignTransform(TransformBase): + @staticmethod + def check_sign_applied(tensor): + return torch.logical_or( + torch.logical_or(tensor == -1, tensor == 1), tensor == 0.0 + ).all() + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + torch.manual_seed(0) + rb = rbclass(storage=LazyTensorStorage(20)) + + t = Compose( + SignTransform( + in_keys=["observation", "reward"], + out_keys=["obs_sign", "reward_sign"], + in_keys_inv=["input"], + out_keys_inv=["input_sign"], + ) + ) + rb.append_transform(t) + data = TensorDict({"observation": 1, "reward": 2, "input": 3}, []) + rb.add(data) + sample = rb.sample(20) + + assert (sample["observation"] == 1).all() + assert self.check_sign_applied(sample["obs_sign"]) + + assert (sample["reward"] == 2).all() + assert self.check_sign_applied(sample["reward_sign"]) + + assert (sample["input"] == 3).all() + assert self.check_sign_applied(sample["input_sign"]) + + def test_single_trans_env_check(self): + env = ContinuousActionVecMockEnv() + env = TransformedEnv( + env, + SignTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + ), + ) + check_env_specs(env) + + def test_transform_compose(self): + t = Compose( + SignTransform( + in_keys=["observation", "reward"], + out_keys=["obs_sign", "reward_sign"], + ) + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t(data) + assert data["observation"] == 1 + assert self.check_sign_applied(data["obs_sign"]) + assert data["reward"] == 2 + assert self.check_sign_applied(data["reward_sign"]) + + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_env(self, device): + base_env = ContinuousActionVecMockEnv(device=device) + env = TransformedEnv( + base_env, + SignTransform( + in_keys=["observation", "reward"], + ), + ) + r = env.rollout(3) + assert r.device == device + assert self.check_sign_applied(r["observation"]) + assert self.check_sign_applied(r["next", "observation"]) + assert self.check_sign_applied(r["next", "reward"]) + check_env_specs(env) + + def test_transform_inverse(self): + t = SignTransform( + in_keys_inv=["observation", "reward"], + out_keys_inv=["obs_sign", "reward_sign"], + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t.inv(data) + assert data["observation"] == 1 + assert self.check_sign_applied(data["obs_sign"]) + assert data["reward"] == 2 + assert self.check_sign_applied(data["reward_sign"]) + + def test_transform_model(self): + t = nn.Sequential( + SignTransform( + in_keys=["observation", "reward"], + out_keys=["obs_sign", "reward_sign"], + ) + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t(data) + assert data["observation"] == 1 + assert self.check_sign_applied(data["obs_sign"]) + assert data["reward"] == 2 + assert self.check_sign_applied(data["reward_sign"]) + + def test_transform_no_env(self): + t = SignTransform( + in_keys=["observation", "reward"], + out_keys=["obs_sign", "reward_sign"], + ) + data = TensorDict({"observation": 1, "reward": 2}, []) + data = t(data) + assert data["observation"] == 1 + assert self.check_sign_applied(data["obs_sign"]) + assert data["reward"] == 2 + assert self.check_sign_applied(data["reward_sign"]) + + def test_parallel_trans_env_check(self): + def make_env(): + env = ContinuousActionVecMockEnv() + return TransformedEnv( + env, + SignTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + ), + ) + + env = ParallelEnv(2, make_env) + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + env = ContinuousActionVecMockEnv() + return TransformedEnv( + env, + SignTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + ), + ) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_trans_parallel_env_check(self): + env = TransformedEnv( + ParallelEnv(2, ContinuousActionVecMockEnv), + SignTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + ), + ) + check_env_specs(env) + + def test_trans_serial_env_check(self): + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + SignTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + ), + ) + check_env_specs(env) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index b38819263e6..31c94b38343 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -71,6 +71,7 @@ RewardScaling, RewardSum, SelectTransform, + SignTransform, SqueezeTransform, StepCounter, TargetReturn, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 849f41789c7..91968971c1f 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -39,6 +39,7 @@ RewardScaling, RewardSum, SelectTransform, + SignTransform, SqueezeTransform, StepCounter, TargetReturn, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 2571ac6f074..21bb542cb1d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6708,3 +6708,79 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def __repr__(self) -> str: return f"{self.__class__.__name__}(burn_in={self.burn_in}, in_keys={self.in_keys}, out_keys={self.out_keys})" + + +class SignTransform(Transform): + """A transform to compute the signs of TensorDict values. + + This transform reads the tensors in ``in_keys`` and ``in_keys_inv``, computes the + signs of their elements and writes the resulting sign tensors to ``out_keys`` and + ``out_keys_inv`` respectively. + + Args: + in_keys (list of NestedKeys): input entries (read) + out_keys (list of NestedKeys): input entries (write) + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. + + Examples: + >>> from torchrl.envs import GymEnv, TransformedEnv, SignTransform + >>> base_env = GymEnv("Pendulum-v1") + >>> env = TransformedEnv(base_env, SignTransform(in_keys=['observation'])) + >>> r = env.rollout(100) + >>> obs = r["observation"] + >>> assert (torch.logical_or(torch.logical_or(obs == -1, obs == 1), obs == 0.0)).all() + """ + + def __init__( + self, + in_keys=None, + out_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ): + if in_keys is None: + in_keys = [] + if out_keys is None: + out_keys = copy(in_keys) + if in_keys_inv is None: + in_keys_inv = [] + if out_keys_inv is None: + out_keys_inv = copy(in_keys_inv) + super().__init__(in_keys, out_keys, in_keys_inv, out_keys_inv) + + def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: + return obs.sign() + + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + return state.sign() + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + return BoundedTensorSpec( + shape=observation_spec.shape, + device=observation_spec.device, + dtype=observation_spec.dtype, + high=1.0, + low=-1.0, + ) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + for key in self.in_keys: + if key in self.parent.reward_keys: + spec = self.parent.output_spec["full_reward_spec"][key] + self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec( + shape=spec.shape, + device=spec.device, + dtype=spec.dtype, + high=1.0, + low=-1.0, + ) + return self.parent.output_spec["full_reward_spec"] + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset