diff --git a/tensordict/base.py b/tensordict/base.py index 3666b6772..3dea5a7d0 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -55,6 +55,7 @@ _CloudpickleWrapper, _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, + _is_dataclass as is_dataclass, _is_non_tensor, _is_number, _is_tensorclass, @@ -9452,6 +9453,7 @@ def _validate_value( if device is not None and value.device != device: if _device_recorder.marked and device.type != "cuda": _device_recorder.record_transfer(device) + assert not non_blocking value = value.to(device, non_blocking=non_blocking) if check_shape: if is_tc is None: @@ -9874,16 +9876,15 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_dict(obj, auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) - from dataclasses import is_dataclass - if is_dataclass(obj): - return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) - if is_namedtuple(obj): - return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, tuple): return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if is_dataclass(obj): + return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if _has_h5: import h5py @@ -9942,7 +9943,7 @@ def from_dataclass( from tensordict.tensorclass import from_dataclass return from_dataclass(dataclass, auto_batch_size=auto_batch_size) - from dataclasses import fields, is_dataclass + from dataclasses import fields from tensordict import TensorDict diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 1fb7d3049..e1c8e77b4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -46,6 +46,7 @@ CompatibleType, ) from tensordict.utils import ( # @manual=//pytorch/tensordict:_C + _is_dataclass as is_dataclass, _is_json_serializable, _is_tensorclass, _LOCK_ERROR, @@ -450,7 +451,7 @@ def from_dataclass( by default, this method will return a tensorclass instance or type. """ - from dataclasses import asdict, is_dataclass, make_dataclass + from dataclasses import asdict, make_dataclass if isinstance(obj, type): if is_tensorclass(obj): diff --git a/tensordict/utils.py b/tensordict/utils.py index 54e00a174..81ab2fa0c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -20,6 +20,7 @@ from collections import defaultdict from collections.abc import KeysView from copy import copy +from dataclasses import _FIELDS, GenericAlias from functools import wraps from importlib import import_module from numbers import Number @@ -2813,3 +2814,13 @@ def _mismatch_keys(keys1, keys2): if sub2 is not None: main.append(sub2) raise KeyError(r" ".join(main)) + + +def _is_dataclass(obj): + """Like dataclasses.is_dataclass but compatible with compile.""" + cls = ( + obj + if isinstance(obj, type) and not isinstance(obj, GenericAlias) + else type(obj) + ) + return hasattr(cls, _FIELDS) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 372e1af6a..73d401c03 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -10661,7 +10661,8 @@ def test_non_tensor_call(self): def test_nontensor_dict(self, non_tensor_data): assert ( - TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data + TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True) + == non_tensor_data ).all() def test_nontensor_tensor(self): @@ -11202,7 +11203,7 @@ def _to_float(td, td_name, tmpdir): td._source = td._source.float() elif td_name in ("td_h5",): td = PersistentTensorDict.from_dict( - td.float().to_dict(), filename=tmpdir + "/file.t" + td.float().to_dict(), filename=tmpdir + "/file.t", auto_batch_size=True ) elif td_name in ("td_params",): td = TensorDictParams(td.data.float())