diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 06be65580..62bc239b6 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -936,6 +936,8 @@ def wrapper( batch_size = torch.Size(()) else: batch_size = kwargs.pop("batch_size", torch.Size(())) + if isinstance(batch_size, int): + batch_size = (batch_size,) if "names" in required_params: names = None else: @@ -997,7 +999,7 @@ def wrapper( # convert the non tensor data in a regular data kwargs = { - key: value.data if is_non_tensor(value) else value + key: value.data if isinstance(value, NonTensorData) else value for key, value in kwargs.items() } __init__(self, **kwargs) @@ -3259,6 +3261,9 @@ class NonTensorStack(LazyStackedTensorDict): _is_non_tensor: bool = True def __init__(self, *args, **kwargs): + args = [ + arg if is_tensor_collection(arg) else NonTensorData(arg) for arg in args + ] super().__init__(*args, **kwargs) if not all(is_non_tensor(item) for item in self.tensordicts): raise RuntimeError("All tensordicts must be non-tensors.") diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e29778bf9..cb2c9176e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -43,6 +43,7 @@ make_tensordict, PersistentTensorDict, set_get_defaults_to_none, + TensorClass, TensorDict, ) from tensordict._lazy import _CustomOpTensorDict @@ -10955,6 +10956,19 @@ def test_non_tensor_call(self): assert td["a"] == -1 assert td["b"] == 1 + def test_non_tensor_from_list(self): + class X(TensorClass): + non_tensor: str = None + + x = X(batch_size=3) + x.non_tensor = NonTensorStack.from_list(["a", "b", "c"]) + assert x[0].non_tensor == "a" + assert x[1].non_tensor == "b" + + x = X(non_tensor=NonTensorStack("a", "b", "c"), batch_size=3) + assert x[0].non_tensor == "a" + assert x[1].non_tensor == "b" + def test_nontensor_dict(self, non_tensor_data): assert ( TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True)