diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 7000ade01..cb6fb1739 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample TensorDictSequential TensorDictModuleWrapper CudaGraphModule + WrapModule Ensembles --------- diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 55590889a..e930ac75d 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -9,6 +9,7 @@ TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, + WrapModule, ) from tensordict.nn.distributions import ( AddStateIndependentNormalScale, diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 395141c0a..7a1a7a22b 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1278,12 +1278,45 @@ def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase: class WrapModule(TensorDictModuleBase): + """A wrapper around any callable that processes TensorDict instances. + + This wrapper is useful when building :class:`~tensordict.nn.TensorDictSequential` stacks and when a transform + requires the entire TensorDict instance to be visible. + + Args: + func (Callable[[TensorDictBase], TensorDictBase]): A callable function that takes in a TensorDictBase instance + and returns a transformed TensorDictBase instance. + + Keyword Args: + inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict + will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``. + + Examples: + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule + >>> seq = Seq( + ... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]), + ... WrapModule(lambda td: td.reshape(-1)), + ... ) + >>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4]) + >>> td = Seq(td) + >>> assert td.shape == (12,) + >>> assert (td["y"] == 2).all() + >>> assert td["y"].shape == (12, 5) + + """ + in_keys = [] out_keys = [] - def __init__(self, func): - self.func = func + def __init__( + self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False + ) -> None: super().__init__() + self.func = func + self.inplace = inplace - def forward(self, data): - return self.func(data) + def forward(self, data: TensorDictBase) -> TensorDictBase: + result = self.func(data) + if self.inplace and result is not data: + return data.update(result) + return result diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 79a890545..ebded544b 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -221,22 +221,6 @@ def from_distributions( self.inplace = inplace return self - @property - def aggregate_probabilities(self): - aggregate_probabilities = self._aggregate_probabilities - if aggregate_probabilities is None: - warnings.warn( - "The default value of `aggregate_probabilities` will change from `False` to `True` in v0.7. " - "Please pass this value explicitly to avoid this warning.", - FutureWarning, - ) - aggregate_probabilities = self._aggregate_probabilities = False - return aggregate_probabilities - - @aggregate_probabilities.setter - def aggregate_probabilities(self, value): - self._aggregate_probabilities = value - def sample(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) @@ -337,7 +321,7 @@ def log_prob( aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities`` from the class. include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. - Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). + Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default). Has no effect if ``aggregate_probabilities`` is set to ``True``. .. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor. @@ -356,6 +340,8 @@ def log_prob( """ if aggregate_probabilities is None: aggregate_probabilities = self.aggregate_probabilities + if aggregate_probabilities is None: + aggregate_probabilities = False if not aggregate_probabilities: return self.log_prob_composite( sample, include_sum=include_sum, inplace=inplace @@ -382,7 +368,7 @@ def log_prob_composite( Keyword Args: include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. - Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). + Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default). .. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor. @@ -451,7 +437,7 @@ def entropy( setting from the class. Determines whether to return a single summed entropy tensor or a TensorDict with individual entropies. Defaults to ``False`` if not set in the class. include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict. - Defaults to `self.inplace`, which is set through the class constructor. Has no effect if + Defaults to `self.include_sum`, which is set through the class constructor. Has no effect if `aggregate_probabilities` is set to `True`. .. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor. @@ -466,6 +452,8 @@ def entropy( """ if aggregate_probabilities is None: aggregate_probabilities = self.aggregate_probabilities + if aggregate_probabilities is None: + aggregate_probabilities = False if not aggregate_probabilities: return self.entropy_composite(samples_mc, include_sum=include_sum) se = 0.0 diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 61df65d2b..24c9fe246 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -11,6 +11,8 @@ from textwrap import indent from typing import Any, Dict, List, Optional +import torch + from tensordict._nestedkey import NestedKey from tensordict.nn import CompositeDistribution @@ -27,6 +29,8 @@ from torch.utils._contextlib import _DecoratorContextManager +from .. import is_tensor_collection + try: from torch.compiler import is_compiling except ImportError: @@ -108,51 +112,49 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: class ProbabilisticTensorDictModule(TensorDictModuleBase): """A probabilistic TD Module. - `ProbabilisticTensorDictModule` is a non-parametric module representing a - probability distribution. It reads the distribution parameters from an input - TensorDict using the specified `in_keys`. The output is sampled given some rule, - specified by the input :obj:`default_interaction_type` argument and the - :obj:`interaction_type()` global function. + `ProbabilisticTensorDictModule` is a non-parametric module representing a probability distribution. + It reads the distribution parameters from an input TensorDict using the specified `in_keys`. + The output is sampled given some rule, specified by the input :obj:`default_interaction_type` argument and the + :func:`~tensordict.nn.interaction_type` global function. :obj:`ProbabilisticTensorDictModule` can be used to construct the distribution - (through the :obj:`get_dist()` method) and/or sampling from this distribution - (through a regular :obj:`__call__()` to the module). + (through the :meth:`~.get_dist` method) and/or sampling from this distribution + (through a regular :meth:`~.forward` to the module). + + A ``ProbabilisticTensorDictModule`` instance has two main features: - A :obj:`ProbabilisticTensorDictModule` instance has two main features: - It reads and writes TensorDict objects - It uses a real mapping R^n -> R^m to create a distribution in R^d from - which values can be sampled or computed. + which values can be sampled or computed. - When the :obj:`__call__` / :obj:`forward` method is called, a distribution is + When the :meth:`~.forward` method is called, a distribution is created, and a value computed (using the 'mean', 'mode', 'median' attribute or the 'rsample', 'sample' method). The sampling step is skipped if the supplied TensorDict has all of the desired key-value pairs already. - By default, ProbabilisticTensorDictModule distribution class is a Delta - distribution, making ProbabilisticTensorDictModule a simple wrapper around + By default, the ``ProbabilisticTensorDictModule`` distribution class is a ``Delta`` + distribution, making ``ProbabilisticTensorDictModule`` a simple wrapper around a deterministic mapping function. Args: - in_keys (NestedKey or list of NestedKey or dict): key(s) that will be read from the - input TensorDict and used to build the distribution. Importantly, if it's an - list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by - the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for - the Normal distribution and similar. If in_keys is a dictionary, the keys - are the keys of the distribution and the values are the keys in the + in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict + and used to build the distribution. + Importantly, if it's a list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by + the distribution class of interest, e.g. ``"loc"`` and ``"scale"`` for + the :class:`~torch.distributions.Normal` distribution and similar. + If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys. - out_keys (NestedKey or list of NestedKey): keys where the sampled values will be - written. Importantly, if these keys are found in the input TensorDict, the - sampling step will be skipped. - default_interaction_mode (str, optional): *Deprecated* keyword-only argument. - Please use default_interaction_type instead. + out_keys (NestedKey | List[NestedKey] | None): key(s) where the sampled values will be written. + Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped. + + Keyword Args: default_interaction_type (InteractionType, optional): keyword-only argument. Default method to be used to retrieve the output value. Should be one of InteractionType: MODE, MEDIAN, MEAN or RANDOM (in which case the value is sampled randomly from the distribution). Default is MODE. - .. note:: - When a sample is drawn, the + .. note:: When a sample is drawn, the :class:`ProbabilisticTensorDictModule` instance will first look for the interaction mode dictated by the :func:`~tensordict.nn.probabilistic.interaction_type` @@ -170,7 +172,7 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): to get the value through a call to ``get_mode()``, ``get_median()`` or ``get_mean()`` if the method exists. - distribution_class (Type, optional): keyword-only argument. + distribution_class (Type or Callable[[Any], Distribution], optional): keyword-only argument. A :class:`torch.distributions.Distribution` class to be used for sampling. Default is :class:`~tensordict.nn.distributions.Delta`. @@ -184,6 +186,11 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): distribution_kwargs (dict, optional): keyword-only argument. Keyword-argument pairs to be passed to the distribution. + + .. note:: if your kwargs contain tensors that you would like to transfer to device with the module, or + tensors that should see their dtype modified when calling `module.to(dtype)`, you can wrap the kwargs + in a :class:`~tensordict.nn.TensorDictParams` to do this automatically. + return_log_prob (bool, optional): keyword-only argument. If ``True``, the log-probability of the distribution sample will be written in the tensordict with the key @@ -288,6 +295,7 @@ def __init__( log_prob_key: Optional[NestedKey] = "sample_log_prob", cache_dist: bool = False, n_empirical_estimate: int = 1000, + num_samples: int | torch.Size | None = None, ) -> None: super().__init__() distribution_kwargs = ( @@ -339,11 +347,24 @@ def __init__( self._dist = None self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False self.return_log_prob = return_log_prob + if isinstance(num_samples, (int, torch.SymInt)): + num_samples = torch.Size((num_samples,)) + self.num_samples = num_samples if self.return_log_prob and self.log_prob_key not in self.out_keys: self.out_keys.append(self.log_prob_key) def get_dist(self, tensordict: TensorDictBase) -> D.Distribution: - """Creates a :class:`torch.distribution.Distribution` instance with the parameters provided in the input tensordict.""" + """Creates a :class:`torch.distribution.Distribution` instance with the parameters provided in the input tensordict. + + Args: + tensordict (TensorDictBase): The input tensordict containing the distribution parameters. + + Returns: + A :class:`torch.distribution.Distribution` instance created from the input tensordict. + + Raises: + TypeError: If the input tensordict does not match the distribution keywords. + """ try: dist_kwargs = {} for dist_key, td_key in _zip_strict(self.dist_keys, self.in_keys): @@ -367,12 +388,63 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution: raise err return dist - def log_prob(self, tensordict): - """Writes the log-probability of the distribution sample.""" - dist = self.get_dist(tensordict) + def log_prob( + self, + tensordict, + *, + dist: torch.distributions.Distribution | None = None, + aggregate_probabilities: bool | None = None, + inplace: bool | None = None, + include_sum: bool | None = None, + ): + """Computes the log-probability of the distribution sample. + + Args: + tensordict (TensorDictBase): The input tensordict containing the distribution parameters. + dist (torch.distributions.Distribution, optional): The distribution instance. Defaults to ``None``. + If ``None``, the distribution will be computed using the `get_dist` method. + aggregate_probabilities (bool, optional): Whether to aggregate probabilities. Defaults to ``None``. + If ``None``, the value from the distribution will be used, if indicated, or ``False`` otherwise. + inplace (bool, optional): Whether to perform operations in-place. Defaults to ``None``. + If ``None``, the value from the distribution will be used, if indicated, or ``True`` otherwise. + include_sum (bool, optional): Whether to include the sum of probabilities. Defaults to ``None``. + If ``None``, the value from the distribution will be used, if indicated, or ``True`` otherwise. + + Returns: + A tensor representing the log-probability of the distribution sample. + """ + if dist is None: + dist = self.get_dist(tensordict) if isinstance(dist, CompositeDistribution): - td = dist.log_prob(tensordict, aggregate_probabilities=False) - return td.get(dist.log_prob_key) + # Check the values within the dist - if not set, choose defaults + if aggregate_probabilities is None: + if dist.aggregate_probabilities is not None: + aggregate_probabilities_inp = dist.aggregate_probabilities + else: + # TODO: warning + aggregate_probabilities_inp = False + else: + aggregate_probabilities_inp = aggregate_probabilities + if inplace is None: + if dist.inplace is not None: + inplace = dist.inplace + else: + # TODO: warning + inplace = True + if include_sum is None: + if dist.include_sum is not None: + include_sum = dist.include_sum + else: + # TODO: warning + include_sum = True + lp = dist.log_prob( + tensordict, + aggregate_probabilities=aggregate_probabilities_inp, + inplace=inplace, + include_sum=include_sum, + ) + if is_tensor_collection(lp) and aggregate_probabilities is None: + return lp.get(dist.log_prob_key) else: return dist.log_prob(tensordict.get(self.out_keys[0])) @@ -396,6 +468,11 @@ def forward( dist = self.get_dist(tensordict) if _requires_sample: out_tensors = self._dist_sample(dist, interaction_type=interaction_type()) + if self.num_samples is not None: + # TODO: capture contiguous error here + tensordict_out = tensordict_out.expand( + self.num_samples + tensordict_out.shape + ) if isinstance(out_tensors, TensorDictBase): if self.return_log_prob: kwargs = {} @@ -503,32 +580,214 @@ def _dist_sample( return dist.sample((self.n_empirical_estimate,)).mean(0) elif interaction_type is InteractionType.RANDOM: + num_samples = self.num_samples + if num_samples is None: + num_samples = torch.Size(()) if dist.has_rsample: - return dist.rsample() + return dist.rsample(num_samples) else: - return dist.sample() + return dist.sample(num_samples) else: raise NotImplementedError(f"unknown interaction_type {interaction_type}") class ProbabilisticTensorDictSequential(TensorDictSequential): - """A sequence of TensorDictModules ending in a ProbabilistictTensorDictModule. + """A sequence of :class:`~tensordict.nn.TensorDictModules` containing at least one :class:`~tensordict.nn.ProbabilisticTensorDictModule`. - Similarly to :obj:`TensorDictSequential`, but enforces that the final module in the - sequence is an :obj:`ProbabilisticTensorDictModule` and also exposes ``get_dist`` - method to recover the distribution object from the ``ProbabilisticTensorDictModule`` + This class extends :class:`~tensordict.nn.TensorDictSequential` and is typically configured with a sequence of + modules where the final module is an instance of :class:`~tensordict.nn.ProbabilisticTensorDictModule`. + However, it also supports configurations where one or more intermediate modules are instances of + :class:`~tensordict.nn.ProbabilisticTensorDictModule`, while the last module may or may not be probabilistic. + In all cases, it exposes the :meth:`~.get_dist` method to recover the distribution object from the + :class:`~tensordict.nn.ProbabilisticTensorDictModule` instances in the sequence. + + Multiple probabilistic modules can co-exist in a single ``ProbabilisticTensorDictSequential``. + If `return_composite` is ``False`` (default), only the last one will produce a distribution and the others + will be executed as regular :class:`~tensordict.nn.TensorDictModule` instances. + However, if a `ProbabilisticTensorDictModule` is not the last module in the sequence and `return_composite=False`, + a `ValueError` will be raised when trying to query the module. If `return_composite=True`, + all intermediate `ProbabilisticTensorDictModule` instances will contribute to a single + :class:`~tensordict.nn.CompositeDistribution` instance. + + Resulting log-probabilities will be conditional probabilities if samples are interdependent: + whenever + + .. math:: + Z = F(X, Y) + + then the log-probability of Z will be + + .. math:: + log(p(z | x, y)) Args: - modules (sequence of TensorDictModules): ordered sequence of TensorDictModule - instances, terminating in ProbabilisticTensorDictModule, to be run - sequentially. - partial_tolerant (bool, optional): if True, the input tensordict can miss some - of the input keys. If so, the only module that will be executed are those - who can be executed given the keys that are present. Also, if the input - tensordict is a lazy stack of tensordicts AND if partial_tolerant is - :obj:`True` AND if the stack does not have the required keys, then - TensorDictSequential will scan through the sub-tensordicts looking for those - that have the required keys, if any. + *modules (sequence of TensorDictModules): An ordered sequence of + :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`, + to be run sequentially. + + Keyword Args: + partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some + of the input keys. If so, only the modules that can be executed given the + keys that are present will be executed. Also, if the input tensordict is a + lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the stack + does not have the required keys, then TensorDictSequential will scan through + the sub-tensordicts looking for those that have the required keys, if any. + Defaults to ``False``. + return_composite (bool, optional): If True and multiple + :class:`~tensordict.nn.ProbabilisticTensorDictModule` or + :class:`~tensordict.nn.ProbabilisticTensorDictSequential` instances are found, + a :class:`~tensordict.nn.CompositeDistribution` instance will be used. + Otherwise, only the last module will be used to build the distribution. + Defaults to ``False``. + + .. warning:: The behaviour of :attr:`return_composite` will change in v0.9 + and default to True from there on. + + aggregate_probabilities (bool, optional): (:class:`~tensordict.nn.CompositeDistribution` + outputs only) If provided, overrides the default ``aggregate_probabilities`` + from the class. + include_sum (bool, optional): (:class:`~tensordict.nn.CompositeDistribution` + outputs only) Whether to include the summed log-probability in the output + TensorDict. Defaults to ``self.include_sum`` which is set through the class + constructor (True by default). Has no effect if ``aggregate_probabilities`` + is set to True. + + .. warning:: The default value of ``include_sum`` will switch to False in + v0.9 in the constructor. + + inplace (bool, optional): (:class:`~tensordict.nn.CompositeDistribution` + outputs only) Whether to update the input sample in-place or return a new + TensorDict. Defaults to ``self.inplace`` which is set through the class + constructor (True by default). Has no effect if ``aggregate_probabilities`` + is set to True. + + .. warning:: The default value of ``inplace`` will switch to False in v0.9 + in the constructor. + + Raises: + ValueError: If the input sequence of modules is empty. + TypeError: If the final module is not an instance of + :obj:`ProbabilisticTensorDictModule` or + :obj:`ProbabilisticTensorDictSequential`. + + Examples: + >>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq + >>> import torch + >>> # Typical usage: a single distribution is computed last in the sequence + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq, \ + ... TensorDictModule as Mod + >>> torch.manual_seed(0) + >>> + >>> module = Seq( + ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + ... Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal, + ... distribution_kwargs={"scale": 1}), + ... ) + >>> input = TensorDict(x=torch.ones(3)) + >>> td = module(input.copy()) + >>> print(td) + TensorDict( + fields={ + loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(module.get_dist(input)) + Normal(loc: torch.Size([3]), scale: torch.Size([3])) + >>> print(module.log_prob(td)) + tensor([-0.9189, -0.9189, -0.9189]) + >>> # Intermediate distributions are ignored when return_composite=False + >>> module = Seq( + ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, + ... distribution_kwargs={"scale": 1}), + ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), + ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, + ... distribution_kwargs={"scale": 1}), + ... return_composite=False, + ... ) + >>> td = module(TensorDict(x=torch.ones(3))) + >>> print(td) + TensorDict( + fields={ + loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(module.get_dist(input)) + Normal(loc: torch.Size([3]), scale: torch.Size([3])) + >>> print(module.log_prob(td)) + tensor([-0.9189, -0.9189, -0.9189]) + >>> # Intermediate distributions produce a CompositeDistribution when return_composite=True + >>> module = Seq( + ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, + ... distribution_kwargs={"scale": 1}), + ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), + ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, + ... distribution_kwargs={"scale": 1}), + ... return_composite=True, + ... ) + >>> input = TensorDict(x=torch.ones(3)) + >>> td = module(input.copy()) + >>> print(td) + TensorDict( + fields={ + loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(module.get_dist(input)) + CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) + >>> print(module.log_prob(td, aggregate_probabilities=False)) + TensorDict( + fields={ + sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when + >>> # return_composite=True + >>> module = Seq( + ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, + ... distribution_kwargs={"scale": 1}), + ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]), + ... return_composite=True, + ... ) + >>> td = module(TensorDict(x=torch.ones(3))) + >>> print(td) + TensorDict( + fields={ + loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(module.get_dist(input)) + CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) + >>> print(module.log_prob(td, aggregate_probabilities=False, inplace=False, include_sum=False)) + TensorDict( + fields={ + sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) """ @@ -536,20 +795,24 @@ def __init__( self, *modules: TensorDictModuleBase | ProbabilisticTensorDictModule, partial_tolerant: bool = False, + return_composite: bool | None = None, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, + inplace: bool | None = None, ) -> None: if len(modules) == 0: raise ValueError( "ProbabilisticTensorDictSequential must consist of zero or more " "TensorDictModules followed by a ProbabilisticTensorDictModule" ) - if not isinstance( + if not return_composite and not isinstance( modules[-1], (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), ): raise TypeError( "The final module passed to ProbabilisticTensorDictSequential must be " "an instance of ProbabilisticTensorDictModule or another " - "ProbabilisticTensorDictSequential" + "ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)." ) # if the modules not including the final probabilistic module return the sampled # key we wont be sampling it again, in that case @@ -559,6 +822,12 @@ def __init__( self._requires_sample = modules[-1].out_keys[0] not in set(out_keys) self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) super().__init__(*modules, partial_tolerant=partial_tolerant) + self.return_composite = return_composite + self.aggregate_probabilities = aggregate_probabilities + self.include_sum = include_sum + self.inplace = inplace + + _dist_sample = ProbabilisticTensorDictModule._dist_sample @property def det_part(self): @@ -570,39 +839,263 @@ def get_dist_params( tensordict_out: TensorDictBase | None = None, **kwargs, ) -> tuple[D.Distribution, TensorDictBase]: + """Returns the distribution parameters and output tensordict. + + This method runs the deterministic part of the :class:`~tensordict.nn.ProbabilisticTensorDictSequential` + module to obtain the distribution parameters. The interaction type is set to the current global + interaction type if available, otherwise it defaults to the interaction type of the last module. + + Args: + tensordict (TensorDictBase): The input tensordict. + tensordict_out (TensorDictBase, optional): The output tensordict. If ``None``, a new tensordict will be created. + Defaults to ``None``. + + Keyword Args: + **kwargs: Additional keyword arguments passed to the deterministic part of the module. + + Returns: + tuple[D.Distribution, TensorDictBase]: A tuple containing the distribution object and the output tensordict. + + .. note:: The interaction type is temporarily set to the specified value during the execution of this method. + """ tds = self.det_part type = interaction_type() if type is None: - type = self.module[-1].default_interaction_type + for m in reversed(self.module): + if hasattr(m, "default_interaction_type"): + type = m.default_interaction_type + break + else: + raise ValueError("Could not find a default interaction in the modules.") with set_interaction_type(type): return tds(tensordict, tensordict_out, **kwargs) + @property + def num_samples(self): + num_samples = () + for tdm in self.module: + if isinstance( + tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) + ): + num_samples = tdm.num_samples + num_samples + return num_samples + def get_dist( self, tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, **kwargs, ) -> D.Distribution: - """Get the distribution that results from passing the input tensordict through the sequence, and then using the resulting parameters.""" - tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs) - return self.build_dist_from_params(tensordict_out) + """Returns the distribution resulting from passing the input tensordict through the sequence. + + If `return_composite` is ``False`` (default), this method will only consider the last probabilistic module in the sequence. + + Otherwise, it will return a :class:`~tensordict.nn.CompositeDistribution` instance containing the distributions of all probabilistic modules. + + Args: + tensordict (TensorDictBase): The input tensordict. + tensordict_out (TensorDictBase, optional): The output tensordict. If ``None``, a new tensordict will be created. + Defaults to ``None``. + + Keyword Args: + **kwargs: Additional keyword arguments passed to the underlying modules. + + Returns: + D.Distribution: The resulting distribution object. + + Raises: + RuntimeError: If no probabilistic module is found in the sequence. + + .. note:: + When `return_composite` is ``True``, the distributions are conditioned on the previous samples in the sequence. + This means that if a module depends on the output of a previous probabilistic module, its distribution will be conditional. + + """ + if not self.return_composite: + tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs) + return self.build_dist_from_params(tensordict_out) + + td_copy = tensordict.copy() + dists = {} + for i, tdm in enumerate(self.module): + if isinstance( + tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) + ): + dist = tdm.get_dist(td_copy) + if i < len(self.module) - 1: + sample = tdm._dist_sample(dist, interaction_type=interaction_type()) + if tdm.num_samples not in ((), None): + td_copy = td_copy.expand(tdm.num_samples + td_copy.shape) + if isinstance(tdm, ProbabilisticTensorDictModule): + if isinstance(sample, torch.Tensor): + sample = [sample] + for val, key in zip(sample, tdm.out_keys): + td_copy.set(key, val) + else: + td_copy.update(sample) + dists[tdm.out_keys[0]] = dist + else: + td_copy = tdm(td_copy) + if len(dists) == 0: + raise RuntimeError(f"No distribution module found in {self}.") + # elif len(dists) == 1: + # return dist + return CompositeDistribution.from_distributions( + td_copy, + dists, + aggregate_probabilities=self.aggregate_probabilities, + inplace=self.inplace, + include_sum=self.include_sum, + ) def log_prob( - self, tensordict, tensordict_out: TensorDictBase | None = None, **kwargs + self, + tensordict, + tensordict_out: TensorDictBase | None = None, + *, + dist: torch.distributions.Distribution | None = None, + aggregate_probabilities: bool | None = None, + inplace: bool | None = None, + include_sum: bool | None = None, + **kwargs, ): - tensordict_out = self.get_dist_params( - tensordict, - tensordict_out, - **kwargs, - ) - return self.module[-1].log_prob(tensordict_out) + """Returns the log-probability of the input tensordict. + + If `return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`, + this method will return the log-probability of the entire composite distribution. + + Otherwise, it will only consider the last probabilistic module in the sequence. + + Args: + tensordict (TensorDictBase): The input tensordict. + tensordict_out (TensorDictBase, optional): The output tensordict. If ``None``, a new tensordict will be created. + Defaults to ``None``. + + Keyword Args: + dist (torch.distributions.Distribution, optional): The distribution object. If ``None``, it will be computed using `get_dist`. + Defaults to ``None``. + aggregate_probabilities (bool, optional): Whether to aggregate the probabilities of the composite distribution. + If ``None``, it will default to the value set in the constructor or the distribution object. + Defaults to ``None``. + inplace (bool, optional): Whether to update the input tensordict in-place or return a new tensordict. + If ``None``, it will default to the value set in the constructor or the distribution object. + Defaults to ``None``. + include_sum (bool, optional): Whether to include the summed log-probability in the output tensordict. + If ``None``, it will default to the value set in the constructor or the distribution object. + Defaults to ``None``. + + Returns: + TensorDictBase or torch.Tensor: The log-probability of the input tensordict. + + .. note:: + If `aggregate_probabilities` is ``True``, the log-probability will be aggregated across all components of the composite distribution. + If `inplace` is ``True``, the input tensordict will be updated in-place with the log-probability values. + If `include_sum` is ``True``, the summed log-probability will be included in the output tensordict. + + .. warning:: + In future releases (v0.9), the default values of `aggregate_probabilities`, `inplace`, and `include_sum` will change. + To avoid warnings, it is recommended to explicitly pass these arguments to the `log_prob` method or set them in the constructor. + + """ + if tensordict_out is not None: + tensordict_inp = tensordict.copy() + else: + tensordict_inp = tensordict + if dist is None: + dist = self.get_dist(tensordict_inp) + if self.return_composite and isinstance(dist, CompositeDistribution): + # Check the values within the dist - if not set, choose defaults + if aggregate_probabilities is None: + if self.aggregate_probabilities is not None: + aggregate_probabilities = self.aggregate_probabilities + elif dist.aggregate_probabilities is not None: + aggregate_probabilities = dist.aggregate_probabilities + else: + warnings.warn( + f"aggregate_probabilities wasn't defined in the {type(self).__name__} instance. " + f"It couldn't be retrieved from the CompositeDistribution object either. " + f"Currently, the aggregate_probability will be `True` in this case but in a future release " + f"(v0.9) this will change and `aggregate_probabilities` will default to ``False`` such " + f"that log_prob will return a tensordict with the log-prob values. To silence this warning, " + f"pass `aggregate_probabilities` to the {type(self).__name__} constructor, to the distribution kwargs " + f"or to the log-prob method.", + category=DeprecationWarning, + ) + aggregate_probabilities = True + if inplace is None: + if self.inplace is not None: + inplace = self.inplace + elif dist.inplace is not None: + inplace = dist.inplace + else: + warnings.warn( + f"inplace wasn't defined in the {type(self).__name__} instance. " + f"It couldn't be retrieved from the CompositeDistribution object either. " + f"Currently, the `inplace` will be `True` in this case but in a future release " + f"(v0.9) this will change and `inplace` will default to ``False`` such " + f"that log_prob will return a new tensordict containing only the log-prob values. To silence this warning, " + f"pass `inplace` to the {type(self).__name__} constructor, to the distribution kwargs " + f"or to the log-prob method.", + category=DeprecationWarning, + ) + inplace = True + if include_sum is None: + if self.include_sum is not None: + include_sum = self.include_sum + elif dist.include_sum is not None: + include_sum = dist.include_sum + else: + warnings.warn( + f"include_sum wasn't defined in the {type(self).__name__} instance. " + f"It couldn't be retrieved from the CompositeDistribution object either. " + f"Currently, the `include_sum` will be `True` in this case but in a future release " + f"(v0.9) this will change and `include_sum` will default to ``False`` such " + f"that log_prob will return a new tensordict containing only the leaf log-prob values. " + f"To silence this warning, " + f"pass `include_sum` to the {type(self).__name__} constructor, to the distribution kwargs " + f"or to the log-prob method.", + category=DeprecationWarning, + ) + include_sum = True + return dist.log_prob( + tensordict, + aggregate_probabilities=aggregate_probabilities, + inplace=inplace, + include_sum=include_sum, + **kwargs, + ) + last_module: ProbabilisticTensorDictModule = self.module[-1] + out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs) + if is_tensor_collection(out): + if tensordict_out is not None: + if out is tensordict_inp: + tensordict_out.update( + tensordict_inp.apply( + lambda x, y: x if x is not y else None, + tensordict, + filter_empty=True, + ) + ) + else: + tensordict_out.update(out) + else: + tensordict_out = out + return tensordict_out + return out def build_dist_from_params(self, tensordict: TensorDictBase) -> D.Distribution: - """Construct a distribution from the input parameters. Other modules in the sequence are not evaluated. + """Constructs a distribution from the input parameters without evaluating other modules in the sequence. + + This method searches for the last :class:`~tensordict.nn.ProbabilisticTensorDictModule` in the sequence and uses it to build the distribution. - This method will look for the last ProbabilisticTensorDictModule contained in the - sequence and use it to build the distribution. + Args: + tensordict (TensorDictBase): The input tensordict containing the distribution parameters. + Returns: + D.Distribution: The constructed distribution object. + + Raises: + RuntimeError: If no :class:`~tensordict.nn.ProbabilisticTensorDictModule` is found in the sequence. """ dest_module = None for module in reversed(list(self.modules())): @@ -629,17 +1122,39 @@ def forward( tensordict_exec = tensordict.copy() else: tensordict_exec = tensordict - tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs) - tensordict_exec = self.module[-1]( - tensordict_exec, _requires_sample=self._requires_sample - ) + if self.return_composite: + for m in self.module: + if isinstance( + m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule) + ): + tensordict_exec = m( + tensordict_exec, _requires_sample=self._requires_sample + ) + else: + tensordict_exec = m(tensordict_exec, **kwargs) + else: + tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs) + tensordict_exec = self.module[-1]( + tensordict_exec, _requires_sample=self._requires_sample + ) if tensordict_out is not None: result = tensordict_out result.update(tensordict_exec, keys_to_update=self.out_keys) else: result = tensordict_exec if self._select_before_return: - return tensordict.update(result, keys_to_update=self.out_keys) + # We must also update any value that has been updated during the course of execution + # from the input data. + if is_compiling(): + keys = [ # noqa: C416 + k + for k in {k for k in self.out_keys}.union( # noqa: C416 + {k for k in tensordict.keys(True, True)} # noqa: C416 + ) + ] + else: + keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) + return tensordict.update(result, keys_to_update=keys) return result diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 5f8c84bb9..faa2f60a0 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -35,6 +35,10 @@ ) FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling __all__ = ["TensorDictSequential"] @@ -489,7 +493,18 @@ def forward( else: result = tensordict_exec if self._select_before_return: - return tensordict.update(result.select(*self.out_keys)) + # We must also update any value that has been updated during the course of execution + # from the input data. + if is_compiling(): + keys = [ # noqa: C416 + k + for k in {k for k in self.out_keys}.union( # noqa: C416 + {k for k in tensordict.keys(True, True)} # noqa: C416 + ) + ] + else: + keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) + return tensordict.update(result, keys_to_update=keys) return result def __len__(self) -> int: diff --git a/test/test_nn.py b/test/test_nn.py index 510441207..0d6085ef9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1041,6 +1041,128 @@ def test_probtdseq(self, return_log_prob, td_out): == expected ) + @pytest.mark.parametrize("aggregate_probabilities", [None, False, True]) + @pytest.mark.parametrize("inplace", [None, False, True]) + @pytest.mark.parametrize("include_sum", [None, False, True]) + def test_probtdseq_multdist(self, include_sum, aggregate_probabilities, inplace): + + tdm0 = TensorDictModule(torch.nn.Linear(3, 4), in_keys=["x"], out_keys=["loc"]) + tdm1 = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["y"], + distribution_class=torch.distributions.Normal, + distribution_kwargs={"scale": 1}, + default_interaction_type="random", + ) + tdm2 = TensorDictModule(torch.nn.Linear(4, 5), in_keys=["y"], out_keys=["loc2"]) + tdm3 = ProbabilisticTensorDictModule( + in_keys={"loc": "loc2"}, + out_keys=["z"], + distribution_class=torch.distributions.Normal, + distribution_kwargs={"scale": 1}, + default_interaction_type="random", + ) + + tdm = ProbabilisticTensorDictSequential( + tdm0, + tdm1, + tdm2, + tdm3, + include_sum=include_sum, + aggregate_probabilities=aggregate_probabilities, + inplace=inplace, + return_composite=True, + ) + dist: CompositeDistribution = tdm.get_dist(TensorDict(x=torch.randn(10, 3))) + s = dist.sample() + assert dist.aggregate_probabilities is aggregate_probabilities + assert dist.inplace is inplace + assert dist.include_sum is include_sum + if aggregate_probabilities in (None, False): + assert isinstance(dist.log_prob(s), TensorDict) + else: + assert isinstance(dist.log_prob(s), torch.Tensor) + + v = tdm(TensorDict(x=torch.randn(10, 3))) + assert set(v.keys()) == {"x", "loc", "y", "loc2", "z"} + if aggregate_probabilities is None: + cm0 = pytest.warns( + expected_warning=DeprecationWarning, match="aggregate_probabilities" + ) + else: + cm0 = contextlib.nullcontext() + if include_sum is None: + cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum") + else: + cm1 = contextlib.nullcontext() + if inplace is None: + cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace") + else: + cm2 = contextlib.nullcontext() + with cm0, cm1, cm2: + if aggregate_probabilities in (None, True): + assert isinstance(tdm.log_prob(v), torch.Tensor) + else: + assert isinstance(tdm.log_prob(v), TensorDict) + + @pytest.mark.parametrize("aggregate_probabilities", [None, False, True]) + @pytest.mark.parametrize("inplace", [None, False, True]) + @pytest.mark.parametrize("include_sum", [None, False, True]) + def test_probtdseq_intermediate_dist( + self, include_sum, aggregate_probabilities, inplace + ): + tdm0 = TensorDictModule(torch.nn.Linear(3, 4), in_keys=["x"], out_keys=["loc"]) + tdm1 = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["y"], + distribution_class=torch.distributions.Normal, + distribution_kwargs={"scale": 1}, + default_interaction_type="random", + ) + tdm2 = TensorDictModule(torch.nn.Linear(4, 5), in_keys=["y"], out_keys=["loc2"]) + tdm = ProbabilisticTensorDictSequential( + tdm0, + tdm1, + tdm2, + include_sum=include_sum, + aggregate_probabilities=aggregate_probabilities, + inplace=inplace, + return_composite=True, + ) + dist: CompositeDistribution = tdm.get_dist(TensorDict(x=torch.randn(10, 3))) + assert isinstance(dist, CompositeDistribution) + + s = dist.sample() + assert dist.aggregate_probabilities is aggregate_probabilities + assert dist.inplace is inplace + assert dist.include_sum is include_sum + if aggregate_probabilities in (None, False): + assert isinstance(dist.log_prob(s), TensorDict) + else: + assert isinstance(dist.log_prob(s), torch.Tensor) + + v = tdm(TensorDict(x=torch.randn(10, 3))) + assert set(v.keys()) == {"x", "loc", "y", "loc2"} + if aggregate_probabilities is None: + cm0 = pytest.warns( + expected_warning=DeprecationWarning, match="aggregate_probabilities" + ) + else: + cm0 = contextlib.nullcontext() + if include_sum is None: + cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum") + else: + cm1 = contextlib.nullcontext() + if inplace is None: + cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace") + else: + cm2 = contextlib.nullcontext() + with cm0, cm1, cm2: + if aggregate_probabilities in (None, True): + assert isinstance(tdm.log_prob(v), torch.Tensor) + else: + assert isinstance(tdm.log_prob(v), TensorDict) + @pytest.mark.parametrize("lazy", [True, False]) def test_stateful_probabilistic(self, lazy): torch.manual_seed(0) @@ -1744,100 +1866,225 @@ def test_module_buffer(): assert module.td.device.type == "cuda" -@pytest.mark.parametrize( - "log_prob_key", - [ - None, - "sample_log_prob", - ("nested", "sample_log_prob"), - ("data", "sample_log_prob"), - ], -) -def test_nested_keys_probabilistic_delta(log_prob_key): - policy_module = TensorDictModule( - nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] - ) - td = TensorDict({"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]) - - module = ProbabilisticTensorDictModule( - in_keys=[("data", "param")], - out_keys=[("data", "action")], - distribution_class=Delta, - return_log_prob=True, - log_prob_key=log_prob_key, - ) - td_out = module(policy_module(td)) - assert td_out["data", "action"].shape == (3, 4, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4) - else: - assert td_out["sample_log_prob"].shape == (3, 4) - - module = ProbabilisticTensorDictModule( - in_keys={"param": ("data", "param")}, - out_keys=[("data", "action")], - distribution_class=Delta, - return_log_prob=True, - log_prob_key=log_prob_key, - ) - td_out = module(policy_module(td)) - assert td_out["data", "action"].shape == (3, 4, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4) - else: - assert td_out["sample_log_prob"].shape == (3, 4) +class TestProbabilisticTensorDictModule: + @pytest.mark.parametrize("return_log_prob", [True, False]) + def test_probabilistic_n_samples(self, return_log_prob): + prob = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["sample"], + distribution_class=Normal, + distribution_kwargs={"scale": 1}, + return_log_prob=return_log_prob, + num_samples=2, + default_interaction_type="random", + ) + # alone + td = TensorDict(loc=torch.randn(3, 4), batch_size=[3]) + td = prob(td) + assert "sample" in td + assert td.shape == (2, 3) + assert td["sample"].shape == (2, 3, 4) + if return_log_prob: + assert "sample_log_prob" in td + @pytest.mark.parametrize("return_log_prob", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("include_sum", [True, False]) + @pytest.mark.parametrize("aggregate_probabilities", [True, False]) + @pytest.mark.parametrize("return_composite", [True, False]) + def test_probabilistic_seq_n_samples( + self, + return_log_prob, + return_composite, + include_sum, + aggregate_probabilities, + inplace, + ): + prob = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["sample"], + distribution_class=Normal, + distribution_kwargs={"scale": 1}, + return_log_prob=return_log_prob, + num_samples=2, + default_interaction_type="random", + ) + # in a sequence + seq = ProbabilisticTensorDictSequential( + TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + prob, + inplace=inplace, + aggregate_probabilities=aggregate_probabilities, + include_sum=include_sum, + return_composite=return_composite, + ) + td = TensorDict(x=torch.randn(3, 4), batch_size=[3]) + if return_composite: + assert isinstance(seq.get_dist(td), CompositeDistribution) + else: + assert isinstance(seq.get_dist(td), Normal) + td = seq(td) + assert "sample" in td + assert td.shape == (2, 3) + assert td["sample"].shape == (2, 3, 4) + if return_log_prob: + assert "sample_log_prob" in td -@pytest.mark.parametrize( - "log_prob_key", - [ - None, - "sample_log_prob", - ("nested", "sample_log_prob"), - ("data", "sample_log_prob"), - ], -) -def test_nested_keys_probabilistic_normal(log_prob_key): - loc_module = TensorDictModule( - nn.Linear(1, 1), - in_keys=[("data", "states")], - out_keys=[("data", "loc")], - ) - scale_module = TensorDictModule( - nn.Linear(1, 1), - in_keys=[("data", "states")], - out_keys=[("data", "scale")], - ) - td = TensorDict({"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3]) - - module = ProbabilisticTensorDictModule( - in_keys=[("data", "loc"), ("data", "scale")], - out_keys=[("data", "action")], - distribution_class=Normal, - return_log_prob=True, - log_prob_key=log_prob_key, + # log-prob from the sequence + log_prob = seq.log_prob(td) + if aggregate_probabilities or not return_composite: + assert isinstance(log_prob, torch.Tensor) + else: + assert isinstance(log_prob, TensorDict) + + @pytest.mark.parametrize("return_log_prob", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("include_sum", [True, False]) + @pytest.mark.parametrize("aggregate_probabilities", [True, False]) + @pytest.mark.parametrize("return_composite", [True]) + def test_intermediate_probabilistic_seq_n_samples( + self, + return_log_prob, + return_composite, + include_sum, + aggregate_probabilities, + inplace, + ): + prob = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["sample"], + distribution_class=Normal, + distribution_kwargs={"scale": 1}, + return_log_prob=return_log_prob, + num_samples=2, + default_interaction_type="random", + ) + + # intermediate in a sequence + seq = ProbabilisticTensorDictSequential( + TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), + prob, + TensorDictModule( + lambda x: x + 1, in_keys=["sample"], out_keys=["new_sample"] + ), + inplace=inplace, + aggregate_probabilities=aggregate_probabilities, + include_sum=include_sum, + return_composite=return_composite, + ) + td = TensorDict(x=torch.randn(3, 4), batch_size=[3]) + assert isinstance(seq.get_dist(td), CompositeDistribution) + td = seq(td) + assert "sample" in td + assert td.shape == (2, 3) + assert td["sample"].shape == (2, 3, 4) + if return_log_prob: + assert "sample_log_prob" in td + + # log-prob from the sequence + log_prob = seq.log_prob(td) + if aggregate_probabilities or not return_composite: + assert isinstance(log_prob, torch.Tensor) + else: + assert isinstance(log_prob, TensorDict) + + @pytest.mark.parametrize( + "log_prob_key", + [ + None, + "sample_log_prob", + ("nested", "sample_log_prob"), + ("data", "sample_log_prob"), + ], ) - with pytest.warns(UserWarning, match="deterministic_sample"): - td_out = module(loc_module(scale_module(td))) + def test_nested_keys_probabilistic_delta(self, log_prob_key): + policy_module = TensorDictModule( + nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] + ) + td = TensorDict( + {"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3] + ) + + module = ProbabilisticTensorDictModule( + in_keys=[("data", "param")], + out_keys=[("data", "action")], + distribution_class=Delta, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + td_out = module(policy_module(td)) assert td_out["data", "action"].shape == (3, 4, 1) if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4, 1) + assert td_out[log_prob_key].shape == (3, 4) else: - assert td_out["sample_log_prob"].shape == (3, 4, 1) + assert td_out["sample_log_prob"].shape == (3, 4) module = ProbabilisticTensorDictModule( - in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, + in_keys={"param": ("data", "param")}, out_keys=[("data", "action")], - distribution_class=Normal, + distribution_class=Delta, return_log_prob=True, log_prob_key=log_prob_key, ) - td_out = module(loc_module(scale_module(td))) + td_out = module(policy_module(td)) assert td_out["data", "action"].shape == (3, 4, 1) if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4, 1) + assert td_out[log_prob_key].shape == (3, 4) else: - assert td_out["sample_log_prob"].shape == (3, 4, 1) + assert td_out["sample_log_prob"].shape == (3, 4) + + @pytest.mark.parametrize( + "log_prob_key", + [ + None, + "sample_log_prob", + ("nested", "sample_log_prob"), + ("data", "sample_log_prob"), + ], + ) + def test_nested_keys_probabilistic_normal(self, log_prob_key): + loc_module = TensorDictModule( + nn.Linear(1, 1), + in_keys=[("data", "states")], + out_keys=[("data", "loc")], + ) + scale_module = TensorDictModule( + nn.Linear(1, 1), + in_keys=[("data", "states")], + out_keys=[("data", "scale")], + ) + td = TensorDict( + {"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3] + ) + + module = ProbabilisticTensorDictModule( + in_keys=[("data", "loc"), ("data", "scale")], + out_keys=[("data", "action")], + distribution_class=Normal, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = module(loc_module(scale_module(td))) + assert td_out["data", "action"].shape == (3, 4, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (3, 4, 1) + else: + assert td_out["sample_log_prob"].shape == (3, 4, 1) + + module = ProbabilisticTensorDictModule( + in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, + out_keys=[("data", "action")], + distribution_class=Normal, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + td_out = module(loc_module(scale_module(td))) + assert td_out["data", "action"].shape == (3, 4, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (3, 4, 1) + else: + assert td_out["sample_log_prob"].shape == (3, 4, 1) class TestEnsembleModule: @@ -2574,6 +2821,7 @@ def test_prob_module(self, interaction, return_log_prob, map_names): assert all(key in sample for key in module.out_keys) sample_clone = sample.clone() lp = module.log_prob(sample_clone) + assert isinstance(lp, torch.Tensor) if return_log_prob: torch.testing.assert_close( lp, @@ -2705,7 +2953,13 @@ def test_prob_module_seq(self, interaction, return_log_prob): assert "cont_log_prob" in sample.keys() assert ("nested", "cont_log_prob") in sample.keys(True) sample_clone = sample.clone() + + dist = module.get_dist(sample_clone) + assert isinstance(dist, CompositeDistribution) + + sample_clone = sample.clone() lp = module.log_prob(sample_clone) + if return_log_prob: torch.testing.assert_close( lp,