From ad3a8f83c6497427a5e0f1fd4461910005b6f7e7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 12 Sep 2023 13:16:48 -0700 Subject: [PATCH] Avoid crashes when constructing signatures PiperOrigin-RevId: 564812691 --- haiku/_src/transform.py | 10 +++------- haiku/_src/transform_test.py | 19 +------------------ 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/haiku/_src/transform.py b/haiku/_src/transform.py index 3a2f29254..2e1d02073 100644 --- a/haiku/_src/transform.py +++ b/haiku/_src/transform.py @@ -81,16 +81,12 @@ def sig_remove_state(s: inspect.Signature) -> inspect.Signature: def sig_add_state(s: inspect.Signature) -> inspect.Signature: """Add hk.State to the return type of a signature.""" if s.return_annotation is inspect.Parameter.empty: - ret = Tuple[Any, hk.State] + ret = Any else: - try: - ret = Tuple[s.return_annotation, hk.State] - except TypeError: - # annotations are not strictly _required_ to contain type information - ret = Tuple[Any, hk.State] + ret = s.return_annotation return inspect.Signature( parameters=list(s.parameters.values()), - return_annotation=ret, + return_annotation=Tuple[ret, hk.State], __validate_parameters__=False) diff --git a/haiku/_src/transform_test.py b/haiku/_src/transform_test.py index 67df1d7e9..be48c7699 100644 --- a/haiku/_src/transform_test.py +++ b/haiku/_src/transform_test.py @@ -15,7 +15,7 @@ """Tests for haiku._src.transform.""" import inspect -from typing import Any, Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Tuple, Union from absl.testing import absltest from absl.testing import parameterized @@ -626,23 +626,6 @@ def expected_f_apply( self.assertEqual( inspect.signature(f.apply), inspect.signature(expected_f_apply)) - def test_signature_unsupported(self): - # unsupported annotations should not error - @transform.transform - def f() -> ...: - raise NotImplementedError - def expected_f_init(rng: Optional[Union[PRNGKey, int]]) -> Params: - del rng - raise NotImplementedError - def expected_f_apply( - params: Optional[Params], rng: Optional[Union[PRNGKey, int]]) -> Any: - del params, rng - raise NotImplementedError - self.assertEqual( - inspect.signature(f.init), inspect.signature(expected_f_init)) - self.assertEqual( - inspect.signature(f.apply), inspect.signature(expected_f_apply)) - def test_init_return_type_is_mutable(self): init, _ = transform.transform(lambda: None) params = init(None)