From 8cf7fcda858c7fe11bca0983127a1e900c660f61 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 19 Oct 2023 15:56:48 -0700 Subject: [PATCH] [jex] replace `extend.random.PRNGImpl` with `extend.random.define_prng_impl` Instead of exposing a constructor, only expose a function that returns an opaque object representing the defined implementation. This result can still be passed to `jax.random.key` and `wrap_key_data`. PiperOrigin-RevId: 575027938 --- haiku/_src/random_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/haiku/_src/random_test.py b/haiku/_src/random_test.py index 0240cb165..d3f2025f9 100644 --- a/haiku/_src/random_test.py +++ b/haiku/_src/random_test.py @@ -97,7 +97,14 @@ def count_splits(_, num): num = tuple(num) if isinstance(num, Sequence) else (num,) return jnp.zeros((*num, 13), np.uint32) - differently_shaped_prng_impl = jex.random.PRNGImpl( + # TODO(frostig): remove after JAX 0.4.20, use + # jex.random.define_prng_impl directly + if hasattr(jex.random, "define_prng_impl"): + def_prng_impl = jex.random.define_prng_impl + else: + def_prng_impl = jex.random.PRNGImpl + + differently_shaped_prng_impl = def_prng_impl( # Testing a different key shape to make sure it's accepted by Haiku key_shape=(13,), seed=lambda _: jnp.zeros((13,), np.uint32), @@ -109,7 +116,7 @@ def count_splits(_, num): init, _ = transform.transform(base.next_rng_key) if do_jit: init = jax.jit(init) - key = jex.random.seed_with_impl(differently_shaped_prng_impl, 42) + key = jax.random.key(42, impl=differently_shaped_prng_impl) init(key) self.assertEqual(count, 1)