Skip to content

Commit

Permalink
[Feature] ConditionalPolicySwitch transform
Browse files Browse the repository at this point in the history
ghstack-source-id: defb61a46ba3657f499c510d326dc917dd690b56
Pull Request resolved: #2711
  • Loading branch information
vmoens committed Jan 22, 2025
1 parent dd2bf20 commit 29f7971
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
ConditionalPolicySwitch
Crop
DTypeCastTransform
DeviceCastTransform
Expand Down
72 changes: 54 additions & 18 deletions examples/agents/ppo-chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
import tensordict.nn
import torch
import tqdm
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
from tensordict.nn import (
ProbabilisticTensorDictModule as TDProb,
ProbabilisticTensorDictSequential as TDProbSeq,
TensorDictModule as TDMod,
TensorDictSequential as TDSeq,
)
from torch import nn
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement

from torchrl.envs import ChessEnv, Tokenizer
from torchrl.modules import MLP
from torchrl.modules.distributions import MaskedCategorical
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement

tensordict.nn.set_composite_lp_aggregate(False)

Expand All @@ -39,7 +43,9 @@
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)

# Embedding for the fen
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
embedding_fen = nn.Embedding(
num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
)

backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)

Expand All @@ -49,20 +55,30 @@
critic_head = nn.Linear(512, 1)
critic_head.bias.data.fill_(0)

prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
prob = TDProb(
in_keys=["logits", "mask"],
out_keys=["action"],
distribution_class=MaskedCategorical,
return_log_prob=True,
)


def make_mask(idx):
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]


actor = TDProbSeq(
TDMod(
make_mask,
in_keys=["legal_moves"], out_keys=["mask"]),
TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
out_keys=["features"]),
TDMod(
lambda *args: torch.cat(
[arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
),
in_keys=["embedded_legal_moves", "embedded_fen"],
out_keys=["features"],
),
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
prob,
Expand All @@ -78,7 +94,9 @@ def make_mask(idx):

optim = Adam(loss.parameters())

gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
gae = GAE(
value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
)

# Create a data collector
collector = SyncDataCollector(
Expand All @@ -88,12 +106,20 @@ def make_mask(idx):
total_frames=1_000_000,
)

replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
replay_buffer0 = ReplayBuffer(
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
batch_size=batch_size,
sampler=SamplerWithoutReplacement(),
)
replay_buffer1 = ReplayBuffer(
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
batch_size=batch_size,
sampler=SamplerWithoutReplacement(),
)

for data in tqdm.tqdm(collector):
data = data.filter_non_tensor_data()
print('data', data[0::2])
print("data", data[0::2])
for i in range(num_epochs):
replay_buffer0.empty()
replay_buffer1.empty()
Expand All @@ -103,14 +129,24 @@ def make_mask(idx):
# player 1
data1 = gae(data[1::2])
if i == 0:
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
print(
"win rate for 0",
data0["next", "reward"].sum()
/ data["next", "done"].sum().clamp_min(1e-6),
)
print(
"win rate for 1",
data1["next", "reward"].sum()
/ data["next", "done"].sum().clamp_min(1e-6),
)

replay_buffer0.extend(data0)
replay_buffer1.extend(data1)

n_iter = collector.frames_per_batch//(2 * batch_size)
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
n_iter = collector.frames_per_batch // (2 * batch_size)
for (d0, d1) in tqdm.tqdm(
zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
):
loss_vals = (loss(d0) + loss(d1)) / 2
loss_vals.sum(reduce=True).backward()
gn = clip_grad_norm_(loss.parameters(), 100.0)
Expand Down
171 changes: 171 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tensordict.tensordict
import torch
from tensordict.nn import WrapModule

from torchrl.collectors import MultiSyncDataCollector

Expand Down Expand Up @@ -106,6 +107,7 @@
CenterCrop,
ClipTransform,
Compose,
ConditionalPolicySwitch,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
Expand Down Expand Up @@ -13192,6 +13194,175 @@ def test_composite_reward_spec(self) -> None:
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec


class TestConditionalPolicySwitch(TransformBase):
def test_single_trans_env_check(self):
base_env = CountingEnv(max_steps=15)
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
# Player 0
policy_odd = lambda td: td.set("action", env.action_spec.zero())
policy_even = lambda td: td.set("action", env.action_spec.one())
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
env = base_env.append_transform(transforms)
r = env.rollout(1000, policy_odd, break_when_all_done=True)
assert r.shape[0] == 15
assert (r["action"] == 0).all()
assert (
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
).all()
assert r["next", "done"].any()

# Player 1
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
)
env = base_env.append_transform(transforms)
r = env.rollout(1000, policy_even, break_when_all_done=True)
assert r.shape[0] == 16
assert (r["action"] == 1).all()
assert (
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
).all()
assert r["next", "done"].any()

