Skip to content

Commit

Permalink
[Refactor] Refactor context managers
Browse files Browse the repository at this point in the history
ghstack-source-id: c16baa83f6e41c4afd6637f3b3739d4e5cf25f1e
Pull Request resolved: #1098
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent e2444ed commit 2567f62
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 132 deletions.
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def _quick_set(swap_dict, swap_td):
_quick_set(_swap, swap_dest)
return swap_dest
else:
return TensorDict._new_unsafe(_swap, batch_size=[])
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))

def __ne__(self, other: object) -> T | bool:
if is_tensorclass(other):
Expand Down
8 changes: 6 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9870,8 +9870,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
"""
if is_tensor_collection(obj):
if is_non_tensor(obj):
return cls.from_any(obj.data, auto_batch_size=auto_batch_size)
# Conversions from non-tensor data must be done manually
# if is_non_tensor(obj):
# from tensordict.tensorclass import LazyStackedTensorDict
# if isinstance(obj, LazyStackedTensorDict):
# return obj
# return cls.from_any(obj.data, auto_batch_size=auto_batch_size)
return obj
if isinstance(obj, dict):
return cls.from_dict(obj, auto_batch_size=auto_batch_size)
Expand Down
6 changes: 2 additions & 4 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ def _maybe_make_param(tensor):


def _maybe_make_param_or_buffer(tensor):
if (
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, (nn.Parameter, Buffer))
and tensor.dtype in (torch.float, torch.double, torch.half)
if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance(
tensor, (nn.Parameter, Buffer)
):
if not tensor.requires_grad and not is_batchedtensor(tensor):
# convert all non-parameters to buffers
Expand Down
14 changes: 6 additions & 8 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tensordict.nn.utils import _set_skip_existing_None
from tensordict.tensorclass import is_non_tensor
from tensordict.tensordict import TensorDictBase
from tensordict.utils import _zip_strict
from tensordict.utils import _ContextManager, _zip_strict
from torch import distributions as D, Tensor

from torch.utils._contextlib import _DecoratorContextManager
Expand Down Expand Up @@ -66,12 +66,12 @@ def from_str(cls, type_str: str) -> InteractionType:
return cls(type_str.lower())


_INTERACTION_TYPE: InteractionType | None = None
_interaction_type = _ContextManager()


def interaction_type() -> InteractionType | None:
"""Returns the current sampling type."""
return _INTERACTION_TYPE
return _interaction_type.get_mode()


class set_interaction_type(_DecoratorContextManager):
Expand All @@ -98,13 +98,11 @@ def clone(self) -> set_interaction_type:
return type(self)(self.type)

def __enter__(self) -> None:
global _INTERACTION_TYPE
self.prev = _INTERACTION_TYPE
_INTERACTION_TYPE = self.type
self.prev = _interaction_type.get_mode()
_interaction_type.set_mode(self.type)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _INTERACTION_TYPE
_INTERACTION_TYPE = self.prev
_interaction_type.set_mode(self.prev)


class ProbabilisticTensorDictModule(TensorDictModuleBase):
Expand Down
39 changes: 18 additions & 21 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@
from typing import Any, Callable

import torch
from tensordict.utils import strtobool
from tensordict.utils import _ContextManager, strtobool
from torch import nn

from torch.utils._contextlib import _DecoratorContextManager

try:
from torch.compiler import is_dynamo_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling


DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True"))
_dispatch_tdnn_modules = _ContextManager(
default=strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True"))
)

__all__ = ["mappings", "inv_softplus", "biased_softplus"]

_SKIP_EXISTING = False

from torch.utils._contextlib import _DecoratorContextManager
_skip_existing = _ContextManager(default=False)


def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor:
Expand Down Expand Up @@ -300,19 +302,17 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
def __enter__(self) -> None:
if self.mode and is_dynamo_compiling():
raise RuntimeError("skip_existing is not compatible with TorchDynamo.")
global _SKIP_EXISTING
self.prev = _SKIP_EXISTING
self.prev = _skip_existing.get_mode()
if self.mode is not None:
_SKIP_EXISTING = self.mode
_skip_existing.set_mode(self.mode)
elif not self._called:
raise RuntimeError(
f"It seems you are using {type(self).__name__} as a context manager with ``None`` input. "
f"This behaviour is not allowed."
)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _SKIP_EXISTING
_SKIP_EXISTING = self.prev
_skip_existing.set_mode(self.prev)


class _set_skip_existing_None(set_skip_existing):
Expand Down Expand Up @@ -353,12 +353,11 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
return tensordict
if is_dynamo_compiling():
return func(_self, tensordict, *args, **kwargs)
global _SKIP_EXISTING
self.prev = _SKIP_EXISTING
self.prev = _skip_existing.get_mode()
try:
result = func(_self, tensordict, *args, **kwargs)
finally:
_SKIP_EXISTING = self.prev
_skip_existing.set_mode(self.prev)
return result

return wrapper
Expand All @@ -375,7 +374,7 @@ def clone(self) -> _set_skip_existing_None:

def skip_existing():
"""Returns whether or not existing entries in a tensordict should be re-computed by a module."""
return _SKIP_EXISTING
return _skip_existing.get_mode()


def _rebuild_buffer(data, requires_grad, backward_hooks):
Expand All @@ -397,7 +396,7 @@ def _rebuild_buffer(data, requires_grad, backward_hooks):

def _dispatch_td_nn_modules():
"""Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""
return DISPATCH_TDNN_MODULES
return _dispatch_tdnn_modules.get_mode()


class _set_dispatch_td_nn_modules(_DecoratorContextManager):
Expand All @@ -411,17 +410,15 @@ def clone(self):
return type(self)(self.mode)

def __enter__(self):
global DISPATCH_TDNN_MODULES
# We want to avoid changing global variables because compile puts guards on them
if DISPATCH_TDNN_MODULES != self.mode:
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode
if _dispatch_tdnn_modules.get_mode() != self.mode:
self._saved_mode = _dispatch_tdnn_modules
_dispatch_tdnn_modules.set_mode(self.mode)

def __exit__(self, exc_type, exc_val, exc_tb):
if self._saved_mode is None:
return
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode
_dispatch_tdnn_modules.set_mode(self._saved_mode)


# Reproduce StrEnum for python<3.11
Expand Down
50 changes: 33 additions & 17 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import re

import sys
import threading
import time
import warnings
from collections import defaultdict
from collections.abc import KeysView
from contextlib import nullcontext
from copy import copy
from dataclasses import _FIELDS, GenericAlias
from functools import wraps
Expand Down Expand Up @@ -66,12 +68,9 @@
except ImportError:
_has_funcdim = False
try:
from torch.compiler import assume_constant_result, is_dynamo_compiling
from torch.compiler import assume_constant_result, is_compiling
except ImportError: # torch 2.0
from torch._dynamo import (
assume_constant_result,
is_compiling as is_dynamo_compiling,
)
from torch._dynamo import assume_constant_result, is_compiling

if TYPE_CHECKING:
from tensordict.tensordict import TensorDictBase
Expand Down Expand Up @@ -862,7 +861,7 @@ def _is_tensorclass(cls: type) -> bool:
out = _TENSORCLASS_MEMO.get(cls)
if out is None:
out = getattr(cls, "_is_tensorclass", False)
if not is_dynamo_compiling():
if not is_compiling():
_TENSORCLASS_MEMO[cls] = out
return out

Expand Down Expand Up @@ -1118,7 +1117,7 @@ def cache(fun):

@wraps(fun)
def newfun(_self: "TensorDictBase", *args, **kwargs):
if not _self.is_locked or is_dynamo_compiling():
if not _self.is_locked or is_compiling():
return fun(_self, *args, **kwargs)
cache = _self._cache
if cache is None:
Expand Down Expand Up @@ -1358,7 +1357,7 @@ def _parse_to(*args, **kwargs):
num_threads = kwargs.pop("num_threads", None)
other = kwargs.pop("other", None)
inplace = kwargs.pop("inplace", False)
if not is_dynamo_compiling():
if not is_compiling():
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
Expand Down Expand Up @@ -1732,7 +1731,7 @@ def _check_keys(
is_leaf=_is_leaf_nontensor,
)
# TODO: compile doesn't like set() over an arbitrary object
if is_dynamo_compiling():
if is_compiling():
keys = {k for k in keys} # noqa: C416
else:
keys: set[str] = set(keys)
Expand All @@ -1745,7 +1744,7 @@ def _check_keys(
if not strict:
keys = keys.intersection(k)
else:
if is_dynamo_compiling():
if is_compiling():
k = {v for v in k} # noqa: C416
else:
k = set(k)
Expand Down Expand Up @@ -2014,7 +2013,7 @@ def _getitem_batch_size(batch_size, index):
continue
elif isinstance(idx, slice):
batch = batch_size[count]
if is_dynamo_compiling():
if is_compiling():
out.append(len(range(*_slice_indices(idx, batch))))
else:
out.append(len(range(*idx.indices(batch))))
Expand Down Expand Up @@ -2446,7 +2445,7 @@ def is_non_tensor(data):

def _is_non_tensor(cls: type):
out = None
is_dynamo = is_dynamo_compiling()
is_dynamo = is_compiling()
if not is_dynamo:
out = _NON_TENSOR_MEMO.get(cls)
if out is None:
Expand Down Expand Up @@ -2502,7 +2501,7 @@ def new_func(self):


def _unravel_key_to_tuple(key):
if not is_dynamo_compiling():
if not is_compiling():
return _unravel_key_to_tuple_cpp(key)
if isinstance(key, str):
return (key,)
Expand All @@ -2523,7 +2522,7 @@ def unravel_key(key):
("a", "b")
"""
if not is_dynamo_compiling():
if not is_compiling():
return unravel_key_cpp(key)
if isinstance(key, str):
return key
Expand All @@ -2536,14 +2535,14 @@ def unravel_key(key):

def unravel_keys(*keys):
"""Unravels a sequence of keys."""
if not is_dynamo_compiling():
if not is_compiling():
return unravel_keys_cpp(*keys)
return tuple(unravel_key(key) for key in keys)


def unravel_key_list(keys):
"""Unravels a list of keys."""
if not is_dynamo_compiling():
if not is_compiling():
return unravel_key_list_cpp(keys)
return [unravel_key(key) for key in keys]

Expand Down Expand Up @@ -2823,7 +2822,8 @@ def _is_dataclass(obj):
if isinstance(obj, type) and not isinstance(obj, GenericAlias)
else type(obj)
)
return hasattr(cls, _FIELDS)
# return hasattr(cls, _FIELDS)
return getattr(cls, _FIELDS, None) is not None


def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]:
Expand Down Expand Up @@ -2857,3 +2857,19 @@ def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]:
if len(sizes):
return True, (length_t, *list(sizes)[0]), dtype
return True, (length_t,), dtype


class _ContextManager:
def __init__(self, default=None):
self._mode: Any | None = default
self._lock = threading.Lock()

def get_mode(self) -> Any | None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
return self._mode

def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type
Loading

0 comments on commit 2567f62

Please sign in to comment.