Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent ed656a1 commit 010c84f
Show file tree
Hide file tree
Showing 11 changed files with 497 additions and 77 deletions.
203 changes: 203 additions & 0 deletions examples/agents/composite_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Multi-head agent and PPO loss
=============================
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict.
It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution
object containing the three distributions.
The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters,
creates a distribution from these parameters, and samples from the distribution to output multiple actions.
The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss.
Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a
fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities`
argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False`
if not specified.
In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in
the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used.
"""

import functools

import torch
from tensordict import TensorDict
from tensordict.nn import (
CompositeDistribution,
InteractionType,
ProbabilisticTensorDictModule as Prob,
ProbabilisticTensorDictSequential as ProbSeq,
TensorDictModule as Mod,
TensorDictSequential as Seq,
WrapModule as Wrap,
)
from torch import distributions as d
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss

make_params = Mod(
lambda: (
torch.ones(4),
torch.ones(4),
torch.ones(4, 2),
torch.ones(4, 2),
torch.ones(4, 10) / 10,
torch.zeros(4, 10),
torch.ones(4, 10),
),
in_keys=[],
out_keys=[
("params", "gamma", "concentration"),
("params", "gamma", "rate"),
("params", "Kumaraswamy", "concentration0"),
("params", "Kumaraswamy", "concentration1"),
("params", "mixture", "logits"),
("params", "mixture", "loc"),
("params", "mixture", "scale"),
],
)


def mixture_constructor(logits, loc, scale):
return d.MixtureSameFamily(
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
)


# =============================================================================
# Example 0: aggregate_probabilities=None (default) ===========================

dist_constructor = functools.partial(
CompositeDistribution,
distribution_map={
"gamma": d.Gamma,
"Kumaraswamy": d.Kumaraswamy,
"mixture": mixture_constructor,
},
name_map={
"gamma": ("agent0", "action"),
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
},
aggregate_probabilities=None,
)


policy = ProbSeq(
make_params,
Prob(
in_keys=["params"],
out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
distribution_class=dist_constructor,
return_log_prob=True,
default_interaction_type=InteractionType.RANDOM,
),
)

td = policy(TensorDict(batch_size=[4]))
print("0. result of policy call", td)

dist = policy.get_dist(td)
log_prob = dist.log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
)
print("0. non-aggregated log-prob")

# We can also get the log-prob from the policy directly
log_prob = policy.log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
)
print("0. non-aggregated log-prob (from policy)")

# Build a dummy value operator
value_operator = Seq(
Wrap(
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
out_keys=["state_value"],
)
)

# Create fake data
data = policy(TensorDict(batch_size=[4]))
data.set(
"next",
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
)

# Instantiate the loss
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
ppo = loss_cls(policy, value_operator)

# Keys are not the default ones - there is more than one action
ppo.set_keys(
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
],
)

# Get the loss values
loss_vals = ppo(data)
print("0. ", loss_cls, loss_vals)


# ===================================================================
# Example 1: aggregate_probabilities=True ===========================

dist_constructor.keywords["aggregate_probabilities"] = True

td = policy(TensorDict(batch_size=[4]))
print("1. result of policy call", td)

# Instantiate the loss
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
ppo = loss_cls(policy, value_operator)

# Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since
# there is only one.
ppo.set_keys(
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")]
)

# Get the loss values
loss_vals = ppo(data)
print("1. ", loss_cls, loss_vals)


# ===================================================================
# Example 2: aggregate_probabilities=False ===========================

dist_constructor.keywords["aggregate_probabilities"] = False

td = policy(TensorDict(batch_size=[4]))
print("2. result of policy call", td)

# Instantiate the loss
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
ppo = loss_cls(policy, value_operator)

# Keys are not the default ones - there is more than one action
ppo.set_keys(
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
],
)

# Get the loss values
loss_vals = ppo(data)
print("2. ", loss_cls, loss_vals)
109 changes: 106 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TensorDictModule as Mod,
TensorDictSequential,
TensorDictSequential as Seq,
WrapModule,
)
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
Expand Down Expand Up @@ -8864,9 +8865,7 @@ def test_ppo_tensordict_keys_run(
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
@pytest.mark.parametrize(
"composite_action_dist",
[
False,
],
[False],
)
def test_ppo_notensordict(
self,
Expand Down Expand Up @@ -9060,6 +9059,110 @@ def test_ppo_value_clipping(
loss = loss_fn(td)
assert "loss_critic" in loss.keys()

def test_ppo_composite_dists(self):
d = torch.distributions

make_params = TensorDictModule(
lambda: (
torch.ones(4),
torch.ones(4),
torch.ones(4, 2),
torch.ones(4, 2),
torch.ones(4, 10) / 10,
torch.zeros(4, 10),
torch.ones(4, 10),
),
in_keys=[],
out_keys=[
("params", "gamma", "concentration"),
("params", "gamma", "rate"),
("params", "Kumaraswamy", "concentration0"),
("params", "Kumaraswamy", "concentration1"),
("params", "mixture", "logits"),
("params", "mixture", "loc"),
("params", "mixture", "scale"),
],
)

def mixture_constructor(logits, loc, scale):
return d.MixtureSameFamily(
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
)

dist_constructor = functools.partial(
CompositeDistribution,
distribution_map={
"gamma": d.Gamma,
"Kumaraswamy": d.Kumaraswamy,
"mixture": mixture_constructor,
},
name_map={
"gamma": ("agent0", "action"),
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
},
aggregate_probabilities=False,
include_sum=False,
inplace=True,
)
policy = ProbSeq(
make_params,
ProbabilisticTensorDictModule(
in_keys=["params"],
out_keys=[
("agent0", "action"),
("agent1", "action"),
("agent2", "action"),
],
distribution_class=dist_constructor,
return_log_prob=True,
default_interaction_type=InteractionType.RANDOM,
),
)
# We want to make sure there is no warning
td = policy(TensorDict(batch_size=[4]))
assert isinstance(
policy.get_dist(td).log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
),
TensorDict,
)
assert isinstance(
policy.log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
),
TensorDict,
)
value_operator = Seq(
WrapModule(
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
out_keys=["state_value"],
)
)
for cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
data = policy(TensorDict(batch_size=[4]))
data.set(
"next",
TensorDict(
reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)
),
)
ppo = cls(policy, value_operator)
ppo.set_keys(
action=[
("agent0", "action"),
("agent1", "action"),
("agent2", "action"),
],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
],
)
loss = ppo(data)
loss.sum(reduce=True)


class TestA2C(LossModuleTestBase):
seed = 0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
self.register_buffer(
Expand Down
Loading

0 comments on commit 010c84f

Please sign in to comment.