diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 98224ef..76276b1 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -47,41 +47,41 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_0 = self.precision_policy.cast_to_compute_jax(f_0) f_1 = self.precision_policy.cast_to_compute_jax(f_1) + # Apply streaming + f_post_stream = self.stream(f_0) + + # Apply boundary conditions + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.STREAMING: + f_post_stream = bc( + f_0, + f_post_stream, + boundary_mask, + missing_mask, + ) + # Compute the macroscopic variables - rho, u = self.macroscopic(f_0) + rho, u = self.macroscopic(f_post_stream) # Compute equilibrium feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_0, feq, rho, u) + f_post_collision = self.collision(f_post_stream, feq, rho, u) # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_0, f_post_collision, boundary_mask, missing_mask) + f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( - f_0, - f_post_collision, - boundary_mask, - missing_mask, - ) - - # Apply streaming - f_1 = self.stream(f_post_collision) - - # Apply boundary conditions - for bc in self.boundary_conditions: - if bc.implementation_step == ImplementationStep.STREAMING: - f_1 = bc( + f_post_stream, f_post_collision, - f_1, boundary_mask, missing_mask, ) # Copy back to store precision - f_1 = self.precision_policy.cast_to_store_jax(f_1) + f_1 = self.precision_policy.cast_to_store_jax(f_post_collision) return f_1