From 9b9781aeed71c68ce1b6cac0f2d9f079524bbdaf Mon Sep 17 00:00:00 2001 From: lc1021 Date: Mon, 12 Feb 2024 16:49:13 +0000 Subject: [PATCH] Change deprecated jax type --- qdax/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qdax/types.py b/qdax/types.py index 5000869b..90af9e66 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -45,5 +45,5 @@ def __init__(self) -> None: Mask: TypeAlias = jnp.ndarray # Others -RNGKey: TypeAlias = jax.random.KeyArray +RNGKey: TypeAlias = jax.random.Array Metrics: TypeAlias = Dict[str, jnp.ndarray]