diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 6b3f70247..8906eefd4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -24,16 +24,7 @@ from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import ( - Any, - Callable, - get_type_hints, - List, - overload, - Sequence, - Type, - TypeVar, -) +from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar import numpy as np import orjson as json @@ -371,20 +362,8 @@ def __call__(self, cls): return clz -@overload -def tensorclass(autocast: bool = False, frozen: bool = False) -> _tensorclass_dec: ... - - -@overload -def tensorclass(cls: T) -> T: ... - - -@overload -def tensorclass(cls: T) -> T: ... - - @dataclass_transform() -def tensorclass(*args, **kwargs): +def tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False): """A decorator to create :obj:`tensorclass` classes. ``tensorclass`` classes are specialized :func:`dataclasses.dataclass` instances that @@ -465,7 +444,17 @@ def tensorclass(*args, **kwargs): """ - return _tensorclass_dec(*args, **kwargs) + + def wrap(cls): + return _tensorclass_dec(autocast, frozen)(cls) + + # See if we're being called as @tensorclass or @tensorclass(). + if cls is None: + # We're called with parens. + return wrap + + # We're called as @tensorclass without parens. + return wrap(cls) @dataclass_transform()