diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 62bc239b6..4d53334e1 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -938,6 +938,9 @@ def wrapper( batch_size = kwargs.pop("batch_size", torch.Size(())) if isinstance(batch_size, int): batch_size = (batch_size,) + elif batch_size is None: + batch_size = torch.Size(()) + if "names" in required_params: names = None else: