diff --git a/src/tyro/_unsafe_cache.py b/src/tyro/_unsafe_cache.py index 8b3fc37c..1c90e120 100644 --- a/src/tyro/_unsafe_cache.py +++ b/src/tyro/_unsafe_cache.py @@ -1,4 +1,5 @@ import functools +import sys from typing import Any, Callable, Dict, List, TypeVar CallableType = TypeVar("CallableType", bound=Callable) @@ -23,11 +24,20 @@ def unsafe_cache(maxsize: int) -> Callable[[CallableType], CallableType]: def inner(f: CallableType) -> CallableType: @functools.wraps(f) def wrapped_f(*args, **kwargs): - key = tuple(unsafe_hash(arg) for arg in args) + tuple( - ("__kwarg__", k, unsafe_hash(v)) for k, v in kwargs.items() + key = tuple(_make_key(arg) for arg in args) + tuple( + ("__kwarg__", k, _make_key(v)) for k, v in kwargs.items() ) if key in local_cache: + # Fuzzy check for cache collisions if called from a pytest test. + if "pytest" in sys.modules: + import random + + if random.random() < 0.5: + a = f(*args, **kwargs) + b = local_cache[key] + assert a == b or str(a) == str(b) + return local_cache[key] out = f(*args, **kwargs) @@ -41,8 +51,12 @@ def wrapped_f(*args, **kwargs): return inner -def unsafe_hash(obj: Any) -> Any: +def _make_key(obj: Any) -> Any: + """Some context: https://github.com/brentyi/tyro/issues/214""" try: - return hash(obj) + # If the object is hashable, we can use it as a key directly. + hash(obj) + return obj except TypeError: - return id(obj) + # If the object is not hashable, we'll use assume the type/id are unique... + return type(obj), id(obj) diff --git a/tests/test_dcargs.py b/tests/test_dcargs.py index b82543d1..0481f44c 100644 --- a/tests/test_dcargs.py +++ b/tests/test_dcargs.py @@ -20,7 +20,14 @@ ) import pytest -from typing_extensions import Annotated, Final, Literal, TypeAlias +from typing_extensions import ( + Annotated, + Final, + Literal, + Protocol, + TypeAlias, + runtime_checkable, +) import tyro @@ -953,3 +960,41 @@ class NumericTower: assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2) with pytest.raises(SystemExit): tyro.cli(NumericTower, args="--d False".split(" ")) + + +def test_runtime_checkable_edge_case() -> None: + """From Kevin Black: https://github.com/brentyi/tyro/issues/214""" + + @runtime_checkable + class DummyProtocol(Protocol): + pass + + @dataclasses.dataclass(frozen=True) + class SubConfigA: + pass + + @dataclasses.dataclass(frozen=True) + class SubConfigB: + pass + + @dataclasses.dataclass + class Config: + subconfig: DummyProtocol + + CONFIGS = { + "a": Config(subconfig=SubConfigA()), + "b": Config(subconfig=SubConfigB()), + } + + assert ( + tyro.extras.overridable_config_cli( + {k: (k, v) for k, v in CONFIGS.items()}, args=["a"] + ).subconfig + == SubConfigA() + ) + assert ( + tyro.extras.overridable_config_cli( + {k: (k, v) for k, v in CONFIGS.items()}, args=["b"] + ).subconfig + == SubConfigB() + ) diff --git a/tests/test_py311_generated/test_dcargs_generated.py b/tests/test_py311_generated/test_dcargs_generated.py index 909b3a03..fdcbcfab 100644 --- a/tests/test_py311_generated/test_dcargs_generated.py +++ b/tests/test_py311_generated/test_dcargs_generated.py @@ -16,10 +16,12 @@ List, Literal, Optional, + Protocol, Text, Tuple, TypeAlias, TypeVar, + runtime_checkable, ) import pytest @@ -955,3 +957,41 @@ class NumericTower: assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2) with pytest.raises(SystemExit): tyro.cli(NumericTower, args="--d False".split(" ")) + + +def test_runtime_checkable_edge_case() -> None: + """From Kevin Black: https://github.com/brentyi/tyro/issues/214""" + + @runtime_checkable + class DummyProtocol(Protocol): + pass + + @dataclasses.dataclass(frozen=True) + class SubConfigA: + pass + + @dataclasses.dataclass(frozen=True) + class SubConfigB: + pass + + @dataclasses.dataclass + class Config: + subconfig: DummyProtocol + + CONFIGS = { + "a": Config(subconfig=SubConfigA()), + "b": Config(subconfig=SubConfigB()), + } + + assert ( + tyro.extras.overridable_config_cli( + {k: (k, v) for k, v in CONFIGS.items()}, args=["a"] + ).subconfig + == SubConfigA() + ) + assert ( + tyro.extras.overridable_config_cli( + {k: (k, v) for k, v in CONFIGS.items()}, args=["b"] + ).subconfig + == SubConfigB() + ) diff --git a/tests/test_py311_generated/test_unsafe_cache_generated.py b/tests/test_py311_generated/test_unsafe_cache_generated.py index 35cb4174..f93d086c 100644 --- a/tests/test_py311_generated/test_unsafe_cache_generated.py +++ b/tests/test_py311_generated/test_unsafe_cache_generated.py @@ -9,23 +9,24 @@ def f(dummy: int): nonlocal x x += 1 + # >= is because of fuzz testing inside of unsafe_cache f(0) f(0) f(0) - assert x == 1 + assert x >= 1 f(1) f(1) f(1) - assert x == 2 + assert x >= 2 f(0) f(0) f(0) - assert x == 2 + assert x >= 2 f(2) f(2) f(2) - assert x == 3 + assert x >= 3 f(0) f(0) f(0) - assert x == 4 + assert x >= 4 diff --git a/tests/test_unsafe_cache.py b/tests/test_unsafe_cache.py index 35cb4174..f93d086c 100644 --- a/tests/test_unsafe_cache.py +++ b/tests/test_unsafe_cache.py @@ -9,23 +9,24 @@ def f(dummy: int): nonlocal x x += 1 + # >= is because of fuzz testing inside of unsafe_cache f(0) f(0) f(0) - assert x == 1 + assert x >= 1 f(1) f(1) f(1) - assert x == 2 + assert x >= 2 f(0) f(0) f(0) - assert x == 2 + assert x >= 2 f(2) f(2) f(2) - assert x == 3 + assert x >= 3 f(0) f(0) f(0) - assert x == 4 + assert x >= 4