Skip to content

Commit

Permalink
[Feature] ConditionalPolicySwitch transform
Browse files Browse the repository at this point in the history
ghstack-source-id: f147e7c6b0f55da5746f79563af66ad057021d66
Pull Request resolved: #2711
  • Loading branch information
vmoens committed Jan 26, 2025
1 parent bf707f5 commit 3c1241d
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
ConditionalPolicySwitch
Crop
DTypeCastTransform
DeviceCastTransform
Expand Down
72 changes: 54 additions & 18 deletions examples/agents/ppo-chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
import tensordict.nn
import torch
import tqdm
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
from tensordict.nn import (
ProbabilisticTensorDictModule as TDProb,
ProbabilisticTensorDictSequential as TDProbSeq,
TensorDictModule as TDMod,
TensorDictSequential as TDSeq,
)
from torch import nn
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement

from torchrl.envs import ChessEnv, Tokenizer
from torchrl.modules import MLP
from torchrl.modules.distributions import MaskedCategorical
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement

tensordict.nn.set_composite_lp_aggregate(False)

Expand All @@ -39,7 +43,9 @@
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)

# Embedding for the fen
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
embedding_fen = nn.Embedding(
num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
)

backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)

Expand All @@ -49,20 +55,30 @@
critic_head = nn.Linear(512, 1)
critic_head.bias.data.fill_(0)

prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
prob = TDProb(
in_keys=["logits", "mask"],
out_keys=["action"],
distribution_class=MaskedCategorical,
return_log_prob=True,
)


def make_mask(idx):
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]


actor = TDProbSeq(
TDMod(
make_mask,
in_keys=["legal_moves"], out_keys=["mask"]),
TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
out_keys=["features"]),
TDMod(
lambda *args: torch.cat(
[arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
),
in_keys=["embedded_legal_moves", "embedded_fen"],
out_keys=["features"],
),
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
prob,
Expand All @@ -78,7 +94,9 @@ def make_mask(idx):

optim = Adam(loss.parameters())

gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
gae = GAE(
value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
)

# Create a data collector
collector = SyncDataCollector(
Expand All @@ -88,12 +106,20 @@ def make_mask(idx):
total_frames=1_000_000,
)

replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
replay_buffer0 = ReplayBuffer(
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
batch_size=batch_size,
sampler=SamplerWithoutReplacement(),
)
replay_buffer1 = ReplayBuffer(
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
batch_size=batch_size,
sampler=SamplerWithoutReplacement(),
)

for data in tqdm.tqdm(collector):
data = data.filter_non_tensor_data()
print('data', data[0::2])
print("data", data[0::2])
for i in range(num_epochs):
replay_buffer0.empty()
replay_buffer1.empty()
Expand All @@ -103,14 +129,24 @@ def make_mask(idx):
# player 1
data1 = gae(data[1::2])
if i == 0:
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
print(
"win rate for 0",
data0["next", "reward"].sum()
/ data["next", "done"].sum().clamp_min(1e-6),
)
print(
"win rate for 1",
data1["next", "reward"].sum()
/ data["next", "done"].sum().clamp_min(1e-6),
)

replay_buffer0.extend(data0)
replay_buffer1.extend(data1)

n_iter = collector.frames_per_batch//(2 * batch_size)
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
n_iter = collector.frames_per_batch // (2 * batch_size)
for (d0, d1) in tqdm.tqdm(
zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
):
loss_vals = (loss(d0) + loss(d1)) / 2
loss_vals.sum(reduce=True).backward()
gn = clip_grad_norm_(loss.parameters(), 100.0)
Expand Down
203 changes: 203 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import tensordict.tensordict
import torch
from tensordict.nn import WrapModule

from tensordict import (
NonTensorData,
NonTensorStack,
Expand Down Expand Up @@ -56,6 +58,7 @@
CenterCrop,
ClipTransform,
Compose,
ConditionalPolicySwitch,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
Expand Down Expand Up @@ -13341,6 +13344,206 @@ def test_composite_reward_spec(self) -> None:
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec


class TestConditionalPolicySwitch(TransformBase):
def test_single_trans_env_check(self):
base_env = CountingEnv(max_steps=15)
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
# Player 0
policy_odd = lambda td: td.set("action", env.action_spec.zero())
policy_even = lambda td: td.set("action", env.action_spec.one())
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
env = base_env.append_transform(transforms)
env.check_env_specs()

def _create_policy_odd(self, base_env):
return WrapModule(
lambda td, base_env=base_env: td.set(
"action", base_env.action_spec_unbatched.zero(td.shape)
),
out_keys=["action"],
)

def _create_policy_even(self, base_env):
return WrapModule(
lambda td, base_env=base_env: td.set(
"action", base_env.action_spec_unbatched.one(td.shape)
),
out_keys=["action"],
)

def _create_transforms(self, condition, policy_even):
return Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)

def _make_env(self, max_count, env_cls):
torch.manual_seed(0)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
base_env = env_cls(max_steps=max_count)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
return base_env.append_transform(transforms)

def _test_env(self, env, policy_odd):
env.check_env_specs()
env.set_seed(0)
r = env.rollout(100, policy_odd, break_when_any_done=False)
# Check results are independent: one reset / step in one env should not impact results in another
r0, r1, r2 = r.unbind(0)
r0_split = r0.split(6)
assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
r1_split = r1.split(7)
assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
r2_split = r2.split(8)
assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))

