diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py new file mode 100644 index 00000000000..d75ce3218b3 --- /dev/null +++ b/examples/agents/composite_ppo.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multi-head agent and PPO loss +============================= + +This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions +(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. + +The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict. +It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution +object containing the three distributions. + +The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters, +creates a distribution from these parameters, and samples from the distribution to output multiple actions. + +The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss. + +Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a +fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities` +argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False` +if not specified. + +In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in +the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used. + +""" + +import functools + +import torch +from tensordict import TensorDict +from tensordict.nn import ( + CompositeDistribution, + InteractionType, + ProbabilisticTensorDictModule as Prob, + ProbabilisticTensorDictSequential as ProbSeq, + TensorDictModule as Mod, + TensorDictSequential as Seq, + WrapModule as Wrap, +) +from torch import distributions as d +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss + +make_params = Mod( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], +) + + +def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + +# ============================================================================= +# Example 0: aggregate_probabilities=None (default) =========================== + +dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + aggregate_probabilities=None, +) + + +policy = ProbSeq( + make_params, + Prob( + in_keys=["params"], + out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), +) + +td = policy(TensorDict(batch_size=[4])) +print("0. result of policy call", td) + +dist = policy.get_dist(td) +log_prob = dist.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False +) +print("0. non-aggregated log-prob") + +# We can also get the log-prob from the policy directly +log_prob = policy.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False +) +print("0. non-aggregated log-prob (from policy)") + +# Build a dummy value operator +value_operator = Seq( + Wrap( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) +) + +# Create fake data +data = policy(TensorDict(batch_size=[4])) +data.set( + "next", + TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), +) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + + # Get the loss values + loss_vals = ppo(data) + print("0. ", loss_cls, loss_vals) + + +# =================================================================== +# Example 1: aggregate_probabilities=True =========================== + +dist_constructor.keywords["aggregate_probabilities"] = True + +td = policy(TensorDict(batch_size=[4])) +print("1. result of policy call", td) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since + # there is only one. + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")] + ) + + # Get the loss values + loss_vals = ppo(data) + print("1. ", loss_cls, loss_vals) + + +# =================================================================== +# Example 2: aggregate_probabilities=False =========================== + +dist_constructor.keywords["aggregate_probabilities"] = False + +td = policy(TensorDict(batch_size=[4])) +print("2. result of policy call", td) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + + # Get the loss values + loss_vals = ppo(data) + print("2. ", loss_cls, loss_vals) diff --git a/test/test_cost.py b/test/test_cost.py index 1f191e41db6..7c7c97eedfc 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -34,6 +34,7 @@ TensorDictModule as Mod, TensorDictSequential, TensorDictSequential as Seq, + WrapModule, ) from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key @@ -8864,9 +8865,7 @@ def test_ppo_tensordict_keys_run( @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize( "composite_action_dist", - [ - False, - ], + [False], ) def test_ppo_notensordict( self, @@ -9060,6 +9059,110 @@ def test_ppo_value_clipping( loss = loss_fn(td) assert "loss_critic" in loss.keys() + def test_ppo_composite_dists(self): + d = torch.distributions + + make_params = TensorDictModule( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], + ) + + def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + aggregate_probabilities=False, + include_sum=False, + inplace=True, + ) + policy = ProbSeq( + make_params, + ProbabilisticTensorDictModule( + in_keys=["params"], + out_keys=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), + ) + # We want to make sure there is no warning + td = policy(TensorDict(batch_size=[4])) + assert isinstance( + policy.get_dist(td).log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False + ), + TensorDict, + ) + assert isinstance( + policy.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False + ), + TensorDict, + ) + value_operator = Seq( + WrapModule( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) + ) + for cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + data = policy(TensorDict(batch_size=[4])) + data.set( + "next", + TensorDict( + reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool) + ), + ) + ppo = cls(policy, value_operator) + ppo.set_keys( + action=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + loss = ppo(data) + loss.sum(reduce=True) + class TestA2C(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 6e056589a8c..894c8db5212 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -323,7 +323,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 22e84673641..8bd37f38c39 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -306,7 +306,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a0d193acbfc..16e7b5212a1 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -103,7 +103,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 2a4124c80de..d4df68c6cb6 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -195,7 +195,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index eb9a916dfc1..5411687eb5e 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import List, Tuple import torch from tensordict import ( @@ -27,12 +27,15 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _replace_last from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_add_or_extend_key, + _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, @@ -67,7 +70,10 @@ class PPOLoss(LossModule): Args: actor_network (ProbabilisticTensorDictSequential): policy operator. - critic_network (ValueOperator): value operator. + Typically, a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations + as input and outputting an action (or actions) as well as its log-probability value. + critic_network (ValueOperator): value operator. The critic will usually take the observations as input + and return a scalar value (``state_value`` by default) in the output keys. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -267,16 +273,16 @@ class _AcceptedKeys: Will be used for the underlying value estimator Defaults to ``"value_target"``. value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - sample_log_prob (NestedKey): The input tensordict key where the + sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the sample log probability is expected. Defaults to ``"sample_log_prob"``. - action (NestedKey): The input tensordict key where the action is expected. + action (NestedKey or list of nested keys): The input tensordict key where the action is expected. Defaults to ``"action"``. - reward (NestedKey): The input tensordict key where the reward is expected. + reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. - done (NestedKey): The key in the input TensorDict that indicates + done (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. - terminated (NestedKey): The key in the input TensorDict that indicates + terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ @@ -284,11 +290,11 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" - action: NestedKey = "action" - reward: NestedKey = "reward" - done: NestedKey = "done" - terminated: NestedKey = "terminated" + sample_log_prob: NestedKey | List[NestedKey] = "sample_log_prob" + action: NestedKey | List[NestedKey] = "action" + reward: NestedKey | List[NestedKey] = "reward" + done: NestedKey | List[NestedKey] = "done" + terminated: NestedKey | List[NestedKey] = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -369,8 +375,8 @@ def __init__( try: device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + except (AttributeError, StopIteration): + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) if critic_coef is not None: @@ -408,16 +414,16 @@ def functional(self): return self._functional def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + self._in_keys = list(set(keys)) @property @@ -456,6 +462,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: reward=self.tensor_keys.reward, done=self.tensor_keys.done, terminated=self.tensor_keys.terminated, + sample_log_prob=self.tensor_keys.sample_log_prob, ) self._set_in_keys() @@ -463,34 +470,58 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: + if isinstance(dist, CompositeDistribution): + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False + kwargs = {"aggregate_probabilities": aggregate, "include_sum": include_sum} + else: + kwargs = {} try: - if isinstance(dist, CompositeDistribution): - kwargs = {"aggregate_probabilities": False, "include_sum": False} - else: - kwargs = {} entropy = dist.entropy(**kwargs) - if is_tensor_collection(entropy): - entropy = _sum_td_features(entropy) except NotImplementedError: - x = dist.rsample((self.samples_mc_entropy,)) - log_prob = dist.log_prob(x) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) + log_prob = dist.log_prob(x, **kwargs) + if is_tensor_collection(log_prob): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + try: + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, log_prob, err) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) + if is_tensor_collection(entropy): + entropy = _sum_td_features(entropy) return entropy.unsqueeze(-1) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: + # current log_prob of actions - action = tensordict.get(self.tensor_keys.action) + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + try: + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) + except KeyError as err: + raise _make_lp_get_error(self.tensor_keys, tensordict, err) + if prev_log_prob.requires_grad: raise RuntimeError( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." @@ -505,21 +536,33 @@ def _log_weight( else: if isinstance(dist, CompositeDistribution): is_composite = True + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False kwargs = { "inplace": False, - "aggregate_probabilities": False, - "include_sum": False, + "aggregate_probabilities": aggregate, + "include_sum": include_sum, } else: is_composite = False kwargs = {} - log_prob = dist.log_prob(tensordict, **kwargs) - if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) + if is_tensor_collection(kl_approx): + kl_approx = _sum_td_features(kl_approx) return log_weight, dist, kl_approx @@ -893,6 +936,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: gain2 = ratio * advantage gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + if is_tensor_collection(gain): + gain = _sum_td_features(gain) td_out = TensorDict({"loss_objective": -gain}, batch_size=[]) td_out.set("clip_fraction", clip_fraction) @@ -1087,16 +1132,16 @@ def __init__( self.samples_mc_kl = samples_mc_kl def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + # Get the parameter keys from the actor dist actor_dist_module = None for module in self.actor_network.modules(): @@ -1156,6 +1201,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: advantage = (advantage - loc) / scale log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage + if is_tensor_collection(neg_loss): + neg_loss = _sum_td_features(neg_loss) with self.actor_network_params.to_module( self.actor_network @@ -1166,17 +1213,24 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) if isinstance(previous_dist, CompositeDistribution): + aggregate = previous_dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = previous_dist.include_sum + if include_sum is None: + include_sum = False kwargs = { - "aggregate_probabilities": False, + "aggregate_probabilities": aggregate, "inplace": False, - "include_sum": False, + "include_sum": include_sum, } else: kwargs = {} previous_log_prob = previous_dist.log_prob(x, **kwargs) current_log_prob = current_dist.log_prob(x, **kwargs) - if is_tensor_collection(current_log_prob): + if is_tensor_collection(previous_log_prob): previous_log_prob = _sum_td_features(previous_log_prob) + # Both dists have presumably the same params current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) @@ -1214,3 +1268,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init + + +def _make_lp_get_error(tensor_keys, log_prob, err): + result = ( + f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does " + f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. " + ) + # now check if we can substitute the actions with action_log_prob and retrieve the log-probs + action_keys = tensor_keys.action + if isinstance(action_keys, list): + has_all_log_probs = True + log_prob_keys = [] + for action_key in action_keys: + log_prob_key = _replace_last(action_key, "action_log_prob") + log_prob_keys.append(log_prob_key) + if log_prob_key not in log_prob: + has_all_log_probs = False + break + if has_all_log_probs: + result += ( + f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the " + f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve " + f"this error." + ) + return KeyError(result) + result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=)." + return KeyError(result) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index e234df1a512..68eafb834e6 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -309,7 +309,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index eae6b7feb34..66431b9c9a5 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -383,7 +383,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 @@ -1102,7 +1102,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 9c46fc98262..3e0b97de710 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,10 +8,10 @@ import re import warnings from enum import Enum -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule from torch import nn, Tensor from torch.nn import functional as F @@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize def _sum_td_features(data: TensorDictBase) -> torch.Tensor: # Sum all features and return a tensor return data.sum(dim="feature", reduce=True) + + +def _maybe_get_or_select(td, key_or_keys): + if isinstance(key_or_keys, (str, tuple)): + return td.get(key_or_keys) + return td.select(*key_or_keys) + + +def _maybe_add_or_extend_key( + tensor_keys: List[NestedKey], + key_or_list_of_keys: NestedKey | List[NestedKey], + prefix: NestedKey = None, +): + if prefix is not None: + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(unravel_key((prefix, key_or_list_of_keys))) + else: + tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys]) + return + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(key_or_list_of_keys) + else: + tensor_keys.extend(key_or_list_of_keys) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 3b08780e24c..fa05c8860a6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -13,7 +13,7 @@ from typing import Callable, List, Union import torch -from tensordict import TensorDictBase +from tensordict import is_tensor_collection, TensorDictBase from tensordict.nn import ( CompositeDistribution, dispatch, @@ -23,13 +23,18 @@ TensorDictModuleBase, ) from tensordict.nn.probabilistic import interaction_type -from tensordict.utils import NestedKey +from tensordict.utils import NestedKey, unravel_key from torch import Tensor from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST +from torchrl.objectives.utils import ( + _maybe_get_or_select, + _vmap_func, + hold_out_net, + RANDOM_MODULE_LIST, +) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -293,13 +298,18 @@ def out_keys(self): def set_keys(self, **kwargs) -> None: """Set tensordict key names.""" - for key, value in kwargs.items(): - if not isinstance(value, (str, tuple)): + for key, value in list(kwargs.items()): + if isinstance(value, list): + value = [unravel_key(k) for k in value] + elif not isinstance(value, (str, tuple)): + if value is None: + raise ValueError("tensordict keys cannot be None") raise ValueError( f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}" ) - if value is None: - raise ValueError("tensordict keys cannot be None") + else: + value = unravel_key(value) + if key not in self._AcceptedKeys.__dict__: raise KeyError( f"{key} is not an accepted tensordict key for advantages" @@ -312,6 +322,7 @@ def set_keys(self, **kwargs) -> None: raise KeyError( f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}" ) + kwargs[key] = value if self._tensor_keys is None: conf = asdict(self.default_keys) conf.update(self.dep_keys) @@ -1765,12 +1776,11 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) - # Make sure we have the log prob computed at collection time - if self.tensor_keys.sample_log_prob not in tensordict.keys(): - raise ValueError( - f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" - ) - log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) + lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob) + if is_tensor_collection(lp): + # Sum all values to match the batch size + lp = lp.sum(dim="feature", reduce=True) + log_mu = lp.view_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network):