def _create_policy_odd(self, base_env):
return WrapModule(
lambda td, base_env=base_env: td.set(
"action", base_env.action_spec_unbatched.zero(td.shape)
),
out_keys=["action"],
)

def _create_policy_even(self, base_env):
return WrapModule(
lambda td, base_env=base_env: td.set(
"action", base_env.action_spec_unbatched.one(td.shape)
),
out_keys=["action"],
)

def _create_transforms(self, condition, policy_even):
return Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)

def _make_env(self, max_count, env_cls):
torch.manual_seed(0)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
base_env = env_cls(max_steps=max_count)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
return base_env.append_transform(transforms)

def _test_env(self, env, policy_odd):
env.check_env_specs()
env.set_seed(0)
r = env.rollout(100, policy_odd, break_when_any_done=False)
# Check results are independent: one reset / step in one env should not impact results in another
r0, r1, r2 = r.unbind(0)
r0_split = r0.split(6)
assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
r1_split = r1.split(7)
assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
r2_split = r2.split(8)
assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))

def test_trans_serial_env_check(self):
torch.manual_seed(0)
base_env = SerialEnv(
3,
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
batch_locked=False,
)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(base_env)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
env = base_env.append_transform(transforms)
self._test_env(env, policy_odd)

def test_trans_parallel_env_check(self):
torch.manual_seed(0)
base_env = ParallelEnv(
3,
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
batch_locked=False,
mp_start_method=mp_ctx,
)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(base_env)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
env = base_env.append_transform(transforms)
self._test_env(env, policy_odd)

def test_serial_trans_env_check(self):
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(CountingEnv())

def make_env(max_count):
return partial(self._make_env, max_count, CountingEnv)

env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
self._test_env(env, policy_odd)

def test_parallel_trans_env_check(self):
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(CountingEnv())

def make_env(max_count):
return partial(self._make_env, max_count, CountingEnv)

env = ParallelEnv(
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
)
self._test_env(env, policy_odd)

def test_transform_no_env(self):
"""tests the transform on dummy data, without an env."""
raise NotImplementedError

def test_transform_compose(self):
"""tests the transform on dummy data, without an env but inside a Compose."""
raise NotImplementedError

def test_transform_env(self):
"""tests the transform on a real env.
If possible, do not use a mock env, as bugs may go unnoticed if the dynamic is too
simplistic. A call to reset() and step() should be tested independently, ie
a check that reset produces the desired output and that step() does too.
"""
raise NotImplementedError

def test_transform_model(self):
"""tests the transform before an nn.Module that reads the output."""
raise NotImplementedError

def test_transform_rb(self):
"""tests the transform when used with a replay buffer.
If your transform is not supposed to work with a replay buffer, test that
an error will be raised when called or appended to a RB.
"""
raise NotImplementedError

def test_transform_inverse(self):
"""tests the inverse transform. If not applicable, simply skip this test.
If your transform is not supposed to work offline, test that
an error will be raised when called in a nn.Module.
"""
raise NotImplementedError


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
CenterCrop,
ClipTransform,
Compose,
ConditionalPolicySwitch,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
Expand Down
11 changes: 8 additions & 3 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ class BatchedEnvBase(EnvBase):
one of the environment has dynamic specs.
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
batch_locked (bool, optional): if provided, will override the ``batch_locked`` attribute of the
nested environments. `batch_locked=False` may allow for partial steps.
.. note::
One can pass keyword arguments to each sub-environments using the following
Expand Down Expand Up @@ -305,6 +307,7 @@ def __init__(
non_blocking: bool = False,
mp_start_method: str = None,
use_buffers: bool = None,
batch_locked: bool | None = None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -344,6 +347,7 @@ def __init__(

# if share_individual_td is None, we will assess later if the output can be stacked
self.share_individual_td = share_individual_td
self._batch_locked = batch_locked
self._share_memory = shared_memory
self._memmap = memmap
self.allow_step_when_done = allow_step_when_done
Expand Down Expand Up @@ -610,8 +614,8 @@ def map_device(key, value, device_map=device_map):
self._env_tensordict.named_apply(
map_device, nested_keys=True, filter_empty=True
)

self._batch_locked = meta_data.batch_locked
if self._batch_locked is None:
self._batch_locked = meta_data.batch_locked
else:
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
devices = set()
Expand Down Expand Up @@ -652,7 +656,8 @@ def map_device(key, value, device_map=device_map):
self._env_tensordict = torch.stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
self._batch_locked = meta_data[0].batch_locked
if self._batch_locked is None:
self._batch_locked = meta_data[0].batch_locked
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)

def state_dict(self) -> OrderedDict:
Expand Down
Loading

0 comments on commit 29f7971

Please sign in to comment.