From b24c37d654d6b09e57b1c8516863b1a334e9e680 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 8 Oct 2024 11:07:55 +0100 Subject: [PATCH] [Deprecation] Act warned deprecations for v0.6 ghstack-source-id: 9f9ce070d8726c74fcf5a22e0edd05b8c9fd7e19 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1001 --- tensordict/_lazy.py | 15 +------ tensordict/_td.py | 7 --- tensordict/base.py | 59 +++++++++++++++--------- tensordict/functional.py | 7 --- tensordict/nn/__init__.py | 1 - tensordict/nn/common.py | 15 ------- tensordict/nn/probabilistic.py | 82 +++++++--------------------------- tensordict/nn/utils.py | 29 ++++++++++++ tensordict/utils.py | 6 +-- 9 files changed, 88 insertions(+), 133 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 18cb89e10..88145d71e 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -27,7 +27,6 @@ Tuple, Type, ) -from warnings import warn import numpy as np @@ -1909,18 +1908,8 @@ def _apply_nest( ) for i, (td, *oth) in enumerate(_zip_strict(self.tensordicts, *others)) ] - if all(r is None for r in results): - if filter_empty is None: - warn( - "Your resulting tensordict has no leaves but you did not specify filter_empty=True. " - "This now returns None (filter_empty=True). " - "To silence this warning, set filter_empty to the desired value in your call to `apply`. " - "This warning will be removed in v0.6.", - category=DeprecationWarning, - ) - return - elif filter_empty: - return + if all(r is None for r in results) and filter_empty in (None, True): + return if not inplace: out = type(self)( *results, diff --git a/tensordict/_td.py b/tensordict/_td.py index b6a581ebb..f56e25052 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1401,13 +1401,6 @@ def make_result(names=names, batch_size=batch_size): # we raise the deprecation warning only if the tensordict wasn't already empty. # After we introduce the new behaviour, we will have to consider what happens # to empty tensordicts by default: will they disappear or stay? - warn( - "Your resulting tensordict has no leaves but you did not specify filter_empty=True. " - "This now returns None (filter_empty=True). " - "To silence this warning, set filter_empty to the desired value in your call to `apply`. " - "This warning will be removed in v0.6.", - category=DeprecationWarning, - ) return if result is None: result = make_result() diff --git a/tensordict/base.py b/tensordict/base.py index f8dce9784..8adf8dcef 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5178,6 +5178,14 @@ def update_( if input_dict_or_td is self: # no op return self + + if not _is_tensor_collection(type(input_dict_or_td)): + from tensordict import TensorDict + + input_dict_or_td = TensorDict.from_dict( + input_dict_or_td, batch_dims=self.batch_dims + ) + if keys_to_update is not None: if len(keys_to_update) == 0: return self @@ -5193,29 +5201,35 @@ def inplace_update(name, dest, source): if key == name[: len(key)]: dest.copy_(source, non_blocking=non_blocking) - self._apply_nest( - inplace_update, - input_dict_or_td, - nested_keys=True, - default=None, - filter_empty=True, - named=named, - is_leaf=_is_leaf_nontensor, - ) - return self else: - if not _is_tensor_collection(type(input_dict_or_td)): - from tensordict import TensorDict - - input_dict_or_td = TensorDict.from_dict( - input_dict_or_td, batch_dims=self.batch_dims - ) - # Fastest route using _foreach_copy_ keys, vals = self._items_list(True, True) - other_val = input_dict_or_td._values_list(True, True, sorting_keys=keys) - torch._foreach_copy_(vals, other_val) - return self + new_keys, other_val = input_dict_or_td._items_list( + True, True, sorting_keys=keys, default="intersection" + ) + if len(new_keys): + if len(other_val) != len(vals): + vals = dict(*zip(keys, vals)) + vals = [vals[k] for k in new_keys] + torch._foreach_copy_(vals, other_val) + return self + named = False + + def inplace_update(dest, source): + if source is None: + return None + dest.copy_(source, non_blocking=non_blocking) + + self._apply_nest( + inplace_update, + input_dict_or_td, + nested_keys=True, + default=None, + filter_empty=True, + named=named, + is_leaf=_is_leaf_nontensor, + ) + return self def update_at_( self, @@ -5638,7 +5652,10 @@ def _items_list( leaves_only=leaves_only, is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else None, ) - keys, vals = zip(*items) + keys_vals = tuple(zip(*items)) + if not keys_vals: + return (), () + keys, vals = keys_vals if sorting_keys is None: return list(keys), list(vals) if default is None: diff --git a/tensordict/functional.py b/tensordict/functional.py index 4daa3f303..d5b770e06 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -99,7 +99,6 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T: def pad_sequence( list_of_tensordicts: Sequence[T], - batch_first: bool | None = None, pad_dim: int = 0, padding_value: float = 0.0, out: T | None = None, @@ -146,12 +145,6 @@ def pad_sequence( "The device argument is ignored by this function and will be removed in v0.5. To cast your" " result to a different device, call `tensordict.to(device)` instead." ) - if batch_first is not None: - warnings.warn( - "The batch_first argument is deprecated and will be removed in v0.6. " - "The output will always be batch_first.", - category=DeprecationWarning, - ) if not len(list_of_tensordicts): raise RuntimeError("list_of_tensordicts cannot be empty") diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index c7aa4f8f7..55590889a 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -30,7 +30,6 @@ InteractionType, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, - set_interaction_mode, set_interaction_type, ) from tensordict.nn.sequence import TensorDictSequential diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 51ff52b33..d1faceec3 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1085,21 +1085,6 @@ def forward( ) -> TensorDictBase: """When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.""" try: - if len(args): - tensordict_out = args[0] - args = args[1:] - # we will get rid of tensordict_out as a regular arg, because it - # blocks us when using vmap - # with stateful but functional modules: the functional module checks if - # it still contains parameters. If so it considers that only a "params" kwarg - # is indicative of what the params are, when we could potentially make a - # special rule for TensorDictModule that states that the second arg is - # likely to be the module params. - warnings.warn( - "tensordict_out will be deprecated in v0.6. " - "Make sure you have removed any such arg by then.", - category=DeprecationWarning, - ) if len(args): raise ValueError( "Got a non-empty list of extra agruments, when none was expected." diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 3dc887dc4..e3fae2fb5 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -7,10 +7,13 @@ import re import warnings -from enum import auto, IntEnum + +try: + from enum import StrEnum +except ImportError: + from .utils import StrEnum from textwrap import indent -from typing import Any, Callable, Dict, List, Optional -from warnings import warn +from typing import Any, Dict, List, Optional from tensordict._nestedkey import NestedKey @@ -30,7 +33,7 @@ __all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"] -class InteractionType(IntEnum): +class InteractionType(StrEnum): """A list of possible interaction types with a distribution. MODE, MEDIAN and MEAN point to the property / attribute with the same name. @@ -44,11 +47,11 @@ class InteractionType(IntEnum): """ - MODE = auto() - MEDIAN = auto() - MEAN = auto() - RANDOM = auto() - DETERMINISTIC = auto() + MODE = "mode" + MEDIAN = "median" + MEAN = "mean" + RANDOM = "random" + DETERMINISTIC = "deterministic" @classmethod def from_str(cls, type_str: str) -> InteractionType: @@ -62,57 +65,11 @@ def from_str(cls, type_str: str) -> InteractionType: _INTERACTION_TYPE: InteractionType | None = None -def _insert_interaction_mode_deprecation_warning( - prefix: str = "", -) -> Callable[[str, Warning, int], None]: - return warn( - ( - f"{prefix}interaction_mode is deprecated for naming clarity and will be removed in v0.6. " - f"Please use {prefix}interaction_type with InteractionType enum instead." - ), - DeprecationWarning, - stacklevel=2, - ) - - def interaction_type() -> InteractionType | None: """Returns the current sampling type.""" return _INTERACTION_TYPE -def interaction_mode() -> str | None: - """*Deprecated* Returns the current sampling mode.""" - _insert_interaction_mode_deprecation_warning() - type = interaction_type() - return type.name.lower() if type else None - - -class set_interaction_mode(_DecoratorContextManager): - """*Deprecated* Sets the sampling mode of all ProbabilisticTDModules to the desired mode. - - Args: - mode (str): mode to use when the policy is being called. - """ - - def __init__(self, mode: str | None = "mode") -> None: - _insert_interaction_mode_deprecation_warning("set_") - super().__init__() - self.mode = InteractionType.from_str(mode) if mode else None - - def clone(self) -> set_interaction_mode: - # override this method if your children class takes __init__ parameters - return type(self)(self.mode) - - def __enter__(self) -> None: - global _INTERACTION_TYPE - self.prev = _INTERACTION_TYPE - _INTERACTION_TYPE = self.mode - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - global _INTERACTION_TYPE - _INTERACTION_TYPE = self.prev - - class set_interaction_type(_DecoratorContextManager): """Sets all ProbabilisticTDModules sampling to the desired type. @@ -366,12 +323,10 @@ def __init__( self.log_prob_key = log_prob_key if default_interaction_mode is not None: - _insert_interaction_mode_deprecation_warning("default_") - self.default_interaction_type = InteractionType.from_str( - default_interaction_mode + raise ValueError( + "default_interaction_mode is deprecated, use default_interaction_type instead." ) - else: - self.default_interaction_type = default_interaction_type + self.default_interaction_type = default_interaction_type if isinstance(distribution_class, str): distribution_class = distributions_maps.get(distribution_class.lower()) @@ -418,12 +373,9 @@ def log_prob(self, tensordict): @property def SAMPLE_LOG_PROB_KEY(self): - warnings.warn( - "SAMPLE_LOG_PROB_KEY will be deprecated in v0.6." - "Use 'obj.log_prob_key' instead", - category=DeprecationWarning, + raise RuntimeError( + "SAMPLE_LOG_PROB_KEY is fully deprecated. Use `obj.log_prob_key` instead." ) - return self.log_prob_key @dispatch(auto_batch_size=False) @_set_skip_existing_None() diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 8b4875056..067e2c549 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -8,6 +8,7 @@ import functools import inspect import os +from enum import ReprEnum from typing import Any, Callable import torch @@ -444,3 +445,31 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): global DISPATCH_TDNN_MODULES DISPATCH_TDNN_MODULES = self._saved_mode + + +# Reproduce StrEnum for python<3.11 + + +class StrEnum(str, ReprEnum): # noqa + def __new__(cls, *values): + if len(values) > 3: + raise TypeError("too many arguments for str(): %r" % (values,)) + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError("%r is not a string" % (values[0],)) + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError("encoding must be a string, not %r" % (values[1],)) + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError("errors must be a string, not %r" % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + def _generate_next_value_(name, start, count, last_values): + return name.lower() diff --git a/tensordict/utils.py b/tensordict/utils.py index fd3140401..5edc54e4b 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -705,11 +705,9 @@ def _get_item(tensor: Tensor, index: IndexType) -> Tensor: if _is_lis_of_list_of_bools(index): index = torch.tensor(index, device=tensor.device) if index.dtype is torch.bool: - warnings.warn( + raise RuntimeError( "Indexing a tensor with a nested list of boolean values is " - "going to be deprecated in v0.6 as this functionality is not supported " - f"by PyTorch. (follows error: {err})", - category=DeprecationWarning, + "not supported by PyTorch.", ) return tensor[index] raise err