Skip to content

Commit

Permalink
modified the sequence of lbm step operators in JAX to match stream-th…
Browse files Browse the repository at this point in the history
…en-collide pattern in Warp.
  • Loading branch information
hsalehipour committed Aug 27, 2024
1 parent fbcf525 commit 91c706a
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,41 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
f_0 = self.precision_policy.cast_to_compute_jax(f_0)
f_1 = self.precision_policy.cast_to_compute_jax(f_1)

# Apply streaming
f_post_stream = self.stream(f_0)

# Apply boundary conditions
for bc in self.boundary_conditions:
if bc.implementation_step == ImplementationStep.STREAMING:
f_post_stream = bc(
f_0,
f_post_stream,
boundary_mask,
missing_mask,
)

# Compute the macroscopic variables
rho, u = self.macroscopic(f_0)
rho, u = self.macroscopic(f_post_stream)

# Compute equilibrium
feq = self.equilibrium(rho, u)

# Apply collision
f_post_collision = self.collision(f_0, feq, rho, u)
f_post_collision = self.collision(f_post_stream, feq, rho, u)

# Apply collision type boundary conditions
for bc in self.boundary_conditions:
f_post_collision = bc.prepare_bc_auxilary_data(f_0, f_post_collision, boundary_mask, missing_mask)
f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_mask, missing_mask)
if bc.implementation_step == ImplementationStep.COLLISION:
f_post_collision = bc(
f_0,
f_post_collision,
boundary_mask,
missing_mask,
)

# Apply streaming
f_1 = self.stream(f_post_collision)

# Apply boundary conditions
for bc in self.boundary_conditions:
if bc.implementation_step == ImplementationStep.STREAMING:
f_1 = bc(
f_post_stream,
f_post_collision,
f_1,
boundary_mask,
missing_mask,
)

# Copy back to store precision
f_1 = self.precision_policy.cast_to_store_jax(f_1)
f_1 = self.precision_policy.cast_to_store_jax(f_post_collision)

return f_1

Expand Down

0 comments on commit 91c706a

Please sign in to comment.