diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 4f1e86d..4aafe56 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1226,7 +1226,7 @@ def _assert_fn(path, leaf): # Check that the leaf is a ShardedArray. if isinstance(leaf, jax.Array): if _check_sharding(leaf): - shards = tuple(buf.device() for buf in leaf.device_buffers) + shards = tuple(shard.device for shard in leaf.addressable_shards) if shards != devices: errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' is sharded "