diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 0fa334a..7984dda 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1578,8 +1578,17 @@ def err_msg_fn(arr_1, arr_2) -> str: try: assert_fn(arr_1, arr_2) except AssertionError as e: - return (f"{str(e)} \nOriginal dtypes: " - f"{np.asarray(arr_1).dtype}, {np.asarray(arr_2).dtype}") + dtype_1 = ( + arr_1.dtype + if isinstance(arr_1, jax.Array) + else np.asarray(arr_1).dtype + ) + dtype_2 = ( + arr_2.dtype + if isinstance(arr_2, jax.Array) + else np.asarray(arr_1).dtype + ) + return f"{str(e)} \nOriginal dtypes: {dtype_1}, {dtype_2}" return "" assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees)