diff --git a/jax_md/partition.py b/jax_md/partition.py index 80034c15..07902008 100644 --- a/jax_md/partition.py +++ b/jax_md/partition.py @@ -132,14 +132,11 @@ def count_cell_filling(R: Array, particle_index = jnp.array(R / cell_size, dtype=jnp.int64) particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1) - filling = jnp.zeros((cell_count,), dtype=jnp.int64) - def count(cell_hash, filling): - count = jnp.sum(particle_hash == cell_hash) - filling = ops.index_update(filling, ops.index[cell_hash], count) - return filling - - return lax.fori_loop(0, cell_count, count, filling) + filling = ops.segment_sum(jnp.ones_like(particle_hash), + particle_hash, + cell_count) + return filling def _is_variable_compatible_with_positions(R: Array) -> bool: