diff --git a/examples/interfaces/ldc.py b/examples/interfaces/ldc.py index d75ce8a..e5ca559 100644 --- a/examples/interfaces/ldc.py +++ b/examples/interfaces/ldc.py @@ -120,12 +120,18 @@ def run_ldc(backend, compute_mlup=True): precision_policy=precision_policy, compute_backend=compute_backend, ) + full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream, - boundary_conditions=[equilibrium_bc, half_way_bc], + #boundary_conditions=[equilibrium_bc, half_way_bc, full_way_bc], + boundary_conditions=[half_way_bc, full_way_bc, equilibrium_bc], ) planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( velocity_set=velocity_set, @@ -148,21 +154,22 @@ def run_ldc(backend, compute_mlup=True): ) # Set outlet bc (top x face) - lower_bound = (nr-1, 1, 1) - upper_bound = (nr-1, nr-1, nr-1) + lower_bound = (nr-1, 0, 0) + upper_bound = (nr-1, nr, nr) direction = (-1, 0, 0) boundary_id, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (bottom y face) - lower_bound = (1, 0, 1) + lower_bound = (0, 0, 0) upper_bound = (nr, 0, nr) direction = (0, 1, 0) boundary_id, missing_mask = planar_boundary_masker( @@ -170,13 +177,14 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (top y face) - lower_bound = (1, nr-1, 1) + lower_bound = (0, nr-1, 0) upper_bound = (nr, nr-1, nr) direction = (0, -1, 0) boundary_id, missing_mask = planar_boundary_masker( @@ -184,13 +192,14 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (bottom z face) - lower_bound = (1, 1, 0) + lower_bound = (0, 0, 0) upper_bound = (nr, nr, 0) direction = (0, 0, 1) boundary_id, missing_mask = planar_boundary_masker( @@ -198,13 +207,14 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (top z face) - lower_bound = (1, 1, nr-1) + lower_bound = (0, 0, nr-1) upper_bound = (nr, nr, nr-1) direction = (0, 0, -1) boundary_id, missing_mask = planar_boundary_masker( @@ -212,6 +222,7 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) @@ -226,10 +237,10 @@ def run_ldc(backend, compute_mlup=True): f0 = equilibrium(rho, u) # Time stepping - plot_freq = 512 + plot_freq = 128 save_dir = "ldc" os.makedirs(save_dir, exist_ok=True) - num_steps = nr * 512 + num_steps = nr * 16 start = time.time() for _ in tqdm(range(num_steps)): diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 7f7aadc..92a2c1f 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -6,7 +6,7 @@ from jax import jit, device_count from functools import partial import numpy as np -from enum import Enum +from enum import Enum, auto from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -16,8 +16,8 @@ # Enum for implementation step class ImplementationStep(Enum): - COLLISION = 1 - STREAMING = 2 + COLLISION = auto() + STREAMING = auto() class BoundaryCondition(Operator): diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index f8a4dd7..e47cc26 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -52,7 +52,7 @@ def __init__( def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): boundary = boundary_id == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return lax.select(missing_mask & boundary, f_pre[self.velocity_set.opp_indices], f_post) + return lax.select(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 832a8b2..572f345 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -47,7 +47,7 @@ def jax_implementation( if direction[0] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(id_number) + boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -57,13 +57,13 @@ def jax_implementation( + direction[2] * self.velocity_set.c[2, l] ) if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(True) + mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(True) # y plane elif direction[1] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(id_number) + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -73,13 +73,13 @@ def jax_implementation( + direction[2] * self.velocity_set.c[2, l] ) if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(True) + mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(True) # z plane elif direction[2] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(id_number) + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -89,7 +89,7 @@ def jax_implementation( + direction[2] * self.velocity_set.c[2, l] ) if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(True) + mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(True) return boundary_id, mask @@ -116,15 +116,15 @@ def kernel( # Get local indices if direction[0] != 0: i = lower_bound[0] - start_index[0] - j = plane_i - start_index[1] - k = plane_j - start_index[2] + j = plane_i + lower_bound[1] - start_index[1] + k = plane_j + lower_bound[2] - start_index[2] elif direction[1] != 0: - i = plane_i - start_index[0] + i = plane_i + lower_bound[0] - start_index[0] j = lower_bound[1] - start_index[1] - k = plane_j - start_index[2] + k = plane_j + lower_bound[2] - start_index[2] elif direction[2] != 0: - i = plane_i - start_index[0] - j = plane_j - start_index[1] + i = plane_i + lower_bound[0] - start_index[0] + j = plane_j + lower_bound[1] - start_index[1] k = lower_bound[2] - start_index[2] # Check if in bounds @@ -165,18 +165,18 @@ def warp_implementation( # Get plane dimensions if direction[0] != 0: dim = ( - upper_bound[1] - lower_bound[1] + 1, - upper_bound[2] - lower_bound[2] + 1, + upper_bound[1] - lower_bound[1], + upper_bound[2] - lower_bound[2], ) elif direction[1] != 0: dim = ( - upper_bound[0] - lower_bound[0] + 1, - upper_bound[2] - lower_bound[2] + 1, + upper_bound[0] - lower_bound[0], + upper_bound[2] - lower_bound[2], ) elif direction[2] != 0: dim = ( - upper_bound[0] - lower_bound[0] + 1, - upper_bound[1] - lower_bound[1] + 1, + upper_bound[0] - lower_bound[0], + upper_bound[1] - lower_bound[1], ) # Launch the warp kernel diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index c5fd16e..8bb2568 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -46,10 +46,10 @@ def _streaming_jax_i(f, c): """ if self.velocity_set.d == 2: return jnp.roll( - f, (-c[0], -c[1]), axis=(0, 1) - ) # Negative sign is used to pull the distribution instead of pushing + f, (c[0], c[1]), axis=(0, 1) + ) elif self.velocity_set.d == 3: - return jnp.roll(f, (-c[0], -c[1], -c[2]), axis=(0, 1, 2)) + return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T