diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 3c8eb09d..736b13e9 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -95,7 +95,7 @@ def from_callable_or_type( # superclass. if f in parent_classes and f is not dict: raise _instantiators.UnsupportedTypeAnnotationError( - f"Found a cyclic dataclass dependency with type {f}." + f"Found a cyclic dependency with type {f}." ) # TODO: we are abusing the (minor) distinctions between types, classes, and diff --git a/src/tyro/_resolver.py b/src/tyro/_resolver.py index b28a9454..5e5e5726 100644 --- a/src/tyro/_resolver.py +++ b/src/tyro/_resolver.py @@ -2,6 +2,7 @@ import collections.abc import copy import dataclasses +import inspect import sys import types import warnings @@ -20,7 +21,7 @@ cast, ) -from typing_extensions import Annotated, get_args, get_origin, get_type_hints +from typing_extensions import Annotated, Self, get_args, get_origin, get_type_hints from . import _fields, _unsafe_cache from ._typing import TypeForm @@ -61,8 +62,17 @@ def resolve_generic_types( # We'll ignore NewType when getting the origin + args for generics. origin_cls = get_origin(unwrap_newtype(cls)[0]) + type_from_typevar: Dict[TypeVar, TypeForm[Any]] = {} + + # Support typing.Self. + # We'll do this by pretending that `Self` is a TypeVar... + if hasattr(cls, "__self__"): + self_type = getattr(cls, "__self__") + if inspect.isclass(self_type): + type_from_typevar[cast(TypeVar, Self)] = self_type + else: + type_from_typevar[cast(TypeVar, Self)] = self_type.__class__ - type_from_typevar = {} if ( # Apply some heuristics for generic types. Should revisit this. origin_cls is not None diff --git a/tests/test_py311_generated/test_base_configs_nested_generated.py b/tests/test_py311_generated/test_base_configs_nested_generated.py index ba83318e..8134dbe0 100644 --- a/tests/test_py311_generated/test_base_configs_nested_generated.py +++ b/tests/test_py311_generated/test_base_configs_nested_generated.py @@ -177,3 +177,35 @@ def main(cfg: BaseConfig) -> BaseConfig: ), DataConfig(2), ) + + +def test_pernicious_override(): + """From: https://github.com/nerfstudio-project/nerfstudio/issues/2789 + + Situation where we: + - have a default value in the config class + - override that default value with a subcommand annotation + - override it again with a default instance + """ + assert ( + tyro.cli( + BaseConfig, + default=BaseConfig( + "test", + "test", + ExperimentConfig( + dataset="mnist", + optimizer=AdamOptimizer(), + batch_size=2048, + num_layers=4, + units=64, + train_steps=30_000, + seed=0, + activation=nn.ReLU, + ), + DataConfig(0), + ), + args="small small-data".split(" "), + ).data_config.test + == 0 + ) diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 75dcc1e0..5945752e 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -224,6 +224,57 @@ class Parent: ) == Parent(Nested1(Nested2(B(7)))) +def test_subparser_in_nested_with_metadata_suppressed() -> None: + @dataclasses.dataclass(frozen=True) + class A: + a: tyro.conf.Suppress[int] + + @dataclasses.dataclass + class B: + b: int + a: A = A(5) + + @dataclasses.dataclass + class Nested2: + subcommand: Annotated[ + A, tyro.conf.subcommand("command-a", default=A(7)) + ] | Annotated[B, tyro.conf.subcommand("command-b", default=B(9))] + + @dataclasses.dataclass + class Nested1: + nested2: Nested2 + + @dataclasses.dataclass + class Parent: + nested1: Nested1 + + assert tyro.cli( + Parent, + args="nested1.nested2.subcommand:command-a".split(" "), + ) == Parent(Nested1(Nested2(A(7)))) + assert tyro.cli( + Parent, + args=( + "nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split( + " " + ) + ), + ) == Parent(Nested1(Nested2(A(3)))) + + assert tyro.cli( + Parent, + args="nested1.nested2.subcommand:command-b".split(" "), + ) == Parent(Nested1(Nested2(B(9)))) + assert tyro.cli( + Parent, + args=( + "nested1.nested2.subcommand:command-b --nested1.nested2.subcommand.b 7".split( + " " + ) + ), + ) == Parent(Nested1(Nested2(B(7)))) + + def test_subparser_in_nested_with_metadata_generic() -> None: @dataclasses.dataclass(frozen=True) class A: @@ -1264,7 +1315,7 @@ def instantiate_dataclasses( classes: Tuple[Type[T], ...], args: List[str] ) -> Tuple[T, ...]: return tyro.cli( - tyro.conf.OmitArgPrefixes[ + tyro.conf.OmitArgPrefixes[ # type: ignore # Convert (type1, type2) into Tuple[type1, type2] Tuple.__getitem__( # type: ignore tuple(Annotated[c, tyro.conf.arg(name=c.__name__)] for c in classes) diff --git a/tests/test_py311_generated/test_helptext_generated.py b/tests/test_py311_generated/test_helptext_generated.py index ef47b280..48952a95 100644 --- a/tests/test_py311_generated/test_helptext_generated.py +++ b/tests/test_py311_generated/test_helptext_generated.py @@ -48,6 +48,31 @@ class Helptext: assert "Documentation 3 (default: 3)" in helptext +def test_helptext_sphinx_autodoc_style() -> None: + @dataclasses.dataclass + class Helptext: + """This docstring should be printed as a description.""" + + x: int #: Documentation 1 + + #:Documentation 2 + y: Annotated[int, "ignored"] + z: int = 3 + + helptext = get_helptext(Helptext) + assert cast(str, helptext) in helptext + assert "x INT" in helptext + assert "y INT" in helptext + assert "z INT" in helptext + assert "Documentation 1 (required)" in helptext + assert ": Documentation 1" not in helptext + assert "Documentation 2 (required)" in helptext + assert ":Documentation 2" not in helptext + + # :Documentation 2 should not be applied to `z`. + assert helptext.count("Documentation 2") == 1 + + def test_helptext_from_class_docstring() -> None: @dataclasses.dataclass class Helptext2: diff --git a/tests/test_py311_generated/test_self_type_generated.py b/tests/test_py311_generated/test_self_type_generated.py new file mode 100644 index 00000000..45058b30 --- /dev/null +++ b/tests/test_py311_generated/test_self_type_generated.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Self + +import pytest + +import tyro + + +class TestClass: + def __init__(self, a: int, b: int) -> None: + self.a = a + self.b = b + + def method1(self, x: Self) -> None: + self.effect = x + + @classmethod + def method2(cls, x: Self) -> TestClass: + return x + + # Self is not valid in static methods. + # https://peps.python.org/pep-0673/#valid-locations-for-self + # + # @staticmethod + # def method3(x: Self) -> TestClass: + # return x + + +class TestSubclass(TestClass): + ... + + +def test_method() -> None: + x = TestClass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method1, args=[]) + + assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None + assert x.effect.a == 3 and x.effect.b == 3 + assert isinstance(x, TestClass) + + +def test_classmethod() -> None: + x = TestClass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method2, args=[]) + with pytest.raises(SystemExit): + tyro.cli(TestClass.method2, args=[]) + + y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + y = tyro.cli(TestClass.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + +def test_subclass_method() -> None: + x = TestSubclass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method1, args=[]) + + assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None + assert x.effect.a == 3 and x.effect.b == 3 + assert isinstance(x, TestSubclass) + + y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + +def test_subclass_classmethod() -> None: + x = TestSubclass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method2, args=[]) + with pytest.raises(SystemExit): + tyro.cli(TestSubclass.method2, args=[]) + + y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + y = tyro.cli(TestSubclass.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) diff --git a/tests/test_self_type.py b/tests/test_self_type.py new file mode 100644 index 00000000..0b9ffbfc --- /dev/null +++ b/tests/test_self_type.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest +from typing_extensions import Self + +import tyro + + +class TestClass: + def __init__(self, a: int, b: int) -> None: + self.a = a + self.b = b + + def method1(self, x: Self) -> None: + self.effect = x + + @classmethod + def method2(cls, x: Self) -> TestClass: + return x + + # Self is not valid in static methods. + # https://peps.python.org/pep-0673/#valid-locations-for-self + # + # @staticmethod + # def method3(x: Self) -> TestClass: + # return x + + +class TestSubclass(TestClass): + ... + + +def test_method() -> None: + x = TestClass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method1, args=[]) + + assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None + assert x.effect.a == 3 and x.effect.b == 3 + assert isinstance(x, TestClass) + + +def test_classmethod() -> None: + x = TestClass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method2, args=[]) + with pytest.raises(SystemExit): + tyro.cli(TestClass.method2, args=[]) + + y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + y = tyro.cli(TestClass.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + +def test_subclass_method() -> None: + x = TestSubclass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method1, args=[]) + + assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None + assert x.effect.a == 3 and x.effect.b == 3 + assert isinstance(x, TestSubclass) + + y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + +def test_subclass_classmethod() -> None: + x = TestSubclass(0, 0) + with pytest.raises(SystemExit): + tyro.cli(x.method2, args=[]) + with pytest.raises(SystemExit): + tyro.cli(TestSubclass.method2, args=[]) + + y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass) + + y = tyro.cli(TestSubclass.method2, args="--x.a 3 --x.b 3".split(" ")) + assert y.a == 3 + assert y.b == 3 + assert isinstance(y, TestClass)