Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] SignTransform #1798

Merged
merged 4 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ to be able to create this other composition:
RewardSum
Reward2GoTransform
SelectTransform
SignTransform
SqueezeTransform
StepCounter
TargetReturn
Expand Down
163 changes: 163 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
RewardSum,
SelectTransform,
SerialEnv,
SignTransform,
SqueezeTransform,
StepCounter,
TargetReturn,
Expand Down Expand Up @@ -9594,6 +9595,168 @@ def test_transform_inverse(self):
raise pytest.skip("No inverse for BurnInTransform")


class TestSignTransform(TransformBase):
@staticmethod
def check_sign_applied(tensor):
return torch.logical_or(
torch.logical_or(tensor == -1, tensor == 1), tensor == 0.0
).all()

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
torch.manual_seed(0)
rb = rbclass(storage=LazyTensorStorage(20))

t = Compose(
SignTransform(
in_keys=["observation", "reward"],
out_keys=["obs_sign", "reward_sign"],
in_keys_inv=["input"],
out_keys_inv=["input_sign"],
)
)
rb.append_transform(t)
data = TensorDict({"observation": 1, "reward": 2, "input": 3}, [])
rb.add(data)
sample = rb.sample(20)

assert (sample["observation"] == 1).all()
assert self.check_sign_applied(sample["obs_sign"])

assert (sample["reward"] == 2).all()
assert self.check_sign_applied(sample["reward_sign"])

assert (sample["input"] == 3).all()
assert self.check_sign_applied(sample["input_sign"])

def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv()
env = TransformedEnv(
env,
SignTransform(
in_keys=["observation", "reward"],
in_keys_inv=["observation_orig"],
),
)
check_env_specs(env)

def test_transform_compose(self):
t = Compose(
SignTransform(
in_keys=["observation", "reward"],
out_keys=["obs_sign", "reward_sign"],
)
)
data = TensorDict({"observation": 1, "reward": 2}, [])
data = t(data)
assert data["observation"] == 1
assert self.check_sign_applied(data["obs_sign"])
assert data["reward"] == 2
assert self.check_sign_applied(data["reward_sign"])

@pytest.mark.parametrize("device", get_default_devices())
def test_transform_env(self, device):
base_env = ContinuousActionVecMockEnv(device=device)
env = TransformedEnv(
base_env,
SignTransform(
in_keys=["observation", "reward"],
),
)
r = env.rollout(3)
assert r.device == device
assert self.check_sign_applied(r["observation"])
assert self.check_sign_applied(r["next", "observation"])
assert self.check_sign_applied(r["next", "reward"])
check_env_specs(env)

def test_transform_inverse(self):
t = SignTransform(
in_keys_inv=["observation", "reward"],
out_keys_inv=["obs_sign", "reward_sign"],
)
data = TensorDict({"observation": 1, "reward": 2}, [])
data = t.inv(data)
assert data["observation"] == 1
assert self.check_sign_applied(data["obs_sign"])
assert data["reward"] == 2
assert self.check_sign_applied(data["reward_sign"])

def test_transform_model(self):
t = nn.Sequential(
SignTransform(
in_keys=["observation", "reward"],
out_keys=["obs_sign", "reward_sign"],
)
)
data = TensorDict({"observation": 1, "reward": 2}, [])
data = t(data)
assert data["observation"] == 1
assert self.check_sign_applied(data["obs_sign"])
assert data["reward"] == 2
assert self.check_sign_applied(data["reward_sign"])

def test_transform_no_env(self):
t = SignTransform(
in_keys=["observation", "reward"],
out_keys=["obs_sign", "reward_sign"],
)
data = TensorDict({"observation": 1, "reward": 2}, [])
data = t(data)
assert data["observation"] == 1
assert self.check_sign_applied(data["obs_sign"])
assert data["reward"] == 2
assert self.check_sign_applied(data["reward_sign"])

def test_parallel_trans_env_check(self):
def make_env():
env = ContinuousActionVecMockEnv()
return TransformedEnv(
env,
SignTransform(
in_keys=["observation", "reward"],
in_keys_inv=["observation_orig"],
),
)

env = ParallelEnv(2, make_env)
check_env_specs(env)

def test_serial_trans_env_check(self):
def make_env():
env = ContinuousActionVecMockEnv()
return TransformedEnv(
env,
SignTransform(
in_keys=["observation", "reward"],
in_keys_inv=["observation_orig"],
),
)

env = SerialEnv(2, make_env)
check_env_specs(env)

def test_trans_parallel_env_check(self):
env = TransformedEnv(
ParallelEnv(2, ContinuousActionVecMockEnv),
SignTransform(
in_keys=["observation", "reward"],
in_keys_inv=["observation_orig"],
),
)
check_env_specs(env)

def test_trans_serial_env_check(self):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
SignTransform(
in_keys=["observation", "reward"],
in_keys_inv=["observation_orig"],
),
)
check_env_specs(env)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
RewardScaling,
RewardSum,
SelectTransform,
SignTransform,
SqueezeTransform,
StepCounter,
TargetReturn,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
RewardScaling,
RewardSum,
SelectTransform,
SignTransform,
SqueezeTransform,
StepCounter,
TargetReturn,
Expand Down
76 changes: 76 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6708,3 +6708,79 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

def __repr__(self) -> str:
return f"{self.__class__.__name__}(burn_in={self.burn_in}, in_keys={self.in_keys}, out_keys={self.out_keys})"


class SignTransform(Transform):
"""A transform to compute the signs of TensorDict values.

This transform reads the tensors in ``in_keys`` and ``in_keys_inv``, computes the
signs of their elements and writes the resulting sign tensors to ``out_keys`` and
``out_keys_inv`` respectively.

Args:
in_keys (list of NestedKeys): input entries (read)
out_keys (list of NestedKeys): input entries (write)
in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls.
out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls.

Examples:
>>> from torchrl.envs import GymEnv, TransformedEnv, SignTransform
>>> base_env = GymEnv("Pendulum-v1")
>>> env = TransformedEnv(base_env, SignTransform(in_keys=['observation']))
>>> r = env.rollout(100)
>>> obs = r["observation"]
>>> assert (torch.logical_or(torch.logical_or(obs == -1, obs == 1), obs == 0.0)).all()
"""

def __init__(
self,
in_keys=None,
out_keys=None,
in_keys_inv=None,
out_keys_inv=None,
):
if in_keys is None:
in_keys = []
if out_keys is None:
out_keys = copy(in_keys)
if in_keys_inv is None:
in_keys_inv = []
if out_keys_inv is None:
out_keys_inv = copy(in_keys_inv)
super().__init__(in_keys, out_keys, in_keys_inv, out_keys_inv)

def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
return obs.sign()

def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
return state.sign()

@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
return BoundedTensorSpec(
shape=observation_spec.shape,
device=observation_spec.device,
dtype=observation_spec.dtype,
high=1.0,
low=-1.0,
)

def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
for key in self.in_keys:
if key in self.parent.reward_keys:
spec = self.parent.output_spec["full_reward_spec"][key]
self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec(
shape=spec.shape,
device=spec.device,
dtype=spec.dtype,
high=1.0,
low=-1.0,
)
return self.parent.output_spec["full_reward_spec"]

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
Loading