Skip to content

Commit

Permalink
[Feature] Extend MaxValueWriter with reduce parameter for the rank_key (
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Jan 14, 2024
1 parent 748526e commit b632be9
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 33 deletions.
48 changes: 44 additions & 4 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,11 +1374,11 @@ def test_replay_buffer_iter(size, drop_last):
assert i == (size - 1) // 3


@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
@pytest.mark.parametrize("device", get_default_devices())
class TestMaxValueWriter:
@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
@pytest.mark.parametrize("device", get_default_devices())
def test_max_value_writer(self, size, batch_size, reward_ranges, device):
torch.manual_seed(0)
rb = TensorDictReplayBuffer(
Expand Down Expand Up @@ -1448,6 +1448,10 @@ def test_max_value_writer(self, size, batch_size, reward_ranges, device):
sample = rb.sample()
assert (sample.get("key") != 0).all()

@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
@pytest.mark.parametrize("device", get_default_devices())
def test_max_value_writer_serialize(
self, size, batch_size, reward_ranges, device, tmpdir
):
Expand Down Expand Up @@ -1480,6 +1484,42 @@ def test_max_value_writer_serialize(
torch.tensor(other._current_top_values),
)

@pytest.mark.parametrize("size", [[], [1], [2, 3]])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("reduction", ["max", "min", "mean", "median", "sum"])
def test_max_value_writer_reduce(self, size, device, reduction):
torch.manual_seed(0)
batch_size = 4
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(1, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=batch_size,
writer=TensorDictMaxValueWriter(rank_key="key", reduction=reduction),
)

key = torch.rand(batch_size, *size)
obs = torch.rand(batch_size, *size)
td = TensorDict(
{"key": key, "obs": obs},
batch_size=batch_size,
device=device,
)
rb.extend(td)
sample = rb.sample()
if reduction == "max":
rank_key = torch.stack([k.max() for k in key.unbind(0)])
elif reduction == "min":
rank_key = torch.stack([k.min() for k in key.unbind(0)])
elif reduction == "mean":
rank_key = torch.stack([k.mean() for k in key.unbind(0)])
elif reduction == "median":
rank_key = torch.stack([k.median() for k in key.unbind(0)])
elif reduction == "sum":
rank_key = torch.stack([k.sum() for k in key.unbind(0)])

top_rank = torch.argmax(rank_key)
assert (sample.get("obs") == obs[top_rank]).all()


class TestMultiProc:
@staticmethod
Expand Down
20 changes: 1 addition & 19 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
StorageEnsemble,
)
from torchrl.data.replay_buffers.utils import (
_reduce,
_to_numpy,
_to_torch,
INT_CLASSES,
Expand Down Expand Up @@ -1093,25 +1094,6 @@ def __call__(self, list_of_tds):
return self.out


def _reduce(
tensor: torch.Tensor, reduction: str, dim: int | None = None
) -> Union[float, torch.Tensor]:
"""Reduces a tensor given the reduction method."""
if reduction == "max":
result = tensor.max(dim=dim)
elif reduction == "min":
result = tensor.min(dim=dim)
elif reduction == "mean":
result = tensor.mean(dim=dim)
elif reduction == "median":
result = tensor.median(dim=dim)
else:
raise NotImplementedError(f"Unknown reduction method {reduction}")
if isinstance(result, tuple):
result = result[0]
return result.item() if dim is None else result


def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:
"""Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together.
Expand Down
23 changes: 23 additions & 0 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# import tree
from __future__ import annotations

import typing
from typing import Any, Callable, Union

Expand Down Expand Up @@ -64,3 +66,24 @@ def _pin_memory(output: Any) -> Any:
return output.pin_memory()
else:
return output


def _reduce(
tensor: torch.Tensor, reduction: str, dim: int | None = None
) -> Union[float, torch.Tensor]:
"""Reduces a tensor given the reduction method."""
if reduction == "max":
result = tensor.max(dim=dim)
elif reduction == "min":
result = tensor.min(dim=dim)
elif reduction == "mean":
result = tensor.mean(dim=dim)
elif reduction == "median":
result = tensor.median(dim=dim)
elif reduction == "sum":
result = tensor.sum(dim=dim)
else:
raise NotImplementedError(f"Unknown reduction method {reduction}")
if isinstance(result, tuple):
result = result[0]
return result.item() if dim is None else result
26 changes: 16 additions & 10 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from tensordict.utils import _STRDTYPE2DTYPE
from torch import multiprocessing as mp

from .storages import Storage
from torchrl.data.replay_buffers.storages import Storage
from torchrl.data.replay_buffers.utils import _reduce


class Writer(ABC):
Expand Down Expand Up @@ -201,7 +202,10 @@ def extend(self, data: Sequence) -> torch.Tensor:
class TensorDictMaxValueWriter(Writer):
"""A Writer class for composable replay buffers that keeps the top elements based on some ranking key.
If rank_key is not provided, the key will be ``("next", "reward")``.
Args:
rank_key (str or tuple of str): the key to rank the elements by. Defaults to ``("next", "reward")``.
reduction (str): the reduction method to use if the rank key has more than one element.
Can be ``"max"``, ``"min"``, ``"mean"``, ``"median"`` or ``"sum"``.
Examples:
>>> import torch
Expand Down Expand Up @@ -237,11 +241,12 @@ class TensorDictMaxValueWriter(Writer):
19
"""

def __init__(self, rank_key=None, **kwargs) -> None:
def __init__(self, rank_key=None, reduction: str = "sum", **kwargs) -> None:
super().__init__(**kwargs)
self._cursor = 0
self._current_top_values = []
self._rank_key = rank_key
self._reduction = reduction
if self._rank_key is None:
self._rank_key = ("next", "reward")

Expand All @@ -261,7 +266,10 @@ def get_insert_index(self, data: Any) -> int:
rank_data = data.get("_data", default=data).get(self._rank_key)

# If time dimension, sum along it.
rank_data = rank_data.sum(-1).item()
if rank_data.numel() > 1:
rank_data = _reduce(rank_data.reshape(-1), self._reduction, dim=0)
else:
rank_data = rank_data.item()

if rank_data is None:
raise KeyError(f"Rank key {self._rank_key} not found in data.")
Expand Down Expand Up @@ -289,9 +297,8 @@ def get_insert_index(self, data: Any) -> int:
def add(self, data: Any) -> int:
"""Inserts a single element of data at an appropriate index, and returns that index.
The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where
:obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed
over the time dimension.
The ``rank_key`` in the data passed to this module should be structured as [].
If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
"""
index = self.get_insert_index(data)
if index is not None:
Expand All @@ -302,9 +309,8 @@ def add(self, data: Any) -> int:
def extend(self, data: Sequence) -> None:
"""Inserts a series of data points at appropriate indices.
The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is
the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the
time dimension.
The ``rank_key`` in the data passed to this module should be structured as [B].
If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
"""
data_to_replace = {}
for i, sample in enumerate(data):
Expand Down

0 comments on commit b632be9

Please sign in to comment.