Skip to content

Commit

Permalink
[Refactor] Make @Tensorclass work properly with pyright (#1042)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mxbonn authored Oct 16, 2024
1 parent 59a0ce5 commit fe6db77
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,7 @@
from dataclasses import dataclass
from pathlib import Path
from textwrap import indent
from typing import (
Any,
Callable,
get_type_hints,
List,
overload,
Sequence,
Type,
TypeVar,
)
from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar

import numpy as np
import orjson as json
Expand Down Expand Up @@ -371,20 +362,8 @@ def __call__(self, cls):
return clz


@overload
def tensorclass(autocast: bool = False, frozen: bool = False) -> _tensorclass_dec: ...


@overload
def tensorclass(cls: T) -> T: ...


@overload
def tensorclass(cls: T) -> T: ...


@dataclass_transform()
def tensorclass(*args, **kwargs):
def tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False):
"""A decorator to create :obj:`tensorclass` classes.
``tensorclass`` classes are specialized :func:`dataclasses.dataclass` instances that
Expand Down Expand Up @@ -465,7 +444,17 @@ def tensorclass(*args, **kwargs):
"""
return _tensorclass_dec(*args, **kwargs)

def wrap(cls):
return _tensorclass_dec(autocast, frozen)(cls)

# See if we're being called as @tensorclass or @tensorclass().
if cls is None:
# We're called with parens.
return wrap

# We're called as @tensorclass without parens.
return wrap(cls)


@dataclass_transform()
Expand Down

0 comments on commit fe6db77

Please sign in to comment.