Skip to content

Commit

Permalink
Avoid crashes when constructing signatures
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 564812691
  • Loading branch information
superbobry authored and copybara-github committed Sep 13, 2023
1 parent b1caab0 commit ad3a8f8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 25 deletions.
10 changes: 3 additions & 7 deletions haiku/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,12 @@ def sig_remove_state(s: inspect.Signature) -> inspect.Signature:
def sig_add_state(s: inspect.Signature) -> inspect.Signature:
"""Add hk.State to the return type of a signature."""
if s.return_annotation is inspect.Parameter.empty:
ret = Tuple[Any, hk.State]
ret = Any
else:
try:
ret = Tuple[s.return_annotation, hk.State]
except TypeError:
# annotations are not strictly _required_ to contain type information
ret = Tuple[Any, hk.State]
ret = s.return_annotation
return inspect.Signature(
parameters=list(s.parameters.values()),
return_annotation=ret,
return_annotation=Tuple[ret, hk.State],
__validate_parameters__=False)


Expand Down
19 changes: 1 addition & 18 deletions haiku/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Tests for haiku._src.transform."""

import inspect
from typing import Any, Mapping, Optional, Tuple, Union
from typing import Mapping, Optional, Tuple, Union

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -626,23 +626,6 @@ def expected_f_apply(
self.assertEqual(
inspect.signature(f.apply), inspect.signature(expected_f_apply))

def test_signature_unsupported(self):
# unsupported annotations should not error
@transform.transform
def f() -> ...:
raise NotImplementedError
def expected_f_init(rng: Optional[Union[PRNGKey, int]]) -> Params:
del rng
raise NotImplementedError
def expected_f_apply(
params: Optional[Params], rng: Optional[Union[PRNGKey, int]]) -> Any:
del params, rng
raise NotImplementedError
self.assertEqual(
inspect.signature(f.init), inspect.signature(expected_f_init))
self.assertEqual(
inspect.signature(f.apply), inspect.signature(expected_f_apply))

def test_init_return_type_is_mutable(self):
init, _ = transform.transform(lambda: None)
params = init(None)
Expand Down

0 comments on commit ad3a8f8

Please sign in to comment.