Skip to content

Commit

Permalink
Support overriding behavior hint in argument helptext (#161)
Browse files Browse the repository at this point in the history
* Support overriding behavior hint in argument helptext

* Handle empty string for `help_behavior_hint`

* Support lambdai nput for `help_behavior_hint=`
  • Loading branch information
brentyi authored Sep 18, 2024
1 parent bf3f61a commit 3b436a7
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 59 deletions.
35 changes: 20 additions & 15 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _rule_generate_helptext(
[arg.extern_prefix, arg.field.intern_name]
)

if primary_help is not None and primary_help != "":
if primary_help is not None:
help_parts.append(_rich_tag_if_enabled(primary_help, "helptext"))

if not lowered.required:
Expand Down Expand Up @@ -474,26 +474,31 @@ def _rule_generate_helptext(
else:
default_label = str(default)

# Include default value in helptext. We intentionally don't use the % template
# because the types of all arguments are set to strings, which will cause the
# default to be casted to a string and introduce extra quotation marks.
if lowered.instantiator is None:
# Suffix helptext with some behavior hint, such as the default value of the argument.
help_behavior_hint = arg.field.argconf.help_behavior_hint
if help_behavior_hint is not None:
behavior_hint = (
help_behavior_hint(default_label)
if callable(help_behavior_hint)
else help_behavior_hint
)
elif lowered.instantiator is None:
# Intentionally not quoted via shlex, since this can't actually be passed
# in via the commandline.
default_text = f"(fixed to: {default_label})"
behavior_hint = f"(fixed to: {default_label})"
elif lowered.action == "count":
# Repeatable argument.
default_text = "(repeatable)"
behavior_hint = "(repeatable)"
elif lowered.action == "append" and (
default in _fields.MISSING_SINGLETONS or len(cast(tuple, default)) == 0
):
default_text = "(repeatable)"
behavior_hint = "(repeatable)"
elif lowered.action == "append" and len(cast(tuple, default)) > 0:
assert default is not None # Just for type checker.
default_text = f"(repeatable, appends to: {default_label})"
behavior_hint = f"(repeatable, appends to: {default_label})"
elif arg.field.default is _fields.EXCLUDE_FROM_CALL:
# ^important to use arg.field.default and not the stringified default variable.
default_text = "(unset by default)"
behavior_hint = "(unset by default)"
elif (
_markers._OPTIONAL_GROUP in arg.field.markers
and default in _fields.MISSING_SINGLETONS
Expand All @@ -507,20 +512,20 @@ def _rule_generate_helptext(
# There are some usage details that aren't communicated right now in the
# helptext. For example: all arguments within an optional group without a
# default should be passed in or none at all.
default_text = "(optional)"
behavior_hint = "(optional)"
elif _markers._OPTIONAL_GROUP in arg.field.markers:
# Argument in an optional group, but which also has a default.
default_text = f"(default if used: {default_label})"
behavior_hint = f"(default if used: {default_label})"
else:
default_text = f"(default: {default_label})"
behavior_hint = f"(default: {default_label})"

help_parts.append(_rich_tag_if_enabled(default_text, "helptext_default"))
help_parts.append(_rich_tag_if_enabled(behavior_hint, "helptext_default"))
else:
help_parts.append(_rich_tag_if_enabled("(required)", "helptext_required"))

# Note that the percent symbol needs some extra handling in argparse.
# https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string
lowered.help = " ".join(help_parts).replace("%", "%%")
lowered.help = " ".join([p for p in help_parts if len(p) > 0]).replace("%", "%%")
return


Expand Down
3 changes: 2 additions & 1 deletion src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def make(
None,
None,
help=None,
help_behavior_hint=None,
aliases=None,
prefix_name=True,
constructor_factory=None,
Expand Down Expand Up @@ -288,7 +289,7 @@ def field_list_from_callable(
custom_constructor=False,
markers={_markers.Positional, _markers._PositionalCall},
argconf=_confstruct._ArgConfiguration(
None, None, None, None, None, None
None, None, None, None, None, None, None
),
call_argname="",
)
Expand Down
93 changes: 51 additions & 42 deletions src/tyro/conf/_confstruct.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,55 @@
from __future__ import annotations

import dataclasses
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
from typing import Any, Callable, Sequence, overload

from .._fields import MISSING_NONPROP


@dataclasses.dataclass(frozen=True)
class _SubcommandConfiguration:
name: Optional[str]
name: str | None
default: Any
description: Optional[str]
description: str | None
prefix_name: bool
constructor_factory: Optional[Callable[[], Union[Type, Callable]]]
constructor_factory: Callable[[], type | Callable] | None

def __hash__(self) -> int:
return object.__hash__(self)


@overload
def subcommand(
name: Optional[str] = None,
name: str | None = None,
*,
default: Any = MISSING_NONPROP,
description: Optional[str] = None,
description: str | None = None,
prefix_name: bool = True,
constructor: None = None,
constructor_factory: Optional[Callable[[], Union[Type, Callable]]] = None,
constructor_factory: Callable[[], type | Callable] | None = None,
) -> Any: ...


@overload
def subcommand(
name: Optional[str] = None,
name: str | None = None,
*,
default: Any = MISSING_NONPROP,
description: Optional[str] = None,
description: str | None = None,
prefix_name: bool = True,
constructor: Optional[Union[Type, Callable]] = None,
constructor: type | Callable | None = None,
constructor_factory: None = None,
) -> Any: ...


def subcommand(
name: Optional[str] = None,
name: str | None = None,
*,
default: Any = MISSING_NONPROP,
description: Optional[str] = None,
description: str | None = None,
prefix_name: bool = True,
constructor: Optional[Union[Type, Callable]] = None,
constructor_factory: Optional[Callable[[], Union[Type, Callable]]] = None,
constructor: type | Callable | None = None,
constructor_factory: Callable[[], type | Callable] | None = None,
) -> Any:
"""Returns a metadata object for configuring subcommands with `typing.Annotated`.
Useful for aesthetics.
Expand Down Expand Up @@ -110,51 +112,53 @@ def subcommand(

@dataclasses.dataclass(frozen=True)
class _ArgConfiguration:
# These are all optional by default in order to support multiple tyro.conf.arg()
# annotations. A None value means "don't overwrite the current value".
name: Optional[str]
metavar: Optional[str]
help: Optional[str]
aliases: Optional[Tuple[str, ...]]
prefix_name: Optional[bool]
constructor_factory: Optional[Callable[[], Union[Type, Callable]]]
name: str | None
metavar: str | None
help: str | None
help_behavior_hint: str | Callable[[str], str] | None
aliases: tuple[str, ...] | None
prefix_name: bool | None
constructor_factory: Callable[[], type | Callable] | None


@overload
def arg(
*,
name: Optional[str] = None,
metavar: Optional[str] = None,
help: Optional[str] = None,
aliases: Optional[Sequence[str]] = None,
prefix_name: Optional[bool] = None,
name: str | None = None,
metavar: str | None = None,
help: str | None = None,
help_behavior_hint: str | Callable[[str], str] | None = None,
aliases: Sequence[str] | None = None,
prefix_name: bool | None = None,
constructor: None = None,
constructor_factory: Optional[Callable[[], Union[Type, Callable]]] = None,
constructor_factory: Callable[[], type | Callable] | None = None,
) -> Any: ...


@overload
def arg(
*,
name: Optional[str] = None,
metavar: Optional[str] = None,
help: Optional[str] = None,
aliases: Optional[Sequence[str]] = None,
prefix_name: Optional[bool] = None,
constructor: Optional[Union[Type, Callable]] = None,
name: str | None = None,
metavar: str | None = None,
help: str | None = None,
help_behavior_hint: str | Callable[[str], str] | None = None,
aliases: Sequence[str] | None = None,
prefix_name: bool | None = None,
constructor: type | Callable | None = None,
constructor_factory: None = None,
) -> Any: ...


def arg(
*,
name: Optional[str] = None,
metavar: Optional[str] = None,
help: Optional[str] = None,
aliases: Optional[Sequence[str]] = None,
prefix_name: Optional[bool] = None,
constructor: Optional[Union[Type, Callable]] = None,
constructor_factory: Optional[Callable[[], Union[Type, Callable]]] = None,
name: str | None = None,
metavar: str | None = None,
help: str | None = None,
help_behavior_hint: str | Callable[[str], str] | None = None,
aliases: Sequence[str] | None = None,
prefix_name: bool | None = None,
constructor: type | Callable | None = None,
constructor_factory: Callable[[], type | Callable] | None = None,
) -> Any:
"""Returns a metadata object for fine-grained argument configuration with
`typing.Annotated`. Should typically not be required.
Expand All @@ -173,7 +177,11 @@ def arg(
Arguments:
name: A new name for the argument in the CLI.
metavar: Argument name in usage messages. The type is used by default.
help: Helptext for this argument. The docstring is used by default.
help: Override helptext for this argument. The docstring is used by default.
help_behavior_hint: Override highlighted text that follows the helptext.
Typically used for behavior hints like the `(default: XXX)` or
`(optional)`. Can either be a string or a lambda function whose
input is a formatted default value.
aliases: Aliases for this argument. All strings in the sequence should start
with a hyphen (-). Aliases will _not_ currently be prefixed in a nested
structure, and are not supported for positional arguments.
Expand Down Expand Up @@ -201,6 +209,7 @@ def arg(
name=name,
metavar=metavar,
help=help,
help_behavior_hint=help_behavior_hint,
aliases=tuple(aliases) if aliases is not None else None,
prefix_name=prefix_name,
constructor_factory=constructor_factory
Expand Down
41 changes: 40 additions & 1 deletion tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,44 @@ def test_argconf_help() -> None:
@dataclasses.dataclass
class Struct:
a: Annotated[
int, tyro.conf.arg(name="nice", help="Hello world", metavar="NUMBER")
int,
tyro.conf.arg(
name="nice",
help="Hello world",
help_behavior_hint="(hint)",
metavar="NUMBER",
),
] = 5
b: tyro.conf.Suppress[str] = "7"

def main(x: Any = Struct()) -> int:
return x.a

helptext = get_helptext_with_checks(main)
assert "Hello world" in helptext
assert "INT" not in helptext
assert "NUMBER" in helptext
assert "(hint)" in helptext
assert "(default: 5)" not in helptext
assert "--x.a" not in helptext
assert "--x.nice" in helptext
assert "--x.b" not in helptext

assert tyro.cli(main, args=[]) == 5
assert tyro.cli(main, args=["--x.nice", "3"]) == 3


def test_argconf_help_behavior_hint_lambda() -> None:
@dataclasses.dataclass
class Struct:
a: Annotated[
int,
tyro.conf.arg(
name="nice",
help="Hello world",
help_behavior_hint=lambda default: f"(default value: {default})",
metavar="NUMBER",
),
] = 5
b: tyro.conf.Suppress[str] = "7"

Expand All @@ -638,6 +675,8 @@ def main(x: Any = Struct()) -> int:
assert "Hello world" in helptext
assert "INT" not in helptext
assert "NUMBER" in helptext
assert "(default value: 5)" in helptext
assert "(default: 5)" not in helptext
assert "--x.a" not in helptext
assert "--x.nice" in helptext
assert "--x.b" not in helptext
Expand Down

0 comments on commit 3b436a7

Please sign in to comment.