From 36545af5062821dada2cdb91594209442d3dd0e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Sep 2024 10:47:12 +0100 Subject: [PATCH] [BugFix] compatibility to new Composite dist log_prob/entropy APIs ghstack-source-id: a09b6c34000f57a66736bb9811ca3656c861ec0c Pull Request resolved: https://github.com/pytorch/rl/pull/2435 --- test/test_cost.py | 5 +++++ torchrl/objectives/a2c.py | 9 +++++++-- torchrl/objectives/ppo.py | 11 ++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index ab95c55ef83..b11cec924e3 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7565,6 +7565,7 @@ def _create_mock_actor( "action1": (action_key, "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -7634,6 +7635,7 @@ def _create_mock_actor_value( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -7690,6 +7692,7 @@ def _create_mock_actor_value_shared( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8627,6 +8630,7 @@ def _create_mock_actor( "action1": (action_key, "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8727,6 +8731,7 @@ def _create_mock_common_layer_setup( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index ff9b5f3883e..34c62bc3260 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -420,8 +420,13 @@ def _log_probs( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - tensordict = dist.log_prob(tensordict) - log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + maybe_log_prob = dist.log_prob(tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not + # be a tensor + log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = maybe_log_prob log_prob = log_prob.unsqueeze(-1) return log_prob, dist diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b4779a90663..9d9790ab294 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -490,8 +490,13 @@ def _log_weight( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - tensordict = dist.log_prob(tensordict) - log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + maybe_log_prob = dist.log_prob(tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not + # be a tensor + log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = maybe_log_prob log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -1130,7 +1135,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: x = previous_dist.sample((self.samples_mc_kl,)) previous_log_prob = previous_dist.log_prob(x) current_log_prob = current_dist.log_prob(x) - if is_tensor_collection(x): + if is_tensor_collection(current_log_prob): previous_log_prob = previous_log_prob.get( self.tensor_keys.sample_log_prob )