Skip to content

Commit

Permalink
Update AlphaLoss to support entropy schedules (#186)
Browse files Browse the repository at this point in the history
This PR adds support for entropy schedules in the AlphaLoss callback.
Also I removed the multiplication term in the default target entropy
value since it defaulted to a very small value (CC @AliGhadirzadeh) and
now we have the option to set the target value directly instead of
scaling the number of actions by changing `entropy_eps`.
  • Loading branch information
spetravic authored Nov 9, 2023
1 parent 3d2939f commit 36a1c3d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
23 changes: 15 additions & 8 deletions emote/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from typing import Any, Dict, Optional

import numpy as np
import torch

from torch import nn, optim

from emote.callback import Callback
from emote.callbacks.loss import LossCallback
from emote.extra.schedules import ConstantSchedule, Schedule
from emote.mixins.logging import LoggingMixin
from emote.proxies import AgentProxy, GenericAgentProxy
from emote.utils.deprecated import deprecated
Expand Down Expand Up @@ -245,17 +245,17 @@ class AlphaLoss(LossCallback):
probability given a state.
:param ln_alpha (torch.tensor): The current weight for the entropy part of the
soft Q.
:param lr_schedule (torch.optim.lr_scheduler._LRSchedule): Learning rate schedule
:param lr_schedule (torch.optim.lr_scheduler._LRSchedule | None): Learning rate schedule
for the optimizer of alpha.
:param opt (torch.optim.Optimizer): An optimizer for ln_alpha.
:param n_actions (int): The dimension of the action space. Scales the target
entropy.
:param max_grad_norm (float): Clip the norm of the gradient during backprop using
this value.
:param entropy_eps (float): Scaling value for the target entropy.
:param name (str): The name of the module. Used e.g. while logging.
:param data_group (str): The name of the data group from which this Loss takes its
data.
:param t_entropy (float | Schedule | None): Value or schedule for the target entropy.
"""

def __init__(
Expand All @@ -264,13 +264,13 @@ def __init__(
pi: nn.Module,
ln_alpha: torch.tensor,
opt: optim.Optimizer,
lr_schedule: Optional[optim.lr_scheduler._LRScheduler] = None,
lr_schedule: optim.lr_scheduler._LRScheduler | None = None,
n_actions: int,
max_grad_norm: float = 10.0,
entropy_eps: float = 0.089,
max_alpha: float = 0.2,
name: str = "alpha",
data_group: str = "default",
t_entropy: float | Schedule | None = None,
):
super().__init__(
name=name,
Expand All @@ -284,14 +284,20 @@ def __init__(
self._max_ln_alpha = torch.log(torch.tensor(max_alpha, device=ln_alpha.device))
# TODO(singhblom) Check number of actions
# self.t_entropy = -np.prod(self.env.action_space.shape).item() # Value from rlkit from Harnouja
self.t_entropy = n_actions * (1.0 + np.log(2.0 * np.pi * entropy_eps**2)) / 2.0
t_entropy = -n_actions if t_entropy is None else t_entropy
if not isinstance(t_entropy, (int, float, Schedule)):
raise TypeError("t_entropy must be a number or an instance of Schedule")

self.t_entropy = (
t_entropy if isinstance(t_entropy, Schedule) else ConstantSchedule(t_entropy)
)
self.ln_alpha = ln_alpha # This is log(alpha)

def loss(self, observation):
with torch.no_grad():
_, logp_pi = self.policy(**observation)
entropy = -logp_pi
error = entropy - self.t_entropy
error = entropy - self.t_entropy.value
alpha_loss = torch.mean(self.ln_alpha * error.detach())
assert alpha_loss.dim() == 0
self.log_scalar("loss/alpha_loss", alpha_loss)
Expand All @@ -304,7 +310,8 @@ def end_batch(self):
self.ln_alpha = torch.clamp_max_(self.ln_alpha, self._max_ln_alpha)
self.ln_alpha.requires_grad_(True)
self.log_scalar("training/alpha_value", torch.exp(self.ln_alpha).item())
self.log_scalar("training/target_entropy", self.t_entropy)
self.log_scalar("training/target_entropy", self.t_entropy.value)
self.t_entropy.step()

def state_dict(self):
state = super().state_dict()
Expand Down
47 changes: 47 additions & 0 deletions tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from emote.algorithms.sac import AlphaLoss, FeatureAgentProxy
from emote.extra.schedules import ConstantSchedule, CyclicSchedule
from emote.nn.gaussian_policy import GaussianMlpPolicy
from emote.typing import DictObservation, EpisodeState

Expand Down Expand Up @@ -67,3 +68,49 @@ def test_alpha_value_ref_valid_after_load():
assert (
ln_alpha_before_load is ln_alpha_after_load
), "expected ln(alpha) to be the same python object after loading. The reference is used by other loss functions such as PolicyLoss!"


def test_target_entropy_schedules():
policy = GaussianMlpPolicy(IN_DIM, OUT_DIM, [16, 16])
init_ln_alpha = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
optim = torch.optim.Adam([init_ln_alpha])
loss = AlphaLoss(pi=policy, ln_alpha=init_ln_alpha, opt=optim, n_actions=OUT_DIM)

# Check if default is set correctly when no t_entropy is passed
init_entropy = loss.t_entropy.value
assert init_entropy == -OUT_DIM
print(init_entropy)

# Check that default schedule is constant and doesn't update the value
assert isinstance(loss.t_entropy, ConstantSchedule)
for _ in range(5):
loss.end_batch()
assert init_entropy == loss.t_entropy.value

# Check that value is updated when using a schedule
start = 5
end = 0
steps = 5
schedule = CyclicSchedule(start, end, steps, mode="triangular")
loss = AlphaLoss(
pi=policy, ln_alpha=init_ln_alpha, opt=optim, n_actions=OUT_DIM, t_entropy=schedule
)

for _ in range(steps + 1):
loss.end_batch()
assert loss.t_entropy.value == end

for _ in range(steps):
loss.end_batch()
assert loss.t_entropy.value == start

# Check that invalid types are not accepted
invalid_t_entropy = torch.optim.lr_scheduler.LinearLR(optim, 1, end / start, steps)
with pytest.raises(TypeError):
AlphaLoss(
pi=policy,
ln_alpha=init_ln_alpha,
opt=optim,
n_actions=OUT_DIM,
t_entropy=invalid_t_entropy,
)

0 comments on commit 36a1c3d

Please sign in to comment.