From 9ebf2441c6e77b9635f6691e70d41a6692429272 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 22 Feb 2024 11:56:22 -0800 Subject: [PATCH] parralization --- xlb/grid/jax_grid.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 42687c8..092ea1d 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -34,6 +34,47 @@ def _initialize_jax_backend(self): self.grid_shape[0] // self.nDevices, ) + self.grid_shape[1:] + + def parallelize_operator(self, operator: Operator): + # TODO: fix this + + # Make parallel function + def _parallel_operator(f): + rightPerm = [ + (i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices) + ] + leftPerm = [ + ((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices) + ] + f = self.func(f) + left_comm, right_comm = ( + f[self.velocity_set.right_indices, :1, ...], + f[self.velocity_set.left_indices, -1:, ...], + ) + left_comm, right_comm = ( + lax.ppermute(left_comm, perm=rightPerm, axis_name="x"), + lax.ppermute(right_comm, perm=leftPerm, axis_name="x"), + ) + f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) + f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) + + return f + + in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) + out_specs = in_specs + + f = shard_map( + self._parallel_func, + mesh=self.grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + )(f) + return f + + + + def create_field(self, name: str, cardinality: int, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape)