Skip to content

Commit

Permalink
Support recursive argument suppression (#217)
Browse files Browse the repository at this point in the history
* Support recursive argument suppression

* ruff
  • Loading branch information
brentyi authored Dec 19, 2024
1 parent 167f1ae commit 735530a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 19 deletions.
9 changes: 9 additions & 0 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ class ArgumentDefinition:
subcommand_prefix: str # Prefix for nesting.
field: _fields.FieldDefinition

def __post_init__(self) -> None:
if (
_markers.Fixed in self.field.markers
or _markers.Suppress in self.field.markers
) and self.field.default in _singleton.MISSING_AND_MISSING_NONPROP:
raise UnsupportedTypeAnnotationError(
f"Field {self.field.intern_name} is missing a default value!"
)

def add_argument(
self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup]
) -> None:
Expand Down
12 changes: 1 addition & 11 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from ._singleton import MISSING_AND_MISSING_NONPROP, MISSING_NONPROP
from ._typing import TypeForm
from .conf import _confstruct, _markers
from .constructors._primitive_spec import (
UnsupportedTypeAnnotationError,
)
from .constructors._primitive_spec import UnsupportedTypeAnnotationError
from .constructors._registry import ConstructorRegistry
from .constructors._struct_spec import (
StructFieldSpec,
Expand Down Expand Up @@ -51,14 +49,6 @@ class FieldDefinition:
# doesn't match the keyword expected by our callable.
call_argname: Any

def __post_init__(self):
if (
_markers.Fixed in self.markers or _markers.Suppress in self.markers
) and self.default in MISSING_AND_MISSING_NONPROP:
raise UnsupportedTypeAnnotationError(
f"Field {self.intern_name} is missing a default value!"
)

@staticmethod
@contextlib.contextmanager
def marker_context(markers: Tuple[_markers.Marker, ...]):
Expand Down
12 changes: 4 additions & 8 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,7 @@ def handle_field(
field.type, field.markers, nondefault_only=True
)

if (
not force_primitive
and _markers.Fixed not in field.markers
and _markers.Suppress not in field.markers
):
if not force_primitive:
# (1) Handle Unions over callables; these result in subparsers.
subparsers_attempt = SubparsersSpecification.from_field(
field,
Expand All @@ -367,9 +363,9 @@ def handle_field(
extern_prefix=_strings.make_field_name([extern_prefix, field.extern_name]),
)
if subparsers_attempt is not None:
if (
not subparsers_attempt.required
and _markers.AvoidSubcommands in field.markers
if not subparsers_attempt.required and (
_markers.AvoidSubcommands in field.markers
or _markers.Suppress in field.markers
):
# Don't make a subparser.
field = field.with_new_type_stripped(type(field.default))
Expand Down
13 changes: 13 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,3 +1545,16 @@ def main(
# Doesn't work in Python 3.7 because of argparse limitations.
assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2)
assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3)


def test_nested_suppress() -> None:
@dataclasses.dataclass
class Bconfig:
b: int = 1

@dataclasses.dataclass
class Aconfig:
a: str = "hello"
b_conf: Bconfig = dataclasses.field(default_factory=Bconfig)

assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()
13 changes: 13 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,3 +1540,16 @@ def main(
# Doesn't work in Python 3.7 because of argparse limitations.
assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2)
assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3)


def test_nested_suppress() -> None:
@dataclasses.dataclass
class Bconfig:
b: int = 1

@dataclasses.dataclass
class Aconfig:
a: str = "hello"
b_conf: Bconfig = dataclasses.field(default_factory=Bconfig)

assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()

0 comments on commit 735530a

Please sign in to comment.