Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 22, 2024
2 parents 7c999e5 + 8c5c0e0 commit 47d0b3d
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 29 deletions.
1 change: 0 additions & 1 deletion tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def from_metadata(metadata=metadata, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
print('cls_metadata', cls_metadata)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
Expand Down
5 changes: 0 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,13 +2101,11 @@ def from_dict_instance(
# TODO: v0.7: remove the None
cur_value = self.get(key, None)
if cur_value is not None:
print(type(cur_value))
input_dict[key] = cur_value.from_dict_instance(
value,
device=device,
auto_batch_size=False,
)
print(type(cur_value), type(input_dict[key]))
continue
else:
# we don't know if another tensor of smaller size is coming
Expand Down Expand Up @@ -2142,10 +2140,7 @@ def from_dict_instance(
elif auto_batch_size is None:
auto_batch_size = True
if auto_batch_size:
print('self', self)
print('out', out)
_set_max_batch_size(out, batch_dims)
print('out', out)
else:
out.batch_size = batch_size
return out
Expand Down
3 changes: 1 addition & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,15 +1217,14 @@ def _from_dict_validated(cls, *args, **kwargs):
By default, falls back on :meth:`~.from_dict`.
"""
kwargs.setdefault("auto_batch_size", True)
print('kwargs', kwargs)
return cls.from_dict(*args, **kwargs)

@abc.abstractmethod
def from_dict_instance(
self,
input_dict,
*others,
auto_batch_size: bool | None=None,
auto_batch_size: bool | None = None,
batch_size=None,
device=None,
batch_dims=None,
Expand Down
6 changes: 4 additions & 2 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def make_tensordict(
input_dict: dict[str, CompatibleType] | None = None,
batch_size: Sequence[int] | torch.Size | int | None = None,
device: DeviceType | None = None,
auto_batch_size:bool|None=None,
auto_batch_size: bool | None = None,
**kwargs: CompatibleType, # source
) -> TensorDict:
"""Returns a TensorDict created from the keyword arguments or an input dictionary.
Expand Down Expand Up @@ -503,4 +503,6 @@ def make_tensordict(
"""
if input_dict is not None:
kwargs.update(input_dict)
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size)
return TensorDict.from_dict(
kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size
)
4 changes: 3 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
f"The key {expected_key} wasn't found in the keyword arguments "
f"but is expected to execute that function."
)
batch_size = torch.Size([]) if not self.auto_batch_size else None
tensordict = make_tensordict(
tensordict_values,
batch_size=torch.Size([]) if not self.auto_batch_size else None,
batch_size=batch_size,
auto_batch_size=False,
)
if _self is not None:
out = func(_self, tensordict, *args, **kwargs)
Expand Down
7 changes: 5 additions & 2 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,11 @@ def forward(
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs)
tensordict_exec = self.module[-1](tensordict_exec, _requires_sample=self._requires_sample)
tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs
)
tensordict_exec = self.module[-1](
tensordict_exec, _requires_sample=self._requires_sample
)
if tensordict_out is not None:
result = tensordict_out
result.update(tensordict_exec, keys_to_update=self.out_keys)
Expand Down
26 changes: 22 additions & 4 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,7 +1581,15 @@ def _to_dict(self, *, retain_none: bool = True) -> dict:
return td_dict


def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None):
def _from_dict(
cls,
input_dict,
*,
auto_batch_size: bool | None = None,
batch_size=None,
device=None,
batch_dims=None,
):
# we pass through a tensordict because keys could be passed as NestedKeys
# We can't assume all keys are strings, otherwise calling cls(**kwargs)
# would work ok
Expand All @@ -1595,15 +1603,25 @@ def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=No
non_tensordict=input_dict,
)
td = TensorDict.from_dict(
input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims, auto_batch_size=auto_batch_size
input_dict,
batch_size=batch_size,
device=device,
batch_dims=batch_dims,
auto_batch_size=auto_batch_size,
)
non_tensordict = {}

return cls.from_tensordict(tensordict=td, non_tensordict=non_tensordict)


def _from_dict_instance(
self, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None
self,
input_dict,
*,
auto_batch_size: bool | None = None,
batch_size=None,
device=None,
batch_dims=None,
):
if batch_dims is not None and batch_size is not None:
raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.")
Expand Down Expand Up @@ -1773,7 +1791,7 @@ def _is_castable(datatype):

if isinstance(value, dict):
if _is_tensor_collection(target_cls):
cast_val = target_cls.from_dict(value)
cast_val = target_cls.from_dict(value, auto_batch_size=False)
self._tensordict.set(
key, cast_val, inplace=inplace, non_blocking=non_blocking
)
Expand Down
10 changes: 7 additions & 3 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def test_from_dict(self):
class MyClass:
a: TensorDictBase

tc = MyClass.from_dict(d)
tc = MyClass.from_dict(d, auto_batch_size=True)
assert isinstance(tc, MyClass)
assert isinstance(tc.a, TensorDict)
assert tc.batch_size == torch.Size([10])
Expand Down Expand Up @@ -2148,7 +2148,9 @@ class TestClass:
my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3]
)

assert (test_class == TestClass.from_dict(test_class.to_dict())).all()
assert (
test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True)
).all()

# Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such
# test_class2 = TestClass(
Expand All @@ -2161,7 +2163,9 @@ class TestClass:
my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3]
)

assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all()
assert not (
test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True)
).all()


@tensorclass(autocast=True)
Expand Down
37 changes: 28 additions & 9 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,9 @@ class MyClass:
("4", "h5py", "my_nested_td", "inner"),
}
)
assert set(td.keys(True, True)) == expected, set(td.keys(True, True)).symmetric_difference(expected)
assert set(td.keys(True, True)) == expected, set(
td.keys(True, True)
).symmetric_difference(expected)

def test_from_dataclass(self):
@dataclass
Expand Down Expand Up @@ -1024,7 +1026,11 @@ def test_from_dict(self, batch_size, batch_dims, device):
)
return
data = TensorDict.from_dict(
data, batch_size=batch_size, batch_dims=batch_dims, device=device, auto_batch_size=True
data,
batch_size=batch_size,
batch_dims=batch_dims,
device=device,
auto_batch_size=True,
)
assert data.device == device
assert "a" in data.keys()
Expand Down Expand Up @@ -6500,7 +6506,7 @@ def recursive_checker(cur_dict):
assert recursive_checker(td_dict)
if td_name == "td_with_non_tensor":
assert td_dict["data"]["non_tensor"] == "some text data"
assert (TensorDict.from_dict(td_dict,auto_batch_size=False) == td).all()
assert (TensorDict.from_dict(td_dict, auto_batch_size=False) == td).all()

def test_to_namedtuple(self, td_name, device):
def is_namedtuple(obj):
Expand Down Expand Up @@ -7865,19 +7871,29 @@ def test_tensordict_batch_size(self):
tensordict = make_tensordict(a=torch.randn(3, 4), auto_batch_size=True)
assert tensordict.batch_size == torch.Size([3, 4])

tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True)
tensordict = make_tensordict(
a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True
)
assert tensordict.batch_size == torch.Size([3, 4])

nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5), auto_batch_size=True) # nested
nested_tensordict = make_tensordict(
c=tensordict, d=torch.randn(3, 5), auto_batch_size=True
) # nested
assert nested_tensordict.batch_size == torch.Size([3])

nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5), auto_batch_size=True) # nested
nested_tensordict = make_tensordict(
c=tensordict, d=torch.randn(4, 5), auto_batch_size=True
) # nested
assert nested_tensordict.batch_size == torch.Size([])

tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True)
tensordict = make_tensordict(
a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True
)
assert tensordict.batch_size == torch.Size([3, 4])

tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True)
tensordict = make_tensordict(
a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True
)
assert tensordict.batch_size == torch.Size([])

tensordict = make_tensordict(
Expand All @@ -7893,7 +7909,10 @@ def test_tensordict_batch_size(self):
@pytest.mark.parametrize("device", get_available_devices())
def test_tensordict_device(self, device):
tensordict = make_tensordict(
a=torch.randn(3, 4), b=torch.randn(3, 4), device=device, auto_batch_size=True
a=torch.randn(3, 4),
b=torch.randn(3, 4),
device=device,
auto_batch_size=True,
)
assert tensordict.device == device
assert tensordict["a"].device == device
Expand Down

0 comments on commit 47d0b3d

Please sign in to comment.