Skip to content

Commit

Permalink
[Refactor] better AddStateIndependentNormalScale
Browse files Browse the repository at this point in the history
ghstack-source-id: b5911c0b4e023d3c8e20968732ff58da061f978b
Pull Request resolved: #1028
  • Loading branch information
vmoens committed Oct 4, 2024
1 parent 04faf40 commit b5f1cd2
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions tensordict/nn/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import warnings
from numbers import Number
from typing import Sequence, Union
from typing import Sequence

import numpy as np

Expand Down Expand Up @@ -137,10 +137,19 @@ class AddStateIndependentNormalScale(torch.nn.Module):
The scale parameters are mapped onto positive values using the specified ``scale_mapping``.
Args:
scale_shape (torch.Size or equivalent, optional): the shape of the scale parameter.
Defaults to ``torch.Size(())``.
Keyword Args:
scale_mapping (str, optional): positive mapping function to be used with the std.
default = "biased_softplus_1.0" (i.e. softplus map with bias such that fn(0.0) = 1.0)
choices: "softplus", "exp", "relu", "biased_softplus_1";
scale_lb (Number, optional): The minimum value that the variance can take. Default is 1e-4.
Defaults to ``"biased_softplus_1.0"`` (i.e. softplus map with bias such that fn(0.0) = 1.0)
choices: ``"softplus"``, ``"exp"``, ``"relu"``, ``"biased_softplus_1"``.
scale_lb (Number, optional): The minimum value that the variance can take.
Defaults to ``1e-4``.
device (torch.device, optional): the device of the module.
make_param (bool, optional): whether the scale should be a parameter (``True``)
or a buffer (``False``).
Defaults to ``True``.
Examples:
>>> from torch import nn
Expand All @@ -165,29 +174,50 @@ class AddStateIndependentNormalScale(torch.nn.Module):

def __init__(
self,
scale_shape: Union[torch.Size, int, tuple],
scale_shape: torch.Size | int | tuple = None,
*,
scale_mapping: str = "exp",
scale_lb: Number = 1e-4,
device: torch.device | None = None,
make_param: bool = True,
) -> None:

super().__init__()
if scale_shape is None:
scale_shape = torch.Size(())
self.scale_lb = scale_lb
if isinstance(scale_shape, int):
scale_shape = (scale_shape,)
self.scale_shape = scale_shape
self.scale_shape = torch.Size(scale_shape)
self.scale_mapping = scale_mapping
self.state_independent_scale = torch.nn.Parameter(torch.zeros(scale_shape))
if make_param:
self.state_independent_scale = torch.nn.Parameter(
torch.zeros(scale_shape, device=device)
)
else:
self.state_independent_scale = torch.nn.Buffer(
torch.zeros(scale_shape, device=device)
)

def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
loc, *others = tensors
def forward(
self, loc: torch.Tensor, *others: torch.Tensor
) -> tuple[torch.Tensor, ...]:
"""Forward of AddStateIndependentNormalScale.
Args:
loc (torch.Tensor): a location parameter.
*others: other unused parameters.
Returns:
a tuple of two or more tensors containing the ``(loc, scale, *others)`` values.
"""
if self.scale_shape != loc.shape[-len(self.scale_shape) :]:
raise RuntimeError(
f"Last dimensions of loc ({loc.shape[-len(self.scale_shape):]}) do not match the number of dimensions "
f"in scale ({self.state_independent_scale.shape})"
)

scale = torch.zeros_like(loc) + self.state_independent_scale
scale = self.state_independent_scale.expand_as(loc)
scale = mappings(self.scale_mapping)(scale).clamp_min(self.scale_lb)

return (loc, scale, *others)
Expand Down

0 comments on commit b5f1cd2

Please sign in to comment.