def test_trans_serial_env_check(self):
torch.manual_seed(0)
base_env = SerialEnv(
3,
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
batch_locked=False,
)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(base_env)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
env = base_env.append_transform(transforms)
self._test_env(env, policy_odd)

def test_trans_parallel_env_check(self):
torch.manual_seed(0)
base_env = ParallelEnv(
3,
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
batch_locked=False,
mp_start_method=mp_ctx,
)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(base_env)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
env = base_env.append_transform(transforms)
self._test_env(env, policy_odd)

def test_serial_trans_env_check(self):
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(CountingEnv())

def make_env(max_count):
return partial(self._make_env, max_count, CountingEnv)

env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
self._test_env(env, policy_odd)

def test_parallel_trans_env_check(self):
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(CountingEnv())

def make_env(max_count):
return partial(self._make_env, max_count, CountingEnv)

env = ParallelEnv(
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
)
self._test_env(env, policy_odd)

def test_transform_no_env(self):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
transforms(TensorDict())

def test_transform_compose(self):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
transforms = Compose(
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
transforms(TensorDict())

def test_transform_env(self):
base_env = CountingEnv(max_steps=15)
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
# Player 0
policy_odd = lambda td: td.set("action", env.action_spec.zero())
policy_even = lambda td: td.set("action", env.action_spec.one())
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
env = base_env.append_transform(transforms)
env.check_env_specs()
r = env.rollout(1000, policy_odd, break_when_all_done=True)
assert r.shape[0] == 15
assert (r["action"] == 0).all()
assert (
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
).all()
assert r["next", "done"].any()

# Player 1
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
)
env = base_env.append_transform(transforms)
r = env.rollout(1000, policy_even, break_when_all_done=True)
assert r.shape[0] == 16
assert (r["action"] == 1).all()
assert (
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
).all()
assert r["next", "done"].any()

def test_transform_model(self):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
transforms = nn.Sequential(
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
transforms(TensorDict())

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(
ConditionalPolicySwitch(condition=condition, policy=policy_even)
)
rb.extend(TensorDict(batch_size=[2]))
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
rb.sample(2)

def test_transform_inverse(self):
return


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
CenterCrop,
ClipTransform,
Compose,
ConditionalPolicySwitch,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
Expand Down
Loading

0 comments on commit 3c1241d

Please sign in to comment.