Skip to content

Commit

Permalink
Fix test_assert_trees_all_equal_prng_keys
Browse files Browse the repository at this point in the history
Converting from JAX PRNG key to numpy array is now an error. This happens when
construcing an error string in the tests.

PiperOrigin-RevId: 696182424
  • Loading branch information
stompchicken authored and ChexDev committed Nov 13, 2024
1 parent 7b2f989 commit 1dc7862
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,8 +1578,17 @@ def err_msg_fn(arr_1, arr_2) -> str:
try:
assert_fn(arr_1, arr_2)
except AssertionError as e:
return (f"{str(e)} \nOriginal dtypes: "
f"{np.asarray(arr_1).dtype}, {np.asarray(arr_2).dtype}")
dtype_1 = (
arr_1.dtype
if isinstance(arr_1, jax.Array)
else np.asarray(arr_1).dtype
)
dtype_2 = (
arr_2.dtype
if isinstance(arr_2, jax.Array)
else np.asarray(arr_1).dtype
)
return f"{str(e)} \nOriginal dtypes: {dtype_1}, {dtype_2}"
return ""

assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees)
Expand Down

0 comments on commit 1dc7862

Please sign in to comment.