Skip to content

Commit

Permalink
[Feature] Add scheduler for alpha/beta parameters of PrioritizedSampl…
Browse files Browse the repository at this point in the history
…er (#2452)

Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
LTluttmann and vmoens authored Sep 30, 2024
1 parent 6d1a1b3 commit 5851652
Show file tree
Hide file tree
Showing 3 changed files with 360 additions and 0 deletions.
77 changes: 77 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.scheduler import (
LinearScheduler,
SchedulerList,
StepScheduler,
)

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
Expand Down Expand Up @@ -100,6 +105,7 @@
VecNorm,
)


OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
_has_gym = importlib.util.find_spec("gym") is not None
Expand Down Expand Up @@ -3041,6 +3047,77 @@ def test_prioritized_slice_sampler_episodes(device):
), "after priority update, only episode 1 and 3 are expected to be sampled"


@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)])
@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)])
@pytest.mark.parametrize("gamma", [0.1])
@pytest.mark.parametrize("total_steps", [200])
@pytest.mark.parametrize("n_annealing_steps", [100])
@pytest.mark.parametrize("anneal_every_n", [10, 159])
@pytest.mark.parametrize("alpha_min", [0, 0.2])
@pytest.mark.parametrize("beta_max", [1, 1.4])
def test_prioritized_parameter_scheduler(
alpha,
beta,
gamma,
total_steps,
n_annealing_steps,
anneal_every_n,
alpha_min,
beta_max,
):
rb = TensorDictPrioritizedReplayBuffer(
alpha=alpha, beta=beta, storage=ListStorage(max_size=1000)
)
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
rb.extend(data)
alpha_scheduler = LinearScheduler(
rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps
)
beta_scheduler = StepScheduler(
rb,
param_name="beta",
gamma=gamma,
n_steps=anneal_every_n,
max_value=beta_max,
mode="additive",
)

scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))

alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha)
alpha_min = torch.tensor(alpha_min)
expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1)
expected_alpha_vals = torch.nn.functional.pad(
expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min
)

expected_beta_vals = [beta]
annealing_steps = total_steps // anneal_every_n
gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma
expected_beta_vals = (
(beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max)
)
for i in range(total_steps):
curr_alpha = rb.sampler.alpha
torch.testing.assert_close(
curr_alpha
if torch.is_tensor(curr_alpha)
else torch.tensor(curr_alpha).float(),
expected_alpha_vals[i],
msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}",
)
curr_beta = rb.sampler.beta
torch.testing.assert_close(
curr_beta
if torch.is_tensor(curr_beta)
else torch.tensor(curr_beta).float(),
expected_beta_vals[i],
msg=f"expected {expected_beta_vals[i]}, got {curr_beta}",
)
rb.sample(20)
scheduler.step()


class TestEnsemble:
def _make_data(self, data_type):
if data_type is torch.Tensor:
Expand Down
16 changes: 16 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def __repr__(self):
def max_size(self):
return self._max_capacity

@property
def alpha(self):
return self._alpha

@alpha.setter
def alpha(self, value):
self._alpha = value

@property
def beta(self):
return self._beta

@beta.setter
def beta(self, value):
self._beta = value

def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
Expand Down
267 changes: 267 additions & 0 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from abc import ABC, abstractmethod

from typing import Any, Callable, Dict

import numpy as np

import torch

from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler


class ParameterScheduler(ABC):
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.
Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.
Args:
obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to `None`.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to `None`.
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
min_value: int | float | None = None,
max_value: int | float | None = None,
):
if not isinstance(obj, (ReplayBuffer, Sampler)):
raise TypeError(
f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead."
)
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
self.param_name = param_name
self._min_val = min_value or float("-inf")
self._max_val = max_value or float("inf")
if not hasattr(self.sampler, self.param_name):
raise ValueError(
f"Provided class {type(obj).__name__} does not have an attribute {param_name}"
)
initial_val = getattr(self.sampler, self.param_name)
if isinstance(initial_val, torch.Tensor):
initial_val = initial_val.clone()
self.backend = torch
else:
self.backend = np
self.initial_val = initial_val
self._step_cnt = 0

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which
is not the sampler.
"""
sd = dict(self.__dict__)
del sd["sampler"]
return sd

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)

def step(self):
self._step_cnt += 1
# Apply the step function
new_value = self._step()
# clip value to specified range
new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val)
# Set the new value of the parameter dynamically
setattr(self.sampler, self.param_name, new_value_clipped)

@abstractmethod
def _step(self):
...


class LambdaScheduler(ParameterScheduler):
"""Sets a parameter to its initial value times a given function.
Similar to :class:`~torch.optim.LambdaLR`.
Args:
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter.
lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer
parameter ``step_count``.
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to `None`.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to `None`.
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
lambda_fn: Callable[[int], float],
min_value: int | float | None = None,
max_value: int | float | None = None,
):
super().__init__(obj, param_name, min_value, max_value)
self.lambda_fn = lambda_fn

