Skip to content

Commit

Permalink
parralization
Browse files Browse the repository at this point in the history
  • Loading branch information
loliverhennigh committed Feb 22, 2024
1 parent 05b87bf commit 9ebf244
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions xlb/grid/jax_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,47 @@ def _initialize_jax_backend(self):
self.grid_shape[0] // self.nDevices,
) + self.grid_shape[1:]


def parallelize_operator(self, operator: Operator):
# TODO: fix this

# Make parallel function
def _parallel_operator(f):
rightPerm = [
(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)
]
leftPerm = [
((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)
]
f = self.func(f)
left_comm, right_comm = (
f[self.velocity_set.right_indices, :1, ...],
f[self.velocity_set.left_indices, -1:, ...],
)
left_comm, right_comm = (
lax.ppermute(left_comm, perm=rightPerm, axis_name="x"),
lax.ppermute(right_comm, perm=leftPerm, axis_name="x"),
)
f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm)
f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm)

return f

in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,)))
out_specs = in_specs

f = shard_map(
self._parallel_func,
mesh=self.grid.global_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)(f)
return f




def create_field(self, name: str, cardinality: int, callback=None):
# Get shape of the field
shape = (cardinality,) + (self.shape)
Expand Down

0 comments on commit 9ebf244

Please sign in to comment.