Skip to content

Commit

Permalink
[Deprecation] Act warned deprecations for v0.6
Browse files Browse the repository at this point in the history
ghstack-source-id: 9f9ce070d8726c74fcf5a22e0edd05b8c9fd7e19
Pull Request resolved: #1001
  • Loading branch information
vmoens committed Oct 8, 2024
1 parent 177a08a commit b24c37d
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 133 deletions.
15 changes: 2 additions & 13 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Tuple,
Type,
)
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -1909,18 +1908,8 @@ def _apply_nest(
)
for i, (td, *oth) in enumerate(_zip_strict(self.tensordicts, *others))
]
if all(r is None for r in results):
if filter_empty is None:
warn(
"Your resulting tensordict has no leaves but you did not specify filter_empty=True. "
"This now returns None (filter_empty=True). "
"To silence this warning, set filter_empty to the desired value in your call to `apply`. "
"This warning will be removed in v0.6.",
category=DeprecationWarning,
)
return
elif filter_empty:
return
if all(r is None for r in results) and filter_empty in (None, True):
return
if not inplace:
out = type(self)(
*results,
Expand Down
7 changes: 0 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,13 +1401,6 @@ def make_result(names=names, batch_size=batch_size):
# we raise the deprecation warning only if the tensordict wasn't already empty.
# After we introduce the new behaviour, we will have to consider what happens
# to empty tensordicts by default: will they disappear or stay?
warn(
"Your resulting tensordict has no leaves but you did not specify filter_empty=True. "
"This now returns None (filter_empty=True). "
"To silence this warning, set filter_empty to the desired value in your call to `apply`. "
"This warning will be removed in v0.6.",
category=DeprecationWarning,
)
return
if result is None:
result = make_result()
Expand Down
59 changes: 38 additions & 21 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5178,6 +5178,14 @@ def update_(
if input_dict_or_td is self:
# no op
return self

if not _is_tensor_collection(type(input_dict_or_td)):
from tensordict import TensorDict

input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)

if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
Expand All @@ -5193,29 +5201,35 @@ def inplace_update(name, dest, source):
if key == name[: len(key)]:
dest.copy_(source, non_blocking=non_blocking)

self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
filter_empty=True,
named=named,
is_leaf=_is_leaf_nontensor,
)
return self
else:
if not _is_tensor_collection(type(input_dict_or_td)):
from tensordict import TensorDict

input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)

# Fastest route using _foreach_copy_
keys, vals = self._items_list(True, True)
other_val = input_dict_or_td._values_list(True, True, sorting_keys=keys)
torch._foreach_copy_(vals, other_val)
return self
new_keys, other_val = input_dict_or_td._items_list(
True, True, sorting_keys=keys, default="intersection"
)
if len(new_keys):
if len(other_val) != len(vals):
vals = dict(*zip(keys, vals))
vals = [vals[k] for k in new_keys]
torch._foreach_copy_(vals, other_val)
return self
named = False

def inplace_update(dest, source):
if source is None:
return None
dest.copy_(source, non_blocking=non_blocking)

self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
filter_empty=True,
named=named,
is_leaf=_is_leaf_nontensor,
)
return self

