Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent 010c84f commit c10adf1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 36 deletions.
16 changes: 7 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8158,18 +8158,19 @@ def _create_seq_mock_data_ppo(
obs = total_obs[:, :T]
next_obs = total_obs[:, 1:]
if atoms:
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
-1, 1
)
action_shape = (batch, T, atoms, action_dim)
else:
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
action_shape = (batch, T, action_dim)
params_mean = torch.randn(action_shape, device=device) / 10
params_scale = torch.rand(action_shape, device=device) / 10
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
-1, 1
)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
Expand All @@ -8184,9 +8185,6 @@ def _create_seq_mock_data_ppo(
},
"collector": {"mask": mask},
action_key: {"action1": action} if composite_action_dist else action,
sample_log_prob_key: (
torch.randn_like(action[..., 1]) / 10
).masked_fill_(~mask, 0.0),
},
device=device,
names=[None, "time"],
Expand Down
72 changes: 45 additions & 27 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,17 +510,28 @@ def _log_weight(
# current log_prob of actions
action = _maybe_get_or_select(tensordict, self.tensor_keys.action)

is_composite = None
if all(key in tensordict for key in self.actor_network.dist_params_keys):
prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
kwargs, is_composite = _get_composite_kwargs(prev_dist)
if is_composite:
prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
else:
prev_log_prob = prev_dist.log_prob(action, **kwargs)
print('prev_log_prob', prev_log_prob)
else:
try:
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
except KeyError as err:
raise _make_lp_get_error(self.tensor_keys, tensordict, err)

with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
current_dist = self.actor_network.get_dist(tensordict)

try:
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
except KeyError as err:
raise _make_lp_get_error(self.tensor_keys, tensordict, err)

if prev_log_prob.requires_grad:
raise RuntimeError(
Expand All @@ -532,25 +543,11 @@ def _log_weight(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
log_prob = current_dist.log_prob(action)
else:
if isinstance(dist, CompositeDistribution):
is_composite = True
aggregate = dist.aggregate_probabilities
if aggregate is None:
aggregate = False
include_sum = dist.include_sum
if include_sum is None:
include_sum = False
kwargs = {
"inplace": False,
"aggregate_probabilities": aggregate,
"include_sum": include_sum,
}
else:
is_composite = False
kwargs = {}
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
if is_composite is None:
kwargs, is_composite = _get_composite_kwargs(current_dist)
log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
if (
is_composite
and not is_tensor_collection(prev_log_prob)
Expand All @@ -564,7 +561,7 @@ def _log_weight(
if is_tensor_collection(kl_approx):
kl_approx = _sum_td_features(kl_approx)

return log_weight, dist, kl_approx
return log_weight, current_dist, kl_approx

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
Expand Down Expand Up @@ -640,6 +637,9 @@ def _cached_critic_network_params_detached(self):
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)

log_weight, dist, kl_approx = self._log_weight(tensordict)

advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
self.value_estimator(
Expand All @@ -653,7 +653,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale

log_weight, dist, kl_approx = self._log_weight(tensordict)
if is_tensor_collection(log_weight):
log_weight = _sum_td_features(log_weight)
log_weight = log_weight.view(advantage.shape)
Expand Down Expand Up @@ -1295,3 +1294,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err):
return KeyError(result)
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
return KeyError(result)

def _get_composite_kwargs(current_dist):
if isinstance(current_dist, CompositeDistribution):
is_composite = True
aggregate = current_dist.aggregate_probabilities
if aggregate is None:
aggregate = False
include_sum = current_dist.include_sum
if include_sum is None:
include_sum = False
kwargs = {
"inplace": False,
"aggregate_probabilities": aggregate,
"include_sum": include_sum,
}
else:
is_composite = False
kwargs = {}
return kwargs, is_composite

0 comments on commit c10adf1

Please sign in to comment.