Skip to content

Commit

Permalink
Avoid calling np.testing.assert_array_equal on random keys
Browse files Browse the repository at this point in the history
Internally this attempts to convert keys to NumPy arrays, which will soon be disallowed by jax-ml/jax#24481.

PiperOrigin-RevId: 689550536
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Oct 24, 2024
1 parent 3423f06 commit 0435990
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
14 changes: 10 additions & 4 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def replace_rng_sequence_state_example():

class BaseTest(parameterized.TestCase):

def assert_keys_equal(self, a, b):
self.assertEqual(jax.random.key_impl(a), jax.random.key_impl(b))
np.testing.assert_array_equal(
jax.random.key_data(a), jax.random.key_data(b)
)

@test_utils.transform_and_run
def test_parameter_reuse(self):
w1 = base.get_parameter("w", [], init=jnp.zeros)
Expand Down Expand Up @@ -646,7 +652,7 @@ def test_prng_reserve(self):
s.reserve(10)
hk_keys = tuple(next(s) for _ in range(10))
jax_keys = tuple(jax.random.split(test_utils.clone(k), num=11)[1:])
jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys)
jax.tree.map(self.assert_keys_equal, hk_keys, jax_keys)

def test_prng_reserve_twice(self):
k = jax.random.PRNGKey(42)
Expand All @@ -657,14 +663,14 @@ def test_prng_reserve_twice(self):
k, subkey1, subkey2 = tuple(jax.random.split(test_utils.clone(k), num=3))
_, subkey3, subkey4 = tuple(jax.random.split(k, num=3))
jax_keys = (subkey1, subkey2, subkey3, subkey4)
jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys)
jax.tree.map(self.assert_keys_equal, hk_keys, jax_keys)

def test_prng_sequence_split(self):
k = jax.random.PRNGKey(42)
s = base.PRNGSequence(k)
hk_keys = s.take(10)
jax_keys = tuple(jax.random.split(test_utils.clone(k), num=11)[1:])
jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys)
jax.tree.map(self.assert_keys_equal, hk_keys, jax_keys)

@parameterized.parameters(42, 28)
def test_with_rng(self, seed):
Expand Down Expand Up @@ -782,7 +788,7 @@ def test_rng_reserve_size(self):
for _ in range(2):
split_key, *expected_keys = jax.random.split(split_key, size+1)
hk_keys = hk.next_rng_keys(size)
np.testing.assert_array_equal(hk_keys, expected_keys)
jax.tree.map(self.assert_keys_equal, hk_keys, expected_keys)

@parameterized.parameters(
base.get_params, base.get_current_state, base.get_initial_state
Expand Down
33 changes: 22 additions & 11 deletions haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,37 @@ def wrapper(*a, **kw):

class StatefulTest(parameterized.TestCase):

def assert_keys_equal(self, a, b):
self.assertEqual(jax.random.key_impl(a), jax.random.key_impl(b))
np.testing.assert_array_equal(
jax.random.key_data(a), jax.random.key_data(b)
)

def assert_keys_not_equal(self, a, b):
self.assertFalse(
(jax.random.key_impl(a) == jax.random.key_impl(b)) and
(jnp.all(jax.random.key_data(a) == jax.random.key_data(b))))

@test_utils.transform_and_run
def test_grad(self):
x = jnp.array(3.)
x = jnp.array(3.0)
g = stateful.grad(SquareModule())(x)
np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

def test_grad_no_transform(self):
x = jnp.array(3.)
x = jnp.array(3.0)
with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
stateful.grad(jnp.square)(x)

@test_utils.transform_and_run
def test_value_and_grad(self):
x = jnp.array(2.)
x = jnp.array(2.0)
y, g = stateful.value_and_grad(SquareModule())(x)
self.assertEqual(y, x ** 2)
self.assertEqual(y, x**2)
np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

def test_value_and_grad_no_transform(self):
x = jnp.array(3.)
x = jnp.array(3.0)
with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
stateful.value_and_grad(jnp.square)(x)

Expand Down Expand Up @@ -645,12 +656,12 @@ def test_vmap_no_split_rng(self):
x = jnp.arange(4)
k1, k2, k3, k4 = f(x)
key_after = base.next_rng_key()
np.testing.assert_array_equal(k1, k2)
np.testing.assert_array_equal(k2, k3)
np.testing.assert_array_equal(k3, k4)
self.assertFalse(np.array_equal(key_before, k1))
self.assertFalse(np.array_equal(key_after, k1))
self.assertFalse(np.array_equal(key_before, key_after))
self.assert_keys_equal(k1, k2)
self.assert_keys_equal(k2, k3)
self.assert_keys_equal(k3, k4)
self.assert_keys_not_equal(key_before, k1)
self.assert_keys_not_equal(key_after, k1)
self.assert_keys_not_equal(key_before, key_after)

@test_utils.transform_and_run
def test_vmap_split_rng(self):
Expand Down

0 comments on commit 0435990

Please sign in to comment.