From 36a1c3dc3ecec95dfdb8e65f90621dd98dfa4cd2 Mon Sep 17 00:00:00 2001 From: Simona Petravic Date: Thu, 9 Nov 2023 14:10:37 +0100 Subject: [PATCH] Update AlphaLoss to support entropy schedules (#186) This PR adds support for entropy schedules in the AlphaLoss callback. Also I removed the multiplication term in the default target entropy value since it defaulted to a very small value (CC @AliGhadirzadeh) and now we have the option to set the target value directly instead of scaling the number of actions by changing `entropy_eps`. --- emote/algorithms/sac.py | 23 +++++++++++++------- tests/test_sac.py | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/emote/algorithms/sac.py b/emote/algorithms/sac.py index 90171d4c..8c63c78c 100644 --- a/emote/algorithms/sac.py +++ b/emote/algorithms/sac.py @@ -4,13 +4,13 @@ from typing import Any, Dict, Optional -import numpy as np import torch from torch import nn, optim from emote.callback import Callback from emote.callbacks.loss import LossCallback +from emote.extra.schedules import ConstantSchedule, Schedule from emote.mixins.logging import LoggingMixin from emote.proxies import AgentProxy, GenericAgentProxy from emote.utils.deprecated import deprecated @@ -245,17 +245,17 @@ class AlphaLoss(LossCallback): probability given a state. :param ln_alpha (torch.tensor): The current weight for the entropy part of the soft Q. - :param lr_schedule (torch.optim.lr_scheduler._LRSchedule): Learning rate schedule + :param lr_schedule (torch.optim.lr_scheduler._LRSchedule | None): Learning rate schedule for the optimizer of alpha. :param opt (torch.optim.Optimizer): An optimizer for ln_alpha. :param n_actions (int): The dimension of the action space. Scales the target entropy. :param max_grad_norm (float): Clip the norm of the gradient during backprop using this value. - :param entropy_eps (float): Scaling value for the target entropy. :param name (str): The name of the module. Used e.g. while logging. :param data_group (str): The name of the data group from which this Loss takes its data. + :param t_entropy (float | Schedule | None): Value or schedule for the target entropy. """ def __init__( @@ -264,13 +264,13 @@ def __init__( pi: nn.Module, ln_alpha: torch.tensor, opt: optim.Optimizer, - lr_schedule: Optional[optim.lr_scheduler._LRScheduler] = None, + lr_schedule: optim.lr_scheduler._LRScheduler | None = None, n_actions: int, max_grad_norm: float = 10.0, - entropy_eps: float = 0.089, max_alpha: float = 0.2, name: str = "alpha", data_group: str = "default", + t_entropy: float | Schedule | None = None, ): super().__init__( name=name, @@ -284,14 +284,20 @@ def __init__( self._max_ln_alpha = torch.log(torch.tensor(max_alpha, device=ln_alpha.device)) # TODO(singhblom) Check number of actions # self.t_entropy = -np.prod(self.env.action_space.shape).item() # Value from rlkit from Harnouja - self.t_entropy = n_actions * (1.0 + np.log(2.0 * np.pi * entropy_eps**2)) / 2.0 + t_entropy = -n_actions if t_entropy is None else t_entropy + if not isinstance(t_entropy, (int, float, Schedule)): + raise TypeError("t_entropy must be a number or an instance of Schedule") + + self.t_entropy = ( + t_entropy if isinstance(t_entropy, Schedule) else ConstantSchedule(t_entropy) + ) self.ln_alpha = ln_alpha # This is log(alpha) def loss(self, observation): with torch.no_grad(): _, logp_pi = self.policy(**observation) entropy = -logp_pi - error = entropy - self.t_entropy + error = entropy - self.t_entropy.value alpha_loss = torch.mean(self.ln_alpha * error.detach()) assert alpha_loss.dim() == 0 self.log_scalar("loss/alpha_loss", alpha_loss) @@ -304,7 +310,8 @@ def end_batch(self): self.ln_alpha = torch.clamp_max_(self.ln_alpha, self._max_ln_alpha) self.ln_alpha.requires_grad_(True) self.log_scalar("training/alpha_value", torch.exp(self.ln_alpha).item()) - self.log_scalar("training/target_entropy", self.t_entropy) + self.log_scalar("training/target_entropy", self.t_entropy.value) + self.t_entropy.step() def state_dict(self): state = super().state_dict() diff --git a/tests/test_sac.py b/tests/test_sac.py index 3ed4f443..40314c6c 100644 --- a/tests/test_sac.py +++ b/tests/test_sac.py @@ -6,6 +6,7 @@ import torch from emote.algorithms.sac import AlphaLoss, FeatureAgentProxy +from emote.extra.schedules import ConstantSchedule, CyclicSchedule from emote.nn.gaussian_policy import GaussianMlpPolicy from emote.typing import DictObservation, EpisodeState @@ -67,3 +68,49 @@ def test_alpha_value_ref_valid_after_load(): assert ( ln_alpha_before_load is ln_alpha_after_load ), "expected ln(alpha) to be the same python object after loading. The reference is used by other loss functions such as PolicyLoss!" + + +def test_target_entropy_schedules(): + policy = GaussianMlpPolicy(IN_DIM, OUT_DIM, [16, 16]) + init_ln_alpha = torch.tensor(0.0, dtype=torch.float32, requires_grad=True) + optim = torch.optim.Adam([init_ln_alpha]) + loss = AlphaLoss(pi=policy, ln_alpha=init_ln_alpha, opt=optim, n_actions=OUT_DIM) + + # Check if default is set correctly when no t_entropy is passed + init_entropy = loss.t_entropy.value + assert init_entropy == -OUT_DIM + print(init_entropy) + + # Check that default schedule is constant and doesn't update the value + assert isinstance(loss.t_entropy, ConstantSchedule) + for _ in range(5): + loss.end_batch() + assert init_entropy == loss.t_entropy.value + + # Check that value is updated when using a schedule + start = 5 + end = 0 + steps = 5 + schedule = CyclicSchedule(start, end, steps, mode="triangular") + loss = AlphaLoss( + pi=policy, ln_alpha=init_ln_alpha, opt=optim, n_actions=OUT_DIM, t_entropy=schedule + ) + + for _ in range(steps + 1): + loss.end_batch() + assert loss.t_entropy.value == end + + for _ in range(steps): + loss.end_batch() + assert loss.t_entropy.value == start + + # Check that invalid types are not accepted + invalid_t_entropy = torch.optim.lr_scheduler.LinearLR(optim, 1, end / start, steps) + with pytest.raises(TypeError): + AlphaLoss( + pi=policy, + ln_alpha=init_ln_alpha, + opt=optim, + n_actions=OUT_DIM, + t_entropy=invalid_t_entropy, + )