def update_at_(
self,
Expand Down Expand Up @@ -5638,7 +5652,10 @@ def _items_list(
leaves_only=leaves_only,
is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else None,
)
keys, vals = zip(*items)
keys_vals = tuple(zip(*items))
if not keys_vals:
return (), ()
keys, vals = keys_vals
if sorting_keys is None:
return list(keys), list(vals)
if default is None:
Expand Down
7 changes: 0 additions & 7 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T:

def pad_sequence(
list_of_tensordicts: Sequence[T],
batch_first: bool | None = None,
pad_dim: int = 0,
padding_value: float = 0.0,
out: T | None = None,
Expand Down Expand Up @@ -146,12 +145,6 @@ def pad_sequence(
"The device argument is ignored by this function and will be removed in v0.5. To cast your"
" result to a different device, call `tensordict.to(device)` instead."
)
if batch_first is not None:
warnings.warn(
"The batch_first argument is deprecated and will be removed in v0.6. "
"The output will always be batch_first.",
category=DeprecationWarning,
)

if not len(list_of_tensordicts):
raise RuntimeError("list_of_tensordicts cannot be empty")
Expand Down
1 change: 0 additions & 1 deletion tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
set_interaction_mode,
set_interaction_type,
)
from tensordict.nn.sequence import TensorDictSequential
Expand Down
15 changes: 0 additions & 15 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,21 +1085,6 @@ def forward(
) -> TensorDictBase:
"""When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict."""
try:
if len(args):
tensordict_out = args[0]
args = args[1:]
# we will get rid of tensordict_out as a regular arg, because it
# blocks us when using vmap
# with stateful but functional modules: the functional module checks if
# it still contains parameters. If so it considers that only a "params" kwarg
# is indicative of what the params are, when we could potentially make a
# special rule for TensorDictModule that states that the second arg is
# likely to be the module params.
warnings.warn(
"tensordict_out will be deprecated in v0.6. "
"Make sure you have removed any such arg by then.",
category=DeprecationWarning,
)
if len(args):
raise ValueError(
"Got a non-empty list of extra agruments, when none was expected."
Expand Down
82 changes: 17 additions & 65 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

import re
import warnings
from enum import auto, IntEnum

try:
from enum import StrEnum
except ImportError:
from .utils import StrEnum
from textwrap import indent
from typing import Any, Callable, Dict, List, Optional
from warnings import warn
from typing import Any, Dict, List, Optional

from tensordict._nestedkey import NestedKey

Expand All @@ -30,7 +33,7 @@
__all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"]


class InteractionType(IntEnum):
class InteractionType(StrEnum):
"""A list of possible interaction types with a distribution.
MODE, MEDIAN and MEAN point to the property / attribute with the same name.
Expand All @@ -44,11 +47,11 @@ class InteractionType(IntEnum):
"""

MODE = auto()
MEDIAN = auto()
MEAN = auto()
RANDOM = auto()
DETERMINISTIC = auto()
MODE = "mode"
MEDIAN = "median"
MEAN = "mean"
RANDOM = "random"
DETERMINISTIC = "deterministic"

@classmethod
def from_str(cls, type_str: str) -> InteractionType:
Expand All @@ -62,57 +65,11 @@ def from_str(cls, type_str: str) -> InteractionType:
_INTERACTION_TYPE: InteractionType | None = None


def _insert_interaction_mode_deprecation_warning(
prefix: str = "",
) -> Callable[[str, Warning, int], None]:
return warn(
(
f"{prefix}interaction_mode is deprecated for naming clarity and will be removed in v0.6. "
f"Please use {prefix}interaction_type with InteractionType enum instead."
),
DeprecationWarning,
stacklevel=2,
)


def interaction_type() -> InteractionType | None:
"""Returns the current sampling type."""
return _INTERACTION_TYPE


def interaction_mode() -> str | None:
"""*Deprecated* Returns the current sampling mode."""
_insert_interaction_mode_deprecation_warning()
type = interaction_type()
return type.name.lower() if type else None


class set_interaction_mode(_DecoratorContextManager):
"""*Deprecated* Sets the sampling mode of all ProbabilisticTDModules to the desired mode.
Args:
mode (str): mode to use when the policy is being called.
"""

def __init__(self, mode: str | None = "mode") -> None:
_insert_interaction_mode_deprecation_warning("set_")
super().__init__()
self.mode = InteractionType.from_str(mode) if mode else None

def clone(self) -> set_interaction_mode:
# override this method if your children class takes __init__ parameters
return type(self)(self.mode)

def __enter__(self) -> None:
global _INTERACTION_TYPE
self.prev = _INTERACTION_TYPE
_INTERACTION_TYPE = self.mode

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _INTERACTION_TYPE
_INTERACTION_TYPE = self.prev


class set_interaction_type(_DecoratorContextManager):
"""Sets all ProbabilisticTDModules sampling to the desired type.
Expand Down Expand Up @@ -366,12 +323,10 @@ def __init__(
self.log_prob_key = log_prob_key

if default_interaction_mode is not None:
_insert_interaction_mode_deprecation_warning("default_")
self.default_interaction_type = InteractionType.from_str(
default_interaction_mode
raise ValueError(
"default_interaction_mode is deprecated, use default_interaction_type instead."
)
else:
self.default_interaction_type = default_interaction_type
self.default_interaction_type = default_interaction_type

if isinstance(distribution_class, str):
distribution_class = distributions_maps.get(distribution_class.lower())
Expand Down Expand Up @@ -418,12 +373,9 @@ def log_prob(self, tensordict):

@property
def SAMPLE_LOG_PROB_KEY(self):
warnings.warn(
"SAMPLE_LOG_PROB_KEY will be deprecated in v0.6."
"Use 'obj.log_prob_key' instead",
category=DeprecationWarning,
raise RuntimeError(
"SAMPLE_LOG_PROB_KEY is fully deprecated. Use `obj.log_prob_key` instead."
)
return self.log_prob_key

@dispatch(auto_batch_size=False)
@_set_skip_existing_None()
Expand Down
29 changes: 29 additions & 0 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import inspect
import os
from enum import ReprEnum
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -444,3 +445,31 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode


# Reproduce StrEnum for python<3.11


class StrEnum(str, ReprEnum): # noqa
def __new__(cls, *values):
if len(values) > 3:
raise TypeError("too many arguments for str(): %r" % (values,))
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str):
raise TypeError("%r is not a string" % (values[0],))
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(values[1], str):
raise TypeError("encoding must be a string, not %r" % (values[1],))
if len(values) == 3:
# check that errors argument is a string
if not isinstance(values[2], str):
raise TypeError("errors must be a string, not %r" % (values[2]))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member

def _generate_next_value_(name, start, count, last_values):
return name.lower()
6 changes: 2 additions & 4 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,9 @@ def _get_item(tensor: Tensor, index: IndexType) -> Tensor:
if _is_lis_of_list_of_bools(index):
index = torch.tensor(index, device=tensor.device)
if index.dtype is torch.bool:
warnings.warn(
raise RuntimeError(
"Indexing a tensor with a nested list of boolean values is "
"going to be deprecated in v0.6 as this functionality is not supported "
f"by PyTorch. (follows error: {err})",
category=DeprecationWarning,
"not supported by PyTorch.",
)
return tensor[index]
raise err
Expand Down

0 comments on commit b24c37d

Please sign in to comment.