From 34e8d09c96d57da328f33caaf170936e2fe0c875 Mon Sep 17 00:00:00 2001 From: ChexDev Date: Thu, 12 Oct 2023 14:57:41 -0700 Subject: [PATCH] Fix warning of the form DeprecationWarning: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any] PiperOrigin-RevId: 573020724 --- chex/_src/pytypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 1b2b4ea..0f70d15 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,7 +14,7 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Iterable, Mapping, Union +from typing import Any, Iterable, Mapping, Sequence, Union import jax import numpy as np @@ -52,7 +52,7 @@ # Other types. Scalar = Union[float, int] Numeric = Union[Array, Scalar] -Shape = jax.core.Shape +Shape = Sequence[Union[int, Any]] PRNGKey = jax.Array PyTreeDef = jax.tree_util.PyTreeDef Device = jax.Device