diff --git a/src/spox/_fields.py b/src/spox/_fields.py index 8d5a67a..aed590b 100644 --- a/src/spox/_fields.py +++ b/src/spox/_fields.py @@ -5,7 +5,7 @@ import enum from collections.abc import Iterable, Iterator, Sequence from dataclasses import Field, dataclass -from typing import Any, Optional, Union +from typing import Any, Optional, Union, get_type_hints from ._attributes import Attr from ._var import Var @@ -61,13 +61,14 @@ def __post_init__(self) -> None: @classmethod def _get_field_type(cls, field: Field) -> VarFieldKind: """Access the kind of the field (single, optional, variadic) based on its type annotation.""" - # The field may be unannotated as per + # The field.type may be unannotated as per + field_type = get_type_hints(cls)[field.name] # from __future__ import annotations - if field.type in [Var, "Var"]: + if field_type == Var: return VarFieldKind.SINGLE - elif field.type in [Optional[Var], "Optional[Var]"]: + elif field_type == Optional[Var]: return VarFieldKind.OPTIONAL - elif field.type in [Sequence[Var], "Sequence[Var]"]: + elif field_type == Sequence[Var]: return VarFieldKind.VARIADIC raise ValueError(f"Bad field type: '{field.type}'.")