diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 1b2b4ea..8265f44 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,7 +14,7 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Iterable, Mapping, Union +from typing import Any, Iterable, Mapping, Union, Sequence import jax import numpy as np @@ -52,7 +52,7 @@ # Other types. Scalar = Union[float, int] Numeric = Union[Array, Scalar] -Shape = jax.core.Shape +Shape = Sequence[int | Any] PRNGKey = jax.Array PyTreeDef = jax.tree_util.PyTreeDef Device = jax.Device