diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 0f70d15..c6da077 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -60,6 +60,6 @@ # TODO(iukemaev, jakevdp): upgrade minimum jax version & remove this condition. if hasattr(jax.typing, 'DTypeLike'): # jax version 0.4.19 or newer - ArrayDType = jax.typing.DTypeLike + ArrayDType = jax.typing.DTypeLike # pylint:disable=invalid-name else: - ArrayDType = Any + ArrayDType = Any # pylint:disable=invalid-name