diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 220489d..7d07e05 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -3,13 +3,22 @@ from xlb.precision_policy import PrecisionPolicy from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import FullwayBounceBackBC, ZouHeBC, RegularizedBC, EquilibriumBC, DoNothingBC, ExtrapolationOutflowBC +from xlb.operator.boundary_condition import ( + FullwayBounceBackBC, + HalfwayBounceBackBC, + ZouHeBC, + RegularizedBC, + EquilibriumBC, + DoNothingBC, + ExtrapolationOutflowBC, +) from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np import jax.numpy as jnp +import time class FlowOverSphere: @@ -25,7 +34,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -34,7 +43,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): def _setup(self, omega): self.setup_boundary_conditions() - self.setup_boundary_masks() + self.setup_boundary_masker() self.initialize_fields() self.setup_stepper(omega) @@ -69,19 +78,20 @@ def setup_boundary_conditions(self): # bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) # bc_outlet = DoNothingBC(indices=outlet) bc_outlet = ExtrapolationOutflowBC(indices=outlet) - bc_sphere = FullwayBounceBackBC(indices=sphere) + bc_sphere = HalfwayBounceBackBC(indices=sphere) + self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. # TODO: how to ensure about this behind in the src code? - def setup_boundary_masks(self): + def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) + self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -90,12 +100,16 @@ def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") def run(self, num_steps, post_process_interval=100): + start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: self.post_process(i) + end_time = time.time() + print(f"Completing {i} iterations. Time elapsed for 1000 LBM steps in {end_time - start_time:.6f} seconds.") + start_time = time.time() def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing @@ -114,7 +128,7 @@ def post_process(self, i): fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]} - save_fields_vtk(fields, timestep=i) + # save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index b4540a9..16fb4f9 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -24,7 +24,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -33,7 +33,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): def _setup(self, omega): self.setup_boundary_conditions() - self.setup_boundary_masks() + self.setup_boundary_masker() self.initialize_fields() self.setup_stepper(omega) @@ -51,13 +51,13 @@ def setup_boundary_conditions(self): bc_walls = HalfwayBounceBackBC(indices=walls) self.boundary_conditions = [bc_top, bc_walls] - def setup_boundary_masks(self): + def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask) + self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -67,7 +67,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 460b217..8395579 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -9,14 +9,18 @@ FullwayBounceBackBC, EquilibriumBC, DoNothingBC, + RegularizedBC, + HalfwayBounceBackBC, ExtrapolationOutflowBC, ) +from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker +from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np import jax.numpy as jnp +import matplotlib.pyplot as plt class WindTunnel3D: @@ -32,18 +36,25 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] # Setup the simulation BC, its initial conditions, and the stepper - self._setup(omega, wind_speed) - - def _setup(self, omega, wind_speed): - self.setup_boundary_conditions(wind_speed) - self.setup_boundary_masks() + self.wind_speed = wind_speed + self.omega = omega + self._setup() + + # Make list to store drag coefficients + self.time_steps = [] + self.drag_coefficients = [] + self.lift_coefficients = [] + + def _setup(self): + self.setup_boundary_conditions() + self.setup_boundary_masker() self.initialize_fields() - self.setup_stepper(omega) + self.setup_stepper() def voxelize_stl(self, stl_filename, length_lbm_unit): mesh = trimesh.load_mesh(stl_filename, process=False) @@ -64,46 +75,66 @@ def define_boundary_indices(self): for i in range(self.velocity_set.d) ] + # Load the mesh stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" - grid_size_x = self.grid_shape[0] - car_length_lbm_unit = grid_size_x / 4 - car_voxelized, pitch = self.voxelize_stl(stl_filename, car_length_lbm_unit) - - # car_area = np.prod(car_voxelized.shape[1:]) - tx, ty, _ = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape - shift = [tx // 4, ty // 2, 0] - car = np.argwhere(car_voxelized) + shift - car = np.array(car).T - car = [tuple(car[i]) for i in range(self.velocity_set.d)] + mesh = trimesh.load_mesh(stl_filename, process=False) + mesh_vertices = mesh.vertices + + # Transform the mesh points to be located in the right position in the wind tunnel + mesh_vertices -= mesh_vertices.min(axis=0) + mesh_extents = mesh_vertices.max(axis=0) + length_phys_unit = mesh_extents.max() + length_lbm_unit = self.grid_shape[0] / 4 + dx = length_phys_unit / length_lbm_unit + shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0.0]) + car = mesh_vertices + shift + self.grid_spacing = dx + self.car_cross_section = np.prod(mesh_extents[1:]) / dx**2 return inlet, outlet, walls, car - def setup_boundary_conditions(self, wind_speed): + def setup_boundary_conditions(self): inlet, outlet, walls, car = self.define_boundary_indices() - bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) + bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet) + # bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = FullwayBounceBackBC(indices=car) + bc_car = HalfwayBounceBackBC(mesh_vertices=car) + # bc_car = FullwayBounceBackBC(mesh_vertices=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] - def setup_boundary_masks(self): + def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) + mesh_boundary_masker = MeshBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.backend, + ) + bclist_other = self.boundary_conditions[:-1] + bc_mesh = self.boundary_conditions[-1] + dx = self.grid_spacing + origin, spacing = (0, 0, 0), (dx, dx, dx) + self.boundary_map, self.missing_mask = indices_boundary_masker(bclist_other, self.boundary_map, self.missing_mask) + self.boundary_map, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.boundary_map, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) - def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") + def setup_stepper(self): + self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") def run(self, num_steps, print_interval, post_process_interval=100): + # Setup the operator for computing surface forces at the interface of the specified BC + bc_car = self.boundary_conditions[-1] + self.momentum_transfer = MomentumTransfer(bc_car) + start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: @@ -133,6 +164,49 @@ def post_process(self, i): save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) + # Compute lift and drag + boundary_force = self.momentum_transfer(self.f_0, self.boundary_map, self.missing_mask) + drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane + lift = boundary_force[2] + c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) + c_l = 2.0 * lift / (self.wind_speed**2 * self.car_cross_section) + self.drag_coefficients.append(c_d) + self.lift_coefficients.append(c_l) + self.time_steps.append(i) + + # Save monitor plot + self.plot_drag_coefficient() + return + + def plot_drag_coefficient(self): + # Compute moving average of drag coefficient, 100, 1000, 10000 + drag_coefficients = np.array(self.drag_coefficients) + self.drag_coefficients_ma_10 = np.convolve(drag_coefficients, np.ones(10) / 10, mode="valid") + self.drag_coefficients_ma_100 = np.convolve(drag_coefficients, np.ones(100) / 100, mode="valid") + self.drag_coefficients_ma_1000 = np.convolve(drag_coefficients, np.ones(1000) / 1000, mode="valid") + self.drag_coefficients_ma_10000 = np.convolve(drag_coefficients, np.ones(10000) / 10000, mode="valid") + self.drag_coefficients_ma_100000 = np.convolve(drag_coefficients, np.ones(100000) / 100000, mode="valid") + + # Plot drag coefficient + plt.plot(self.time_steps, drag_coefficients, label="Raw") + if len(self.time_steps) > 10: + plt.plot(self.time_steps[9:], self.drag_coefficients_ma_10, label="MA 10") + if len(self.time_steps) > 100: + plt.plot(self.time_steps[99:], self.drag_coefficients_ma_100, label="MA 100") + if len(self.time_steps) > 1000: + plt.plot(self.time_steps[999:], self.drag_coefficients_ma_1000, label="MA 1,000") + if len(self.time_steps) > 10000: + plt.plot(self.time_steps[9999:], self.drag_coefficients_ma_10000, label="MA 10,000") + if len(self.time_steps) > 100000: + plt.plot(self.time_steps[99999:], self.drag_coefficients_ma_100000, label="MA 100,000") + + plt.ylim(-1.0, 1.0) + plt.legend() + plt.xlabel("Time step") + plt.ylabel("Drag coefficient") + plt.savefig("drag_coefficient_ma.png") + plt.close() + if __name__ == "__main__": # Grid parameters diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py index 68d1c2b..7214130 100644 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -75,7 +75,7 @@ def warp_implementation(self, rho, u, vel): u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_mask = grid.create_field(cardinality=1, dtype=wp.uint8) + boundary_map = grid.create_field(cardinality=1, dtype=wp.uint8) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators @@ -154,23 +154,19 @@ def warp_implementation(self, rho, u, vel): indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_mask, missing_mask, (0, 0, 0)) + boundary_map, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_map, missing_mask, (0, 0, 0)) # Set inlet bc lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_mask, missing_mask, (0, 0, 0) - ) + boundary_map, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_map, missing_mask, (0, 0, 0)) # Set outlet bc lower_bound = (nr - 1, 0, 0) upper_bound = (nr - 1, nr, nr) direction = (-1, 0, 0) - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_mask, missing_mask, (0, 0, 0) - ) + boundary_map, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_map, missing_mask, (0, 0, 0)) # Set initial conditions rho, u = initializer(rho, u, vel) @@ -185,7 +181,7 @@ def warp_implementation(self, rho, u, vel): num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_mask, missing_mask, _) + f1 = stepper(f0, f1, boundary_map, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) @@ -195,7 +191,7 @@ def warp_implementation(self, rho, u, vel): plt.imshow(u[0, :, nr // 2, :].numpy()) plt.colorbar() plt.subplot(1, 2, 2) - plt.imshow(boundary_mask[0, :, nr // 2, :].numpy()) + plt.imshow(boundary_map[0, :, nr // 2, :].numpy()) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py index 9ed7fa6..c5b40b7 100644 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -113,7 +113,7 @@ def run_taylor_green(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + boundary_map = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -149,10 +149,10 @@ def run_taylor_green(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, boundary_mask, missing_mask, _) + f1 = stepper(f0, f1, boundary_map, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, boundary_mask, missing_mask, _) + f0 = stepper(f0, boundary_map, missing_mask, _) # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 74bfa04..602e741 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -42,9 +42,9 @@ def setup_simulation(args): def create_grid_and_fields(cube_edge): grid_shape = (cube_edge, cube_edge, cube_edge) - grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) + grid, f_0, f_1, missing_mask, boundary_map = create_nse_fields(grid_shape) - return grid, f_0, f_1, missing_mask, boundary_mask + return grid, f_0, f_1, missing_mask, boundary_map def define_boundary_indices(grid): @@ -67,7 +67,7 @@ def setup_boundary_conditions(grid): return [bc_top, bc_walls] -def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): +def run(f_0, f_1, backend, grid, boundary_map, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): start_time = time.time() for i in range(num_steps): - f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_1 = stepper(f_0, f_1, boundary_map, missing_mask, i) f_0, f_1 = f_1, f_0 wp.synchronize() @@ -98,10 +98,10 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time): def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) - grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) + grid, f_0, f_1, missing_mask, boundary_map = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) + elapsed_time = run(f_0, f_1, backend, grid, boundary_map, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 3e50fdb..9d2e4ff 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -32,7 +32,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -58,7 +58,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_map, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -66,7 +66,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) + f = equilibrium_bc(f_pre, f_post, boundary_map, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 0274eba..917e7e4 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -31,7 +31,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -58,7 +58,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_map, missing_mask, start_index=None) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -66,7 +66,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) + f = equilibrium_bc(f_pre, f_post, boundary_map, missing_mask) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 1b7edc2..2fe0b40 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -34,7 +34,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -54,7 +54,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([fullway_bc], boundary_map, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0) # Generate a random field with the same shape @@ -67,7 +67,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) + f = fullway_bc(f_pre, f_post, boundary_map, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index da76f5e..b25d39e 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -34,7 +34,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -54,7 +54,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([fullway_bc], boundary_map, missing_mask, start_index=None) # Generate a random field with the same shape random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) @@ -65,7 +65,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) + f_pre = fullway_bc(f_pre, f_post, boundary_map, missing_mask) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index ddbc761..af121d3 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -34,7 +34,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -56,26 +56,26 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_mask, missing_mask = indices_boundary_masker([test_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([test_bc], boundary_map, missing_mask, start_index=None) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype - assert boundary_mask.dtype == xlb.Precision.UINT8.jax_dtype + assert boundary_map.dtype == xlb.Precision.UINT8.jax_dtype - assert boundary_mask.shape == (1,) + grid_shape + assert boundary_map.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert jnp.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask = boundary_mask.at[0, indices[0], indices[1]].set(0) - assert jnp.all(boundary_mask == 0) + assert jnp.all(boundary_map[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map = boundary_map.at[0, indices[0], indices[1]].set(0) + assert jnp.all(boundary_map == 0) if dim == 3: - assert jnp.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask = boundary_mask.at[0, indices[0], indices[1], indices[2]].set(0) - assert jnp.all(boundary_mask == 0) + assert jnp.all(boundary_map[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map = boundary_map.at[0, indices[0], indices[1], indices[2]].set(0) + assert jnp.all(boundary_map == 0) if __name__ == "__main__": diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 6919ba9..4d02540 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -32,7 +32,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -54,33 +54,33 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_mask, missing_mask = indices_boundary_masker( + boundary_map, missing_mask = indices_boundary_masker( [test_bc], - boundary_mask, + boundary_map, missing_mask, start_index=(0, 0, 0) if dim == 3 else (0, 0), ) assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype - assert boundary_mask.dtype == xlb.Precision.UINT8.wp_dtype + assert boundary_map.dtype == xlb.Precision.UINT8.wp_dtype - boundary_mask = boundary_mask.numpy() + boundary_map = boundary_map.numpy() missing_mask = missing_mask.numpy() - assert boundary_mask.shape == (1,) + grid_shape + assert boundary_map.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert np.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask[0, indices[0], indices[1]] = 0 - assert np.all(boundary_mask == 0) + assert np.all(boundary_map[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map[0, indices[0], indices[1]] = 0 + assert np.all(boundary_map == 0) if dim == 3: - assert np.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask[0, indices[0], indices[1], indices[2]] = 0 - assert np.all(boundary_mask == 0) + assert np.all(boundary_map[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map[0, indices[0], indices[1], indices[2]] = 0 + assert np.all(boundary_map == 0) if __name__ == "__main__": diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index a42c6ac..96befa6 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -14,6 +14,6 @@ def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compu f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) - boundary_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) + boundary_map = grid.create_field(cardinality=1, dtype=Precision.UINT8) - return grid, f_0, f_1, missing_mask, boundary_mask + return grid, f_0, f_1, missing_mask, boundary_map diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index e8a91ee..6e8d317 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -35,6 +35,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): super().__init__( ImplementationStep.STREAMING, @@ -42,12 +43,13 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): @@ -65,7 +67,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.uint8), ): # Get the global index @@ -73,10 +75,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(DoNothingBC.id): + if _boundary_map == wp.uint8(DoNothingBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -91,7 +93,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -99,10 +101,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(DoNothingBC.id): + if _boundary_map == wp.uint8(DoNothingBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -117,11 +119,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 27d5eb2..6853c0e 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -40,6 +40,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): # Store the equilibrium information self.rho = rho @@ -56,15 +57,16 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) new_shape = feq.shape + (1,) * self.velocity_set.d feq = lax.broadcast_in_dim(feq, new_shape, [0]) - boundary = boundary_mask == self.id + boundary = boundary_map == self.id return jnp.where(boundary, feq, f_post) @@ -90,7 +92,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -98,10 +100,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(EquilibriumBC.id): + if _boundary_map == wp.uint8(EquilibriumBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -116,7 +118,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -124,10 +126,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(EquilibriumBC.id): + if _boundary_map == wp.uint8(EquilibriumBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -142,11 +144,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index d16068b..55f094d 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -46,6 +46,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): # Call the parent constructor super().__init__( @@ -54,6 +55,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) # find and store the normal vector using indices @@ -92,13 +94,13 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): """ Prepare the auxilary distribution functions for the boundary condition. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) - boundary = boundary_mask == self.id + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -121,8 +123,8 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -193,7 +195,7 @@ def prepare_bc_auxilary_data( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -201,11 +203,11 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -218,7 +220,7 @@ def kernel2d( _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] # Apply the boundary condition - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both # collision and streaming? _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) @@ -234,7 +236,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -242,11 +244,11 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -259,7 +261,7 @@ def kernel3d( _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] # Apply the boundary condition - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) @@ -275,11 +277,11 @@ def kernel3d( return (functional, prepare_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 6272aca..6af4226 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -35,6 +35,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): super().__init__( ImplementationStep.COLLISION, @@ -42,12 +43,13 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) @@ -75,17 +77,17 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Check if the boundary is active - if _boundary_id == wp.uint8(FullwayBounceBackBC.id): + if _boundary_map == wp.uint8(FullwayBounceBackBC.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -100,7 +102,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -108,10 +110,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Check if the boundary is active - if _boundary_id == wp.uint8(FullwayBounceBackBC.id): + if _boundary_map == wp.uint8(FullwayBounceBackBC.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -126,11 +128,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index a363479..5c001d9 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -37,6 +37,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): # Call the parent constructor super().__init__( @@ -45,12 +46,13 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -86,7 +88,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -94,10 +96,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_map == wp.uint8(HalfwayBounceBackBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -112,7 +114,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -120,10 +122,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_map == wp.uint8(HalfwayBounceBackBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -138,11 +140,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 84dbbf9..b74c0b1 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -49,6 +49,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): # Call the parent constructor super().__init__( @@ -58,6 +59,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) # The operator to compute the momentum flux @@ -103,9 +105,9 @@ def regularize_fpop(self, fpop, feq): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): # creat a mask to slice boundary cells - boundary = boundary_mask == self.id + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -323,7 +325,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -331,10 +333,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -349,7 +351,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -357,10 +359,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -383,11 +385,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 61783f8..56c6868 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -43,6 +43,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). @@ -59,6 +60,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_vertices, ) # Set the prescribed value for pressure or velocity @@ -154,9 +156,9 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): # creat a mask to slice boundary cells - boundary = boundary_mask == self.id + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -328,7 +330,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -336,10 +338,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -354,7 +356,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -362,10 +364,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -388,11 +390,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 29ca2db..2cf6d67 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -33,6 +33,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_vertices=None, ): velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy @@ -42,6 +43,7 @@ def __init__( # Set the BC indices self.indices = indices + self.mesh_vertices = mesh_vertices # Set the implementation step self.implementation_step = implementation_step @@ -64,14 +66,14 @@ def prepare_bc_auxilary_data( def _get_thread_data_2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), index: wp.vec2i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] + _boundary_map = boundary_map[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -83,20 +85,20 @@ def _get_thread_data_2d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_id, _missing_mask + return _f_pre, _f_post, _boundary_map, _missing_mask @wp.func def _get_thread_data_3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), index: wp.vec3i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] + _boundary_map = boundary_map[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -108,7 +110,7 @@ def _get_thread_data_3d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_id, _missing_mask + return _f_pre, _f_post, _boundary_map, _missing_mask # Construct some helper warp functions for getting tid data if self.compute_backend == ComputeBackend.WARP: @@ -117,7 +119,7 @@ def _get_thread_data_3d( self.prepare_bc_auxilary_data = prepare_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): """ A placeholder function for prepare the auxilary distribution functions for the boundary condition. currently being called after collision only. diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index 262e638..20b16b5 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,6 +1,6 @@ from xlb.operator.boundary_masker.indices_boundary_masker import ( IndicesBoundaryMasker as IndicesBoundaryMasker, ) -from xlb.operator.boundary_masker.stl_boundary_masker import ( - STLBoundaryMasker as STLBoundaryMasker, +from xlb.operator.boundary_masker.mesh_boundary_masker import ( + MeshBoundaryMasker as MeshBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 7548cf0..208f50f 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -26,10 +26,27 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) + def are_indices_in_interior(self, indices, shape): + """ + Check if all 2D or 3D indices are inside the bounds of the domain with the given shape and not + at its boundary. + + :param indices: List of tuples, where each tuple contains indices for each dimension. + :param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D. + :return: Boolean flag is_inside indicating whether all indices are inside the bounds. + """ + # Ensure that the number of dimensions in indices matches the domain shape + dim = len(shape) + if len(indices) != dim: + raise ValueError(f"Indices tuple must have {dim} dimensions to match the domain shape.") + + # Check if all indices are within the bounds + return all(0 < idx < shape[d] - 1 for d, idx_list in enumerate(indices) for idx in idx_list) + @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): + def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=None): # Pad the missing mask to create a grid mask to identify out of bound boundaries # Set padded regin to True (i.e. boundary) dim = missing_mask.ndim - 1 @@ -45,27 +62,38 @@ def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=No if start_index is None: start_index = (0,) * dim - bid = boundary_mask[0] + bmap = boundary_map[0] + domain_shape = bmap.shape for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" + assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id - local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] + local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] - bid = bid.at[tuple(local_indices)].set(id_number) - # if dim == 2: - # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) - # if dim == 3: - # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + bmap = bmap.at[tuple(local_indices)].set(id_number) + if self.are_indices_in_interior(bc.indices, domain_shape): + # checking if all indices associated with this BC are in the interior of the domain (not at the boundary). + # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain. + if dim == 2: + grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) + if dim == 3: + grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + + # Assign the boundary id to the push indices + push_indices = local_indices[:, :, None] + self.velocity_set.c[:, None, :] + push_indices = push_indices.reshape(3, -1) + bmap = bmap.at[tuple(push_indices)].set(id_number) + # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - boundary_mask = boundary_mask.at[0].set(bid) + boundary_map = boundary_map.at[0].set(bmap) grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] if dim == 3: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] - return boundary_mask, missing_mask + return boundary_map, missing_mask def _construct_warp(self): # Make constants for warp @@ -77,7 +105,8 @@ def _construct_warp(self): def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), - boundary_mask: wp.array3d(dtype=wp.uint8), + is_interior: wp.array1d(dtype=wp.bool), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): @@ -86,8 +115,8 @@ def kernel2d( # Get local indices index = wp.vec2i() - index[0] = indices[0, ii] + start_index[0] - index[1] = indices[1, ii] + start_index[1] + index[0] = indices[0, ii] - start_index[0] + index[1] = indices[1, ii] - start_index[1] # Check if index is in bounds if index[0] >= 0 and index[0] < missing_mask.shape[1] and index[1] >= 0 and index[1] < missing_mask.shape[2]: @@ -95,23 +124,37 @@ def kernel2d( for l in range(_q): # Get the index of the streaming direction pull_index = wp.vec2i() + push_index = wp.vec2i() for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + push_index[d] = index[d] + _c[d, l] # check if pull index is out of bound # These directions will have missing information after streaming if pull_index[0] < 0 or pull_index[0] >= missing_mask.shape[1] or pull_index[1] < 0 or pull_index[1] >= missing_mask.shape[2]: # Set the missing mask missing_mask[l, index[0], index[1]] = True + boundary_map[0, index[0], index[1]] = id_number[ii] - boundary_mask[0, index[0], index[1]] = id_number[ii] + # handling geometries in the interior of the computational domain + elif ( + is_interior[ii] + and push_index[0] >= 0 + and push_index[0] < missing_mask.shape[1] + and push_index[1] >= 0 + and push_index[1] < missing_mask.shape[2] + ): + # Set the missing mask + missing_mask[l, push_index[0], push_index[1]] = True + boundary_map[0, push_index[0], push_index[1]] = id_number[ii] # Construct the warp 3D kernel @wp.kernel def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), - boundary_mask: wp.array4d(dtype=wp.uint8), + is_interior: wp.array1d(dtype=wp.bool), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -120,9 +163,9 @@ def kernel3d( # Get local indices index = wp.vec3i() - index[0] = indices[0, ii] + start_index[0] - index[1] = indices[1, ii] + start_index[1] - index[2] = indices[2, ii] + start_index[2] + index[0] = indices[0, ii] - start_index[0] + index[1] = indices[1, ii] - start_index[1] + index[2] = indices[2, ii] - start_index[2] # Check if index is in bounds if ( @@ -137,8 +180,10 @@ def kernel3d( for l in range(_q): # Get the index of the streaming direction pull_index = wp.vec3i() + push_index = wp.vec3i() for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + push_index[d] = index[d] + _c[d, l] # check if pull index is out of bound # These directions will have missing information after streaming @@ -152,28 +197,46 @@ def kernel3d( ): # Set the missing mask missing_mask[l, index[0], index[1], index[2]] = True + boundary_map[0, index[0], index[1], index[2]] = id_number[ii] - boundary_mask[0, index[0], index[1], index[2]] = id_number[ii] + # handling geometries in the interior of the computational domain + elif ( + is_interior[ii] + and push_index[0] >= 0 + and push_index[0] < missing_mask.shape[1] + and push_index[1] >= 0 + and push_index[1] < missing_mask.shape[2] + and push_index[2] >= 0 + and push_index[2] < missing_mask.shape[3] + ): + # Set the missing mask + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + boundary_map[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): + def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=None): dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] + is_interior = [] for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' + assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) + is_interior += [self.are_indices_in_interior(bc.indices, boundary_map[0].shape)] * len(bc.indices[0]) + # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) indices = wp.array2d(index_list, dtype=wp.int32) id_number = wp.array1d(id_list, dtype=wp.uint8) + is_interior = wp.array1d(is_interior, dtype=wp.bool) if start_index is None: start_index = (0,) * dim @@ -184,11 +247,12 @@ def warp_implementation(self, bclist, boundary_mask, missing_mask, start_index=N inputs=[ indices, id_number, - boundary_mask, + is_interior, + boundary_map, missing_mask, start_index, ], dim=indices.shape[1], ) - return boundary_mask, missing_mask + return boundary_map, missing_mask diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py new file mode 100644 index 0000000..366c9d6 --- /dev/null +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -0,0 +1,148 @@ +# Base class for all equilibriums + +import numpy as np +import warp as wp +import jax +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class MeshBoundaryMasker(Operator): + """ + Operator for creating a boundary missing_mask from an STL file + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.WARP, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + # Also using Warp kernels for JAX implementation + if self.compute_backend == ComputeBackend.JAX: + self.warp_functional, self.warp_kernel = self._construct_warp() + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation( + self, + bc, + origin, + spacing, + boundary_map, + missing_mask, + start_index=(0, 0, 0), + ): + raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") + # Use Warp backend even for this particular operation. + wp.init() + boundary_map = wp.from_jax(boundary_map) + missing_mask = wp.from_jax(missing_mask) + boundary_map, missing_mask = self.warp_implementation(bc, origin, spacing, boundary_map, missing_mask, start_index) + return wp.to_jax(boundary_map), wp.to_jax(missing_mask) + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.wp_c + _q = wp.constant(self.velocity_set.q) + + # Construct the warp kernel + @wp.kernel + def kernel( + mesh_id: wp.uint64, + origin: wp.vec3, + spacing: wp.vec3, + id_number: wp.int32, + boundary_map: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i() + index[0] = i - start_index[0] + index[1] = j - start_index[1] + index[2] = k - start_index[2] + + # position of the point + ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) + ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center + pos = wp.cw_mul(ijk, spacing) + origin + + # Compute the maximum length + max_length = wp.sqrt( + (spacing[0] * wp.float32(boundary_map.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_map.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_map.shape[3])) ** 2.0 + ) + + # evaluate if point is inside mesh + face_index = int(0) + face_u = float(0.0) + face_v = float(0.0) + sign = float(0.0) + if wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length, sign, face_index, face_u, face_v): + # set point to be solid + if sign <= 0: # TODO: fix this + # Stream indices + for l in range(_q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and missing_mask + boundary_map[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + bc, + origin, + spacing, + boundary_map, + missing_mask, + start_index=(0, 0, 0), + ): + assert bc.mesh_vertices is not None, f'Please provide the mesh points for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' + assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" + assert ( + bc.mesh_vertices.shape[1] == self.velocity_set.d + ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + mesh_vertices = bc.mesh_vertices + id_number = bc.id + + # We are done with bc.mesh_vertices. Remove them from BC objects + bc.__dict__.pop("mesh_vertices", None) + + mesh_indices = np.arange(mesh_vertices.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_vertices, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=int), + ) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + mesh.id, + origin, + spacing, + id_number, + boundary_map, + missing_mask, + start_index, + ], + dim=boundary_map.shape[1:], + ) + + return boundary_map, missing_mask diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py deleted file mode 100644 index b4ea8ca..0000000 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ /dev/null @@ -1,120 +0,0 @@ -# Base class for all equilibriums - -import numpy as np -from stl import mesh as np_mesh -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - - -class STLBoundaryMasker(Operator): - """ - Operator for creating a boundary mask from an STL file - """ - - def __init__( - self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, - ): - # Call super - super().__init__(velocity_set, precision_policy, compute_backend) - - def _construct_warp(self): - # Make constants for warp - _c = self.velocity_set.wp_c - _q = wp.constant(self.velocity_set.q) - - # Construct the warp kernel - @wp.kernel - def kernel( - mesh: wp.uint64, - origin: wp.vec3, - spacing: wp.vec3, - id_number: wp.int32, - boundary_mask: wp.array4d(dtype=wp.uint8), - mask: wp.array4d(dtype=wp.bool), - start_index: wp.vec3i, - ): - # get index - i, j, k = wp.tid() - - # Get local indices - index = wp.vec3i() - index[0] = i - start_index[0] - index[1] = j - start_index[1] - index[2] = k - start_index[2] - - # position of the point - ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) - ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center - pos = wp.cw_mul(ijk, spacing) + origin - - # Compute the maximum length - max_length = wp.sqrt( - (spacing[0] * wp.float32(boundary_mask.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(boundary_mask.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(boundary_mask.shape[3])) ** 2.0 - ) - - # evaluate if point is inside mesh - face_index = int(0) - face_u = float(0.0) - face_v = float(0.0) - sign = float(0.0) - if wp.mesh_query_point_sign_winding_number(mesh, pos, max_length, sign, face_index, face_u, face_v): - # set point to be solid - if sign <= 0: # TODO: fix this - # Stream indices - for l in range(_q): - # Get the index of the streaming direction - push_index = wp.vec3i() - for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] - - # Set the boundary id and mask - boundary_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) - mask[l, push_index[0], push_index[1], push_index[2]] = True - - return None, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation( - self, - stl_file, - origin, - spacing, - id_number, - boundary_mask, - mask, - start_index=(0, 0, 0), - ): - # Load the mesh - mesh = np_mesh.Mesh.from_file(stl_file) - mesh_points = mesh.points.reshape(-1, 3) - mesh_indices = np.arange(mesh_points.shape[0]) - mesh = wp.Mesh( - points=wp.array(mesh_points, dtype=wp.vec3), - indices=wp.array(mesh_indices, dtype=int), - ) - - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - mesh.id, - origin, - spacing, - id_number, - boundary_mask, - mask, - start_index, - ], - dim=boundary_mask.shape[1:], - ) - - return boundary_mask, mask diff --git a/xlb/operator/force/__init__.py b/xlb/operator/force/__init__.py new file mode 100644 index 0000000..6a991ce --- /dev/null +++ b/xlb/operator/force/__init__.py @@ -0,0 +1 @@ +from xlb.operator.force.momentum_transfer import MomentumTransfer as MomentumTransfer diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py new file mode 100644 index 0000000..66dba13 --- /dev/null +++ b/xlb/operator/force/momentum_transfer.py @@ -0,0 +1,216 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit, lax +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream import Stream + + +class MomentumTransfer(Operator): + """ + An opertor for the momentum exchange method to compute the boundary force vector exerted on the solid geometry + based on [1] as described in [3]. Ref [2] shows how [1] is applicable to curved geometries only by using a + bounce-back method (e.g. Bouzidi) that accounts for curved boundaries. + NOTE: this function should be called after BC's are imposed. + [1] A.J.C. Ladd, Numerical simulations of particular suspensions via a discretized Boltzmann equation. + Part 2 (numerical results), J. Fluid Mech. 271 (1994) 311-339. + [2] R. Mei, D. Yu, W. Shyy, L.-S. Luo, Force evaluation in the lattice Boltzmann method involving + curved geometry, Phys. Rev. E 65 (2002) 041203. + [3] Caiazzo, A., & Junk, M. (2008). Boundary forces in lattice Boltzmann: Analysis of momentum exchange + algorithm. Computers & Mathematics with Applications, 55(7), 1415-1423. + + Notes + ----- + This method computes the force exerted on the solid geometry at each boundary node using the momentum exchange method. + The force is computed based on the post-streaming and post-collision distribution functions. This method + should be called after the boundary conditions are imposed. + """ + + def __init__( + self, + no_slip_bc_instance, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + self.no_slip_bc_instance = no_slip_bc_instance + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + # Call the parent constructor + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f, boundary_map, missing_mask): + """ + Parameters + ---------- + f : jax.numpy.ndarray + The post-collision distribution function at each node in the grid. + boundary_map : jax.numpy.ndarray + A grid field with 0 everywhere except for boundary nodes which are designated + by their respective boundary id's. + missing_mask : jax.numpy.ndarray + A grid field with lattice cardinality that specifies missing lattice directions + for each boundary node. + + Returns + ------- + jax.numpy.ndarray + The force exerted on the solid geometry at each boundary node. + """ + # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. + f_post_collision = f + f_post_stream = self.stream(f_post_collision) + f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, boundary_map, missing_mask) + + # Compute momentum transfer + boundary = boundary_map == self.no_slip_bc_instance.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + + # the following will return force as a grid-based field with zero everywhere except for boundary nodes. + opp = self.velocity_set.opp_indices + phi = f_post_collision[opp] + f_post_stream + phi = jnp.where(jnp.logical_and(boundary, missing_mask), phi, 0.0) + force = jnp.tensordot(self.velocity_set.c[:, opp], phi, axes=(-1, 0)) + return force + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _no_slip_id = self.no_slip_bc_instance.id + + # Find velocity index for 0, 0, 0 + for l in range(self.velocity_set.q): + if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0: + zero_index = l + _zero_index = wp.int32(zero_index) + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + boundary_map: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + force: wp.array(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the boundary id + _boundary_map = boundary_map[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Determin if boundary is an edge by checking if center is missing + is_edge = wp.bool(False) + if _boundary_map == wp.uint8(_no_slip_id): + if _missing_mask[_zero_index] == wp.uint8(0): + is_edge = wp.bool(True) + + # If the boundary is an edge then add the momentum transfer + m = wp.vec2() + if is_edge: + # Get the distribution function + f_post_collision = _f_vec() + for l in range(self.velocity_set.q): + f_post_collision[l] = f[l, index[0], index[1]] + + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + + # Compute the momentum transfer + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] + for d in range(self.velocity_set.d): + m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + + wp.atomic_add(force, 0, m) + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + boundary_map: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + force: wp.array(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id + _boundary_map = boundary_map[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Determin if boundary is an edge by checking if center is missing + is_edge = wp.bool(False) + if _boundary_map == wp.uint8(_no_slip_id): + if _missing_mask[_zero_index] == wp.uint8(0): + is_edge = wp.bool(True) + + # If the boundary is an edge then add the momentum transfer + m = wp.vec3() + if is_edge: + # Get the distribution function + f_post_collision = _f_vec() + for l in range(self.velocity_set.q): + f_post_collision[l] = f[l, index[0], index[1], index[2]] + + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + + # Compute the momentum transfer + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] + for d in range(self.velocity_set.d): + m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + + wp.atomic_add(force, 0, m) + + # Return the correct kernel + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, boundary_map, missing_mask): + # Allocate the force vector (the total integral value will be computed) + force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f, boundary_map, missing_mask, force], + dim=f.shape[1:], + ) + return force.numpy()[0] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index efbf847..ba3e294 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -14,6 +14,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep +from xlb.operator.boundary_condition import DoNothingBC as DummyBC class IncompressibleNavierStokesStepper(Stepper): @@ -39,7 +40,7 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + def jax_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ @@ -47,41 +48,41 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_0 = self.precision_policy.cast_to_compute_jax(f_0) f_1 = self.precision_policy.cast_to_compute_jax(f_1) + # Apply streaming + f_post_stream = self.stream(f_0) + + # Apply boundary conditions + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.STREAMING: + f_post_stream = bc( + f_0, + f_post_stream, + boundary_map, + missing_mask, + ) + # Compute the macroscopic variables - rho, u = self.macroscopic(f_0) + rho, u = self.macroscopic(f_post_stream) # Compute equilibrium feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_0, feq, rho, u) + f_post_collision = self.collision(f_post_stream, feq, rho, u) # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_0, f_post_collision, boundary_mask, missing_mask) + f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_map, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( - f_0, + f_post_stream, f_post_collision, - boundary_mask, - missing_mask, - ) - - # Apply streaming - f_1 = self.stream(f_post_collision) - - # Apply boundary conditions - for bc in self.boundary_conditions: - if bc.implementation_step == ImplementationStep.STREAMING: - f_1 = bc( - f_post_collision, - f_1, - boundary_mask, + boundary_map, missing_mask, ) # Copy back to store precision - f_1 = self.precision_policy.cast_to_store_jax(f_1) + f_1 = self.precision_policy.cast_to_store_jax(f_post_collision) return f_1 @@ -115,32 +116,32 @@ def apply_post_streaming_bc( f_post: Any, f_aux: Any, missing_mask: Any, - _boundary_id: Any, + _boundary_map: Any, bc_struct: Any, ): # Apply post-streaming type boundary conditions - if _boundary_id == bc_struct.id_EquilibriumBC: + if _boundary_map == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_DoNothingBC: + elif _boundary_map == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: + elif _boundary_map == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC_velocity: + elif _boundary_map == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC_pressure: + elif _boundary_map == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_RegularizedBC_velocity: + elif _boundary_map == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_RegularizedBC_pressure: + elif _boundary_map == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + elif _boundary_map == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) return f_post @@ -151,13 +152,13 @@ def apply_post_collision_bc( f_post: Any, f_aux: Any, missing_mask: Any, - _boundary_id: Any, + _boundary_map: Any, bc_struct: Any, ): - if _boundary_id == bc_struct.id_FullwayBounceBackBC: + if _boundary_map == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + elif _boundary_map == bc_struct.id_ExtrapolationOutflowBC: # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask) @@ -224,13 +225,13 @@ def get_thread_data_3d( def get_bc_auxilary_data_2d( f_0: wp.array3d(dtype=Any), index: Any, - _boundary_id: Any, + _boundary_map: Any, _missing_mask: Any, bc_struct: Any, ): # special preparation of auxiliary data f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + if _boundary_map == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -247,13 +248,13 @@ def get_bc_auxilary_data_2d( def get_bc_auxilary_data_3d( f_0: wp.array4d(dtype=Any), index: Any, - _boundary_id: Any, + _boundary_map: Any, _missing_mask: Any, bc_struct: Any, ): # special preparation of auxiliary data f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + if _boundary_map == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -270,7 +271,7 @@ def get_bc_auxilary_data_3d( def kernel2d( f_0: wp.array3d(dtype=Any), f_1: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=Any), + boundary_map: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), bc_struct: Any, timestep: int, @@ -286,11 +287,11 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = boundary_mask[0, index[0], index[1]] - f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _boundary_map = boundary_map[0, index[0], index[1]] + f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_map, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -307,7 +308,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -318,7 +319,7 @@ def kernel2d( def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=Any), + boundary_map: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), bc_struct: Any, timestep: int, @@ -334,11 +335,11 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _boundary_map = boundary_map[0, index[0], index[1], index[2]] + f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_map, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -350,7 +351,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -362,7 +363,7 @@ def kernel3d( return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry @@ -378,9 +379,16 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) active_bc_list.append("id_" + bc_name) - # Setting the Struct attributes and active BC classes based on the BC class names - bc_fallback = self.boundary_conditions[0] - # TODO: what if self.boundary_conditions is an empty list e.g. when we have periodic BC all around! + # Check if boundary_conditions is an empty list (e.g. all periodic and no BC) + # TODO: There is a huge issue here with perf. when boundary_conditions list + # is empty and is initialized with a dummy BC. If it is not empty, no perf + # loss ocurrs. The following code at least prevents syntax error for periodic examples. + if self.boundary_conditions: + bc_dummy = self.boundary_conditions[0] + else: + bc_dummy = DummyBC() + + # Setting the Struct attributes for inactive BC classes for var in vars(bc_struct): if var not in active_bc_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 @@ -388,7 +396,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not # produce error when a particular BC is not used in an example. - setattr(self, var.replace("id_", ""), bc_fallback) + setattr(self, var.replace("id_", ""), bc_dummy) # Launch the warp kernel wp.launch( @@ -396,7 +404,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): inputs=[ f_0, f_1, - boundary_mask, + boundary_map, missing_mask, bc_struct, timestep,