diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 0143ec856..be8aa42f1 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -121,7 +121,6 @@ def from_metadata(metadata=metadata, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) - print('cls_metadata', cls_metadata) result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) if is_locked: result = result.lock_() diff --git a/tensordict/_td.py b/tensordict/_td.py index 66f4dc86b..07a98cdfb 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2101,13 +2101,11 @@ def from_dict_instance( # TODO: v0.7: remove the None cur_value = self.get(key, None) if cur_value is not None: - print(type(cur_value)) input_dict[key] = cur_value.from_dict_instance( value, device=device, auto_batch_size=False, ) - print(type(cur_value), type(input_dict[key])) continue else: # we don't know if another tensor of smaller size is coming @@ -2142,10 +2140,7 @@ def from_dict_instance( elif auto_batch_size is None: auto_batch_size = True if auto_batch_size: - print('self', self) - print('out', out) _set_max_batch_size(out, batch_dims) - print('out', out) else: out.batch_size = batch_size return out diff --git a/tensordict/base.py b/tensordict/base.py index b9a4077d5..3666b6772 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1217,7 +1217,6 @@ def _from_dict_validated(cls, *args, **kwargs): By default, falls back on :meth:`~.from_dict`. """ kwargs.setdefault("auto_batch_size", True) - print('kwargs', kwargs) return cls.from_dict(*args, **kwargs) @abc.abstractmethod @@ -1225,7 +1224,7 @@ def from_dict_instance( self, input_dict, *others, - auto_batch_size: bool | None=None, + auto_batch_size: bool | None = None, batch_size=None, device=None, batch_dims=None, diff --git a/tensordict/functional.py b/tensordict/functional.py index edd93a36f..2699f36bb 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -437,7 +437,7 @@ def make_tensordict( input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, - auto_batch_size:bool|None=None, + auto_batch_size: bool | None = None, **kwargs: CompatibleType, # source ) -> TensorDict: """Returns a TensorDict created from the keyword arguments or an input dictionary. @@ -503,4 +503,6 @@ def make_tensordict( """ if input_dict is not None: kwargs.update(input_dict) - return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size) + return TensorDict.from_dict( + kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size + ) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0b55d1cef..ffedba9ad 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -297,9 +297,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: f"The key {expected_key} wasn't found in the keyword arguments " f"but is expected to execute that function." ) + batch_size = torch.Size([]) if not self.auto_batch_size else None tensordict = make_tensordict( tensordict_values, - batch_size=torch.Size([]) if not self.auto_batch_size else None, + batch_size=batch_size, + auto_batch_size=False, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 2f2502353..980b5431f 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -634,8 +634,11 @@ def forward( tensordict_exec = tensordict.copy() else: tensordict_exec = tensordict - tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs) - tensordict_exec = self.module[-1](tensordict_exec, _requires_sample=self._requires_sample) + tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs + ) + tensordict_exec = self.module[-1]( + tensordict_exec, _requires_sample=self._requires_sample + ) if tensordict_out is not None: result = tensordict_out result.update(tensordict_exec, keys_to_update=self.out_keys) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 07b6d5faa..1fb7d3049 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1581,7 +1581,15 @@ def _to_dict(self, *, retain_none: bool = True) -> dict: return td_dict -def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None): +def _from_dict( + cls, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, +): # we pass through a tensordict because keys could be passed as NestedKeys # We can't assume all keys are strings, otherwise calling cls(**kwargs) # would work ok @@ -1595,7 +1603,11 @@ def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=No non_tensordict=input_dict, ) td = TensorDict.from_dict( - input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims, auto_batch_size=auto_batch_size + input_dict, + batch_size=batch_size, + device=device, + batch_dims=batch_dims, + auto_batch_size=auto_batch_size, ) non_tensordict = {} @@ -1603,7 +1615,13 @@ def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=No def _from_dict_instance( - self, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, ): if batch_dims is not None and batch_size is not None: raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.") @@ -1773,7 +1791,7 @@ def _is_castable(datatype): if isinstance(value, dict): if _is_tensor_collection(target_cls): - cast_val = target_cls.from_dict(value) + cast_val = target_cls.from_dict(value, auto_batch_size=False) self._tensordict.set( key, cast_val, inplace=inplace, non_blocking=non_blocking ) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 127d4b77a..0f71bd743 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -584,7 +584,7 @@ def test_from_dict(self): class MyClass: a: TensorDictBase - tc = MyClass.from_dict(d) + tc = MyClass.from_dict(d, auto_batch_size=True) assert isinstance(tc, MyClass) assert isinstance(tc.a, TensorDict) assert tc.batch_size == torch.Size([10]) @@ -2148,7 +2148,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3] ) - assert (test_class == TestClass.from_dict(test_class.to_dict())).all() + assert ( + test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True) + ).all() # Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such # test_class2 = TestClass( @@ -2161,7 +2163,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3] ) - assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all() + assert not ( + test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True) + ).all() @tensorclass(autocast=True) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index efe47b480..372e1af6a 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -992,7 +992,9 @@ class MyClass: ("4", "h5py", "my_nested_td", "inner"), } ) - assert set(td.keys(True, True)) == expected, set(td.keys(True, True)).symmetric_difference(expected) + assert set(td.keys(True, True)) == expected, set( + td.keys(True, True) + ).symmetric_difference(expected) def test_from_dataclass(self): @dataclass @@ -1024,7 +1026,11 @@ def test_from_dict(self, batch_size, batch_dims, device): ) return data = TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device, auto_batch_size=True + data, + batch_size=batch_size, + batch_dims=batch_dims, + device=device, + auto_batch_size=True, ) assert data.device == device assert "a" in data.keys() @@ -6500,7 +6506,7 @@ def recursive_checker(cur_dict): assert recursive_checker(td_dict) if td_name == "td_with_non_tensor": assert td_dict["data"]["non_tensor"] == "some text data" - assert (TensorDict.from_dict(td_dict,auto_batch_size=False) == td).all() + assert (TensorDict.from_dict(td_dict, auto_batch_size=False) == td).all() def test_to_namedtuple(self, td_name, device): def is_namedtuple(obj): @@ -7865,19 +7871,29 @@ def test_tensordict_batch_size(self): tensordict = make_tensordict(a=torch.randn(3, 4), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5), auto_batch_size=True) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(3, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([3]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5), auto_batch_size=True) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(4, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True) + tensordict = make_tensordict( + a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([]) tensordict = make_tensordict( @@ -7893,7 +7909,10 @@ def test_tensordict_batch_size(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_device(self, device): tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4), device=device, auto_batch_size=True + a=torch.randn(3, 4), + b=torch.randn(3, 4), + device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device