diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 06be65580..fdd5d4b83 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -936,6 +936,9 @@ def wrapper( batch_size = torch.Size(()) else: batch_size = kwargs.pop("batch_size", torch.Size(())) + if batch_size is None: + batch_size = torch.Size(()) + if "names" in required_params: names = None else: