Skip to content

Commit

Permalink
[jex] replace extend.random.PRNGImpl with `extend.random.define_prn…
Browse files Browse the repository at this point in the history
…g_impl`

Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 575027938
  • Loading branch information
froystig authored and copybara-github committed Nov 1, 2023
1 parent afd7c2b commit 8cf7fcd
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions haiku/_src/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ def count_splits(_, num):
num = tuple(num) if isinstance(num, Sequence) else (num,)
return jnp.zeros((*num, 13), np.uint32)

differently_shaped_prng_impl = jex.random.PRNGImpl(
# TODO(frostig): remove after JAX 0.4.20, use
# jex.random.define_prng_impl directly
if hasattr(jex.random, "define_prng_impl"):
def_prng_impl = jex.random.define_prng_impl
else:
def_prng_impl = jex.random.PRNGImpl

differently_shaped_prng_impl = def_prng_impl(
# Testing a different key shape to make sure it's accepted by Haiku
key_shape=(13,),
seed=lambda _: jnp.zeros((13,), np.uint32),
Expand All @@ -109,7 +116,7 @@ def count_splits(_, num):
init, _ = transform.transform(base.next_rng_key)
if do_jit:
init = jax.jit(init)
key = jex.random.seed_with_impl(differently_shaped_prng_impl, 42)
key = jax.random.key(42, impl=differently_shaped_prng_impl)
init(key)
self.assertEqual(count, 1)

Expand Down

0 comments on commit 8cf7fcd

Please sign in to comment.