def _step(self):
return self.initial_val * self.lambda_fn(self._step_cnt)


class LinearScheduler(ParameterScheduler):
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.
This scheduler linearly interpolates between the initial value of the parameter and a final target value.
Args:
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter.
final_value (number): The final value that the parameter will reach after the
specified number of steps.
num_steps (number, optional): The total number of steps over which the parameter
will be linearly altered.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.7 if step == 1
>>> # beta = 0.8 if step == 2
>>> # beta = 0.9 if step == 3
>>> # beta = 1.0 if step >= 4
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
final_value: int | float,
num_steps: int,
):
super().__init__(obj, param_name)
if isinstance(self.initial_val, torch.Tensor):
# cast to same type as initial value
final_value = torch.tensor(final_value).to(self.initial_val)
self.final_val = final_value
self.num_steps = num_steps
self._delta = (self.final_val - self.initial_val) / self.num_steps

def _step(self):
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
# without graph breaks
if self._step_cnt < self.num_steps:
return self.initial_val + (self._delta * self._step_cnt)
else:
return self.final_val


class StepScheduler(ParameterScheduler):
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
The scheduler can apply:
1. Multiplicative changes: `new_val = curr_val * gamma`
2. Additive changes: `new_val = curr_val + gamma`
Args:
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter.
gamma (int or float, optional): The value by which to adjust the parameter,
either in a multiplicative or additive way.
n_steps (int, optional): The number of steps after which the parameter should be altered.
Defaults to 1.
mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`.
Defaults to `'multiplicative'`.
min_value (int or float, optional): a lower bound for the parameter to be adjusted.
Defaults to `None`.
max_value (int or float, optional): an upper bound for the parameter to be adjusted.
Defaults to `None`.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.6 if 0 <= step < 10
>>> # beta = 0.7 if 10 <= step < 20
>>> # beta = 0.8 if 20 <= step < 30
>>> # beta = 0.9 if 30 <= step < 40
>>> # beta = 1.0 if 40 <= step
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
gamma: int | float = 0.9,
n_steps: int = 1,
mode: str = "multiplicative",
min_value: int | float | None = None,
max_value: int | float | None = None,
):

super().__init__(obj, param_name, min_value, max_value)
self.gamma = gamma
self.n_steps = n_steps
self.mode = mode
if mode == "additive":
operator = self.backend.add
elif mode == "multiplicative":
operator = self.backend.multiply
else:
raise ValueError(
f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'."
)
self.operator = operator

def _step(self):
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
# Check if the current step count is a multiple of n_steps
current_val = getattr(self.sampler, self.param_name)
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
# without graph breaks
if self._step_cnt % self.n_steps == 0:
return self.operator(current_val, self.gamma)
else:
return current_val


class SchedulerList:
"""Simple container abstracting a list of schedulers."""

def __init__(self, schedulers: list[ParameterScheduler]) -> None:
if isinstance(schedulers, ParameterScheduler):
schedulers = [schedulers]
self.schedulers = schedulers

def append(self, scheduler: ParameterScheduler):
self.schedulers.append(scheduler)

def step(self):
for scheduler in self.schedulers:
scheduler.step()

1 comment on commit 5851652

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 5851652 Previous: 6d1a1b3 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 713.0029736949961 iter/sec (stddev: 0.028560150917693494) 1615.1295495842237 iter/sec (stddev: 0.000029816661506490736) 2.27

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.