From 98a2e192d08b192913d356c8878c911fcf7bdd68 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 28 Nov 2024 11:40:12 -0500 Subject: [PATCH] Added abstraction layer for boundary condition application and stepper initialization, and the capability to add profiles to boundary conditions --- CHANGELOG.md | 2 + examples/cfd/flow_past_sphere_3d.py | 95 ++++++--- examples/cfd/lid_driven_cavity_2d.py | 41 ++-- .../cfd/lid_driven_cavity_2d_distributed.py | 16 +- examples/cfd/turbulent_channel_3d.py | 44 ++-- examples/cfd/windtunnel_3d.py | 80 +++----- examples/performance/mlups_3d.py | 42 ++-- xlb/helper/nse_solver.py | 25 ++- .../bc_extrapolation_outflow.py | 15 +- .../boundary_condition/bc_regularized.py | 27 ++- xlb/operator/boundary_condition/bc_zouhe.py | 191 ++++++++++++++---- .../boundary_condition/boundary_condition.py | 71 ++++++- .../indices_boundary_masker.py | 23 +-- .../boundary_masker/mesh_boundary_masker.py | 71 +++---- xlb/operator/stepper/nse_stepper.py | 139 +++++++++++-- xlb/operator/stepper/stepper.py | 36 ++-- 16 files changed, 603 insertions(+), 315 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d036de..69ad1f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - XLB is now installable via pip - Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern - Added NVIDIA's Warp backend for state-of-the-art performance +- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data +- Added the capability to add profiles to boundary conditions \ No newline at end of file diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 1b5905e..a960786 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -1,20 +1,16 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps +from xlb.grid import grid_factory from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, HalfwayBounceBackBC, - ZouHeBC, RegularizedBC, - EquilibriumBC, - DoNothingBC, ExtrapolationOutflowBC, ) from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker -from xlb.utils import save_fields_vtk, save_image +from xlb.utils import save_image import warp as wp import numpy as np import jax.numpy as jnp @@ -34,18 +30,19 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) - self.stepper = None + self.omega = omega self.boundary_conditions = [] + self.u_max = 0.04 - # Setup the simulation BC, its initial conditions, and the stepper - self._setup(omega) + # Create grid using factory + self.grid = grid_factory(grid_shape, compute_backend=backend) - def _setup(self, omega): + # Setup the simulation BC and stepper + self._setup() + + def _setup(self): self.setup_boundary_conditions() - self.setup_boundary_masker() - self.initialize_fields() - self.setup_stepper(omega) + self.setup_stepper() def define_boundary_indices(self): box = self.grid.bounding_box_indices() @@ -69,31 +66,63 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet) - # bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet) + bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet) + # bc_left = RegularizedBC("velocity", prescribed_value=(self.u_max, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - # bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) - # bc_outlet = DoNothingBC(indices=outlet) bc_outlet = ExtrapolationOutflowBC(indices=outlet) bc_sphere = HalfwayBounceBackBC(indices=sphere) self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere] - def setup_boundary_masker(self): - # check boundary condition list for duplicate indices before creating bc mask - check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend) - - indices_boundary_masker = IndicesBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.backend, + def setup_stepper(self): + self.stepper = IncompressibleNavierStokesStepper( + omega=self.omega, + grid=self.grid, + boundary_conditions=self.boundary_conditions, + collision_type="BGK", ) - self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0)) - - def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) - - def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") + self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() + + def bc_profile(self): + u_max = self.u_max # u_max = 0.04 + # Get the grid dimensions for the y and z directions + H_y = float(self.grid_shape[1] - 1) # Height in y direction + H_z = float(self.grid_shape[2] - 1) # Height in z direction + + @wp.func + def bc_profile_warp(index: wp.vec3i): + # Poiseuille flow profile: parabolic velocity distribution + y = self.precision_policy.store_precision.wp_dtype(index[1]) + z = self.precision_policy.store_precision.wp_dtype(index[2]) + + # Calculate normalized distance from center + y_center = y - (H_y / 2.0) + z_center = z - (H_z / 2.0) + r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 + + # Parabolic profile: u = u_max * (1 - r²) + return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1) + + def bc_profile_jax(): + y = jnp.arange(self.grid_shape[1]) + z = jnp.arange(self.grid_shape[2]) + Y, Z = jnp.meshgrid(y, z, indexing="ij") + + # Calculate normalized distance from center + y_center = Y - (H_y / 2.0) + z_center = Z - (H_z / 2.0) + r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 + + # Parabolic profile for x velocity, zero for y and z + u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared) + u_y = jnp.zeros_like(u_x) + u_z = jnp.zeros_like(u_x) + + return jnp.stack([u_x, u_y, u_z]) + + if self.backend == ComputeBackend.JAX: + return bc_profile_jax + elif self.backend == ComputeBackend.WARP: + return bc_profile_warp def run(self, num_steps, post_process_interval=100): start_time = time.time() diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index f8bd65a..17b5915 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -1,8 +1,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps -from xlb.operator.boundary_masker import IndicesBoundaryMasker +from xlb.grid import grid_factory from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC from xlb.operator.macroscopic import Macroscopic @@ -26,19 +25,21 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) - self.stepper = None + self.omega = omega self.boundary_conditions = [] self.prescribed_vel = prescribed_vel - # Setup the simulation BC, its initial conditions, and the stepper - self._setup(omega) + # Create grid using factory + self.grid = grid_factory(grid_shape, compute_backend=backend) - def _setup(self, omega): + # Setup the simulation BC and stepper + self._setup() + + def _setup(self): self.setup_boundary_conditions() - self.setup_boundary_masker() - self.initialize_fields() - self.setup_stepper(omega) + self.setup_stepper() + # Initialize fields using the stepper + self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() def define_boundary_indices(self): box = self.grid.bounding_box_indices() @@ -54,21 +55,13 @@ def setup_boundary_conditions(self): bc_walls = HalfwayBounceBackBC(indices=walls) self.boundary_conditions = [bc_walls, bc_top] - def setup_boundary_masker(self): - # check boundary condition list for duplicate indices before creating bc mask - check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend) - indices_boundary_masker = IndicesBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.backend, + def setup_stepper(self): + self.stepper = IncompressibleNavierStokesStepper( + omega=self.omega, + grid=self.grid, + boundary_conditions=self.boundary_conditions, + collision_type="BGK", ) - self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) - - def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) - - def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) def run(self, num_steps, post_process_interval=100): for i in range(num_steps): diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index cdc9027..1018ec5 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -10,15 +10,21 @@ class LidDrivenCavity2D_distributed(LidDrivenCavity2D): def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy): super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy) - def setup_stepper(self, omega): - stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) - distributed_stepper = distribute( + def setup_stepper(self): + # Create the base stepper + stepper = IncompressibleNavierStokesStepper( + omega=self.omega, + grid=self.grid, + boundary_conditions=self.boundary_conditions, + collision_type="BGK", + ) + + # Distribute the stepper + self.stepper = distribute( stepper, self.grid, self.velocity_set, ) - self.stepper = distributed_stepper - return if __name__ == "__main__": diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index eb73fdc..ccee2ec 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -2,12 +2,12 @@ import time from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq +from xlb.grid import grid_factory from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import RegularizedBC from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image +from xlb.helper import initialize_eq import warp as wp import numpy as np import jax.numpy as jnp @@ -48,18 +48,16 @@ def __init__(self, channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, self.u_tau = u_tau self.visc = u_tau * channel_half_width / Re_tau self.omega = 1.0 / (3.0 * self.visc + 0.5) - # DeltaPlus = Re_tau / channel_half_width - # DeltaPlus = u_tau / nu * Delta where u_tau / nu = Re_tau / channel_half_width - self.grid_shape = grid_shape self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) - self.stepper = None self.boundary_conditions = [] - # Setup the simulation BC, its initial conditions, and the stepper + # Create grid using factory + self.grid = grid_factory(grid_shape, compute_backend=backend) + + # Setup the simulation BC and stepper self._setup() def get_force(self): @@ -71,9 +69,10 @@ def get_force(self): def _setup(self): self.setup_boundary_conditions() - self.setup_boundary_masker() - self.initialize_fields() self.setup_stepper() + # Initialize fields using the stepper + self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() + self.initialize_fields() def define_boundary_indices(self): # top and bottom sides of the channel are no-slip and the other directions are periodic @@ -83,19 +82,12 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): walls = self.define_boundary_indices() - bc_walls = RegularizedBC("velocity", (0.0, 0.0, 0.0), indices=walls) + bc_walls = RegularizedBC("velocity", prescribed_value=(0.0, 0.0, 0.0), indices=walls) self.boundary_conditions = [bc_walls] - def setup_boundary_masker(self): - indices_boundary_masker = IndicesBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.backend, - ) - self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) - def initialize_fields(self): - shape = (self.velocity_set.d,) + (self.grid_shape) + # Initialize with random velocity field + shape = (self.velocity_set.d,) + self.grid_shape np.random.seed(0) u_init = np.random.random(shape) if self.backend == ComputeBackend.JAX: @@ -105,9 +97,12 @@ def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init) def setup_stepper(self): - force = self.get_force() self.stepper = IncompressibleNavierStokesStepper( - self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC", forcing_scheme="exact_difference", force_vector=force + omega=self.omega, + grid=self.grid, + boundary_conditions=self.boundary_conditions, + collision_type="BGK", + force_vector=self.get_force(), ) def run(self, num_steps, print_interval, post_process_interval=100): @@ -142,14 +137,12 @@ def post_process(self, i): u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_z": u[2], "u_magnitude": u_magnitude} save_fields_vtk(fields, timestep=i) - save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) + save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) # Save monitor plot self.plot_uplus(u, i) - return def plot_uplus(self, u, timestep): - # Compute moving average of drag coefficient, 100, 1000, 10000 # mean streamwise velocity in wall units u^+(z) # Wall distance in wall units to be used inside output_data zz = np.arange(self.grid_shape[-1]) @@ -165,6 +158,7 @@ def plot_uplus(self, u, timestep): ax.set_ylim([0, 20]) fname = "uplus_" + str(timestep // 10000).zfill(5) + ".png" plt.savefig(fname, format="png") + plt.close() if __name__ == "__main__": diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 96c79f7..a77bbac 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -3,20 +3,16 @@ import time from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps +from xlb.grid import grid_factory from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( + HalfwayBounceBackBC, FullwayBounceBackBC, - EquilibriumBC, - DoNothingBC, RegularizedBC, - HalfwayBounceBackBC, ExtrapolationOutflowBC, - GradsApproximationBC, ) from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker, MeshDistanceBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np @@ -37,13 +33,14 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) - self.stepper = None + self.omega = omega self.boundary_conditions = [] - - # Setup the simulation BC, its initial conditions, and the stepper self.wind_speed = wind_speed - self.omega = omega + + # Create grid using factory + self.grid = grid_factory(grid_shape, compute_backend=backend) + + # Setup the simulation BC and stepper self._setup() # Make list to store drag coefficients @@ -52,11 +49,10 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.lift_coefficients = [] def _setup(self): - # NOTE: it is important to initialize fields before setup_boundary_masker is called because f_0 or f_1 might be used to store BC information - self.initialize_fields() self.setup_boundary_conditions() - self.setup_boundary_masker() self.setup_stepper() + # Initialize fields using the stepper + self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() def voxelize_stl(self, stl_filename, length_lbm_unit): mesh = trimesh.load_mesh(stl_filename, process=False) @@ -85,57 +81,28 @@ def define_boundary_indices(self): length_phys_unit = mesh_extents.max() length_lbm_unit = self.grid_shape[0] / 4 dx = length_phys_unit / length_lbm_unit - shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0.0]) + mesh_vertices = mesh_vertices / dx + shift = np.array([self.grid_shape[0] / 4, (self.grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0]) car = mesh_vertices + shift - self.grid_spacing = dx self.car_cross_section = np.prod(mesh_extents[1:]) / dx**2 return inlet, outlet, walls, car def setup_boundary_conditions(self): inlet, outlet, walls, car = self.define_boundary_indices() - bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet) - # bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet) + bc_left = RegularizedBC("velocity", prescribed_value=(self.wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - # bc_car = HalfwayBounceBackBC(mesh_vertices=car) - bc_car = GradsApproximationBC(mesh_vertices=car) - # bc_car = FullwayBounceBackBC(mesh_vertices=car) + bc_car = HalfwayBounceBackBC(mesh_vertices=car) self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car] - def setup_boundary_masker(self): - # check boundary condition list for duplicate indices before creating bc mask - check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend) - - indices_boundary_masker = IndicesBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.backend, - ) - # mesh_boundary_masker = MeshBoundaryMasker( - # velocity_set=self.velocity_set, - # precision_policy=self.precision_policy, - # compute_backend=self.backend, - # ) - mesh_distance_boundary_masker = MeshDistanceBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.backend, - ) - bclist_other = self.boundary_conditions[:-1] - bc_mesh = self.boundary_conditions[-1] - dx = self.grid_spacing - origin, spacing = (0, 0, 0), (dx, dx, dx) - self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask) - # self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) - self.bc_mask, self.missing_mask, self.f_1 = mesh_distance_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask, self.f_1) - - def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) - self.f_1 = initialize_eq(self.f_1, self.grid, self.velocity_set, self.precision_policy, self.backend) - def setup_stepper(self): - self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") + self.stepper = IncompressibleNavierStokesStepper( + omega=self.omega, + grid=self.grid, + boundary_conditions=self.boundary_conditions, + collision_type="BGK", + ) def run(self, num_steps, print_interval, post_process_interval=100): # Setup the operator for computing surface forces at the interface of the specified BC @@ -176,7 +143,7 @@ def post_process(self, i): fields = {"u_magnitude": u_magnitude} save_fields_vtk(fields, timestep=i) - save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) + save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) # Compute lift and drag boundary_force = self.momentum_transfer(self.f_0, self.f_1, self.bc_mask, self.missing_mask) @@ -190,7 +157,6 @@ def post_process(self, i): # Save monitor plot self.plot_drag_coefficient() - return def plot_drag_coefficient(self): # Compute moving average of drag coefficient, 100, 1000, 10000 @@ -230,14 +196,14 @@ def plot_drag_coefficient(self): # Configuration backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend) wind_speed = 0.02 num_steps = 100000 print_interval = 1000 # Set up Reynolds number and deduce relaxation time (omega) - # Re = 50000.0 - Re = 500000000000.0 + Re = 5000.0 clength = grid_size_x - 1 visc = wind_speed * clength / Re omega = 1.0 / (3.0 * visc + 0.5) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 2001fb2..f22dd94 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -2,9 +2,10 @@ import argparse import time import warp as wp +import numpy as np from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq +from xlb.grid import grid_factory from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC from xlb.distribute import distribute @@ -40,32 +41,21 @@ def setup_simulation(args): return backend, precision_policy -def create_grid_and_fields(cube_edge): - grid_shape = (cube_edge, cube_edge, cube_edge) - grid, f_0, f_1, missing_mask, bc_mask = create_nse_fields(grid_shape) - - return grid, f_0, f_1, missing_mask, bc_mask - - -def define_boundary_indices(grid): +def run(backend, precision_policy, grid_shape, num_steps): + # Create grid and setup boundary conditions + grid = grid_factory(grid_shape) box = grid.bounding_box_indices() box_no_edge = grid.bounding_box_indices(remove_edges=True) lid = box_no_edge["top"] walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))] - return lid, walls + walls = np.unique(np.array(walls), axis=-1).tolist() + boundary_conditions = [EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), FullwayBounceBackBC(indices=walls)] -def setup_boundary_conditions(grid): - lid, walls = define_boundary_indices(grid) - bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid) - bc_walls = FullwayBounceBackBC(indices=walls) - return [bc_top, bc_walls] - - -def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_steps): - omega = 1.0 - stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) + # Create stepper + stepper = IncompressibleNavierStokesStepper(omega=1.0, grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK") + # Distribute if using JAX backend if backend == ComputeBackend.JAX: stepper = distribute( stepper, @@ -73,6 +63,8 @@ def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_st xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend), ) + # Initialize fields and run simulation + f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() start_time = time.time() for i in range(num_steps): @@ -80,8 +72,7 @@ def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_st f_0, f_1 = f_1, f_0 wp.synchronize() - end_time = time.time() - return end_time - start_time + return time.time() - start_time def calculate_mlups(cube_edge, num_steps, elapsed_time): @@ -93,11 +84,8 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time): def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) - velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) - grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge) - f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, backend) - - elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps) + grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) + elapsed_time = run(backend, precision_policy, grid_shape, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index 3d0ef3e..b96d05b 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -3,12 +3,33 @@ from xlb.precision_policy import Precision from typing import Tuple +def create_nse_fields( + grid_shape: Tuple[int, int, int] = None, + grid=None, + velocity_set=None, + compute_backend=None, + precision_policy=None, +): + """Create fields for Navier-Stokes equation solver. -def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compute_backend=None, precision_policy=None): + Args: + grid_shape: Tuple of grid dimensions. Required if grid is not provided. + grid: Optional Grid object. If provided, will be used instead of creating new grid. + velocity_set: Optional velocity set. Defaults to DefaultConfig.velocity_set. + compute_backend: Optional compute backend. Defaults to DefaultConfig.default_backend. + precision_policy: Optional precision policy. Defaults to DefaultConfig.default_precision_policy. + + Returns: + Tuple of (grid, f_0, f_1, missing_mask, bc_mask) + """ velocity_set = velocity_set or DefaultConfig.velocity_set compute_backend = compute_backend or DefaultConfig.default_backend precision_policy = precision_policy or DefaultConfig.default_precision_policy - grid = grid_factory(grid_shape, compute_backend=compute_backend) + + if grid is None: + if grid_shape is None: + raise ValueError("grid_shape must be provided when grid is None") + grid = grid_factory(grid_shape, compute_backend=compute_backend) # Create fields f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 8a2a482..fa75490 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -56,12 +56,15 @@ def __init__( mesh_vertices, ) + # Set the flag for auxilary data recovery + self.needs_aux_recovery = True + # find and store the normal vector using indices self._get_normal_vec(indices) # Unpack the two warp functionals needed for this BC! if self.compute_backend == ComputeBackend.WARP: - self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional + self.warp_functional, self.update_bc_auxilary_data = self.warp_functional def _get_normal_vec(self, indices): # Get the frequency count and most common element directly @@ -92,9 +95,9 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): + def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ - Prepare the auxilary distribution functions for the boundary condition. + Update the auxilary distribution functions for the boundary condition. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) @@ -171,7 +174,7 @@ def functional( return _f @wp.func - def prepare_bc_auxilary_data( + def update_bc_auxilary_data( index: Any, timestep: Any, missing_mask: Any, @@ -180,7 +183,7 @@ def prepare_bc_auxilary_data( f_pre: Any, f_post: Any, ): - # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and + # Update the auxilary data for this BC using the neighbour's populations stored in f_aux and # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. _f = f_post @@ -199,7 +202,7 @@ def prepare_bc_auxilary_data( kernel = self._construct_kernel(functional) - return (functional, prepare_bc_auxilary_data), kernel + return (functional, update_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index d1abd4e..dee8679 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -7,7 +7,8 @@ import jax.lax as lax from functools import partial import warp as wp -from typing import Any +from typing import Any, Union, Tuple +import numpy as np from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -44,7 +45,8 @@ class RegularizedBC(ZouHeBC): def __init__( self, bc_type, - prescribed_value, + profile=None, + prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, @@ -54,6 +56,7 @@ def __init__( # Call the parent constructor super().__init__( bc_type, + profile, prescribed_value, velocity_set, precision_policy, @@ -127,15 +130,11 @@ def _construct_warp(self): # assign placeholders for both u and rho based on prescribed_value _d = self.velocity_set.d _q = self.velocity_set.q - u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d - rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) # compute Qi tensor and store it in self _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = self.compute_dtype(rho) - _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _w = self.velocity_set.w _c = self.velocity_set.c @@ -222,6 +221,15 @@ def functional_velocity( # Find normal vector normals = get_normal_vectors(missing_mask) + # Find the value of u from the missing directions + for l in range(_q): + # Since we are only considering normal velocity, we only need to find one value + if missing_mask[l] == wp.uint8(1): + # Create velocity vector by multiplying the prescribed value with the normal vector + prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]] + _u = -prescribed_value * normals + break + # calculate rho fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) @@ -253,6 +261,13 @@ def functional_pressure( # Find normal vector normals = get_normal_vectors(missing_mask) + # Find the value of rho from the missing directions + for q in range(_q): + # Since we need only one scalar value, we only need to find one value + if missing_mask[q] == wp.uint8(1): + _rho = f_1[_opp_indices[q], index[0], index[1], index[2]] + break + # calculate velocity fsum = _get_fsum(_f, missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 3e89992..8e889b0 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -7,7 +7,8 @@ import jax.lax as lax from functools import partial import warp as wp -from typing import Any +from typing import Any, Union, Tuple +import numpy as np from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -21,6 +22,7 @@ boundary_condition_registry, ) from xlb.operator.equilibrium import QuadraticEquilibrium +import jax class ZouHeBC(BoundaryCondition): @@ -38,7 +40,8 @@ class ZouHeBC(BoundaryCondition): def __init__( self, bc_type, - prescribed_value, + profile=None, + prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, @@ -50,7 +53,7 @@ def __init__( assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() - self.prescribed_value = prescribed_value + self.profile = profile # Call the parent constructor super().__init__( @@ -62,15 +65,76 @@ def __init__( mesh_vertices, ) - # Set the prescribed value for pressure or velocity - dim = self.velocity_set.d - if self.compute_backend == ComputeBackend.JAX: - self.prescribed_value = jnp.atleast_1d(prescribed_value)[(slice(None),) + (None,) * dim] - # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! + # Handle prescribed value if provided + if prescribed_value is not None: + if profile is not None: + raise ValueError("Cannot specify both profile and prescribed_value") + + # Convert input to numpy array for validation + if isinstance(prescribed_value, (tuple, list)): + prescribed_value = np.array(prescribed_value, dtype=np.float64) + elif isinstance(prescribed_value, (int, float)): + if bc_type == "pressure": + prescribed_value = float(prescribed_value) + else: + raise ValueError("Velocity prescribed_value must be a tuple or array") + elif isinstance(prescribed_value, np.ndarray): + prescribed_value = prescribed_value.astype(np.float64) + + # Validate prescribed value + if bc_type == "velocity": + if not isinstance(prescribed_value, np.ndarray): + raise ValueError("Velocity prescribed_value must be an array-like") + + # Check for non-zero elements - only one element should be non-zero + non_zero_count = np.count_nonzero(prescribed_value) + if non_zero_count > 1: + raise ValueError("This BC only supports normal prescribed values (only one non-zero element allowed)") + + self.prescribed_value = prescribed_value + self.profile = self._create_constant_prescribed_profile() + + # This BC needs auxilary data initialization before streaming + self.needs_aux_init = True + + # This BC needs auxilary data recovery after streaming + self.needs_aux_recovery = True + + # This BC needs one auxilary data for the density or normal velocity + self.num_of_aux_data = 1 # This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior self.needs_padding = True + def _create_constant_prescribed_profile(self): + if self.bc_type == "velocity": + + @wp.func + def prescribed_profile_warp(index: wp.vec3i): + # Get the non-zero value from prescribed_value + value = wp.static( + self.precision_policy.store_precision.wp_dtype(float(self.prescribed_value[np.nonzero(self.prescribed_value)[0][0]])) + ) + return wp.vec(value, length=1) + + def prescribed_profile_jax(): + return jnp.array(self.prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1) + + else: # pressure + + @wp.func + def prescribed_profile_warp(index: wp.vec3i): + value = wp.static(self.precision_policy.store_precision.wp_dtype(self.prescribed_value)) + return wp.vec(value, length=1) + + def prescribed_profile_jax(): + return jnp.array(self.prescribed_value) + + if self.compute_backend == ComputeBackend.JAX: + return prescribed_profile_jax + elif self.compute_backend == ComputeBackend.WARP: + return prescribed_profile_warp + @partial(jit, static_argnums=(0,), inline=True) def _get_known_middle_mask(self, missing_mask): known_mask = missing_mask[self.velocity_set.opp_indices] @@ -84,13 +148,53 @@ def _get_normal_vec(self, missing_mask): normals = -jnp.tensordot(main_c, m, axes=(-1, 0)) return normals + @partial(jit, static_argnums=(0, 2, 3), inline=True) + def _broadcast_prescribed_values(self, prescribed_values, prescribed_values_shape, target_shape): + """ + Broadcasts `prescribed_values` to `target_shape` following specific rules: + + - If `prescribed_values_shape` is (2, 1) or (3, 1) (for constant profiles), + broadcast along the last 2 or 3 dimensions of `target_shape` respectively. + - For other shapes, identify mismatched dimensions and broadcast only in that direction. + """ + # Determine the number of dimensions to match + num_dims_prescribed = len(prescribed_values_shape) + num_dims_target = len(target_shape) + + if num_dims_prescribed > num_dims_target: + raise ValueError("prescribed_values has more dimensions than target_shape") + + # Insert singleton dimensions after the first dimension to match target_shape + if num_dims_prescribed < num_dims_target: + # Number of singleton dimensions to add + num_singleton = num_dims_target - num_dims_prescribed + + if num_dims_prescribed == 0: + # If prescribed_values is scalar, reshape to all singleton dimensions + prescribed_values_shape = (1,) * num_dims_target + else: + # Insert singleton dimensions after the first dimension + prescribed_values_shape = (prescribed_values_shape[0], *(1,) * num_singleton, *prescribed_values_shape[1:]) + prescribed_values = prescribed_values.reshape(prescribed_values_shape) + + # Create broadcast shape based on the rules + broadcast_shape = [] + for pv_dim, tgt_dim in zip(prescribed_values_shape, target_shape): + if pv_dim == 1 or pv_dim == tgt_dim: + broadcast_shape.append(tgt_dim) + else: + raise ValueError(f"Cannot broadcast dimension {pv_dim} to {tgt_dim}") + + return jnp.broadcast_to(prescribed_values, target_shape) + @partial(jit, static_argnums=(0,), inline=True) def get_rho(self, fpop, missing_mask): if self.bc_type == "velocity": - vel = self.prescribed_value + target_shape = (self.velocity_set.d,) + fpop.shape[1:] + vel = self._broadcast_prescribed_values(self.prescribed_values, self.prescribed_values.shape, target_shape) rho = self.calculate_rho(fpop, vel, missing_mask) elif self.bc_type == "pressure": - rho = self.prescribed_value + rho = self.prescribed_values else: raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") return rho @@ -98,9 +202,10 @@ def get_rho(self, fpop, missing_mask): @partial(jit, static_argnums=(0,), inline=True) def get_vel(self, fpop, missing_mask): if self.bc_type == "velocity": - vel = self.prescribed_value + target_shape = (self.velocity_set.d,) + fpop.shape[1:] + vel = self._broadcast_prescribed_values(self.prescribed_values, self.prescribed_values.shape, target_shape) elif self.bc_type == "pressure": - rho = self.prescribed_value + rho = self.prescribed_values vel = self.calculate_vel(fpop, rho, missing_mask) else: raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") @@ -134,14 +239,13 @@ def calculate_rho(self, fpop, vel, missing_mask): return rho @partial(jit, static_argnums=(0,), inline=True) - def calculate_equilibrium(self, fpop, missing_mask): + def calculate_equilibrium(self, f_post, missing_mask): """ This is the ZouHe method of calculating the missing macroscopic variables at the boundary. """ - rho = self.get_rho(fpop, missing_mask) - vel = self.get_vel(fpop, missing_mask) + rho = self.get_rho(f_post, missing_mask) + vel = self.get_vel(f_post, missing_mask) - # compute feq at the boundary feq = self.equilibrium_operator(rho, vel) return feq @@ -176,14 +280,10 @@ def _construct_warp(self): # assign placeholders for both u and rho based on prescribed_value _d = self.velocity_set.d _q = self.velocity_set.q - u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d - rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = self.compute_dtype(rho) - _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _c = self.velocity_set.c _c_float = self.velocity_set.c_float @@ -231,62 +331,79 @@ def bounceback_nonequilibrium( def functional_velocity( index: Any, timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, + _missing_mask: Any, f_pre: Any, f_post: Any, + _f_pre: Any, + _f_post: Any, ): # Post-streaming values are only modified at missing direction - _f = f_post + _f = _f_post # Find normal vector - normals = get_normal_vectors(missing_mask) + normals = get_normal_vectors(_missing_mask) # calculate rho - fsum = _get_fsum(_f, missing_mask) + fsum = _get_fsum(_f, _missing_mask) unormal = self.compute_dtype(0.0) + + # Find the value of u from the missing directions + for l in range(_q): + # Since we are only considering normal velocity, we only need to find one value (all values are the same in the missing directions) + if _missing_mask[l] == wp.uint8(1): + # Create velocity vector by multiplying the prescribed value with the normal vector + # TODO: This can be optimized by saving _missing_mask[l] in the bc class later since it is the same for all boundary cells + prescribed_value = f_post[_opp_indices[l], index[0], index[1], index[2]] + _u = -prescribed_value * normals + break + for d in range(_d): unormal += _u[d] * normals[d] + _rho = fsum / (self.compute_dtype(1.0) + unormal) # impose non-equilibrium bounceback - feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) + _feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, _feq, _missing_mask) return _f @wp.func def functional_pressure( index: Any, timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, + _missing_mask: Any, f_pre: Any, f_post: Any, + _f_pre: Any, + _f_post: Any, ): # Post-streaming values are only modified at missing direction - _f = f_post + _f = _f_post # Find normal vector - normals = get_normal_vectors(missing_mask) + normals = get_normal_vectors(_missing_mask) + + # Find the value of rho from the missing directions + for q in range(_q): + # Since we need only one scalar value, we only need to find one value (all values are the same in the missing directions) + if _missing_mask[q] == wp.uint8(1): + _rho = f_post[_opp_indices[q], index[0], index[1], index[2]] + break # calculate velocity - fsum = _get_fsum(_f, missing_mask) + fsum = _get_fsum(_f, _missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) + _f = bounceback_nonequilibrium(_f, feq, _missing_mask) return _f if self.bc_type == "velocity": functional = functional_velocity elif self.bc_type == "pressure": functional = functional_pressure - elif self.bc_type == "velocity": - functional = functional_pressure kernel = self._construct_kernel(functional) diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index bf1eef2..86ab042 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -7,6 +7,8 @@ from typing import Any from jax import jit from functools import partial +import jax +import jax.numpy as jnp from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -57,13 +59,25 @@ def __init__( # A flag for BCs that need implicit boundary distance between the grid and a mesh (to be set to True if applicable inside each BC) self.needs_mesh_distance = False + # A flag for BCs that need auxilary data initialization before stepper + self.needs_aux_init = False + + # A flag to track if the BC is initialized with auxilary data + self.is_initialized_with_aux_data = False + + # Number of auxilary data needed for the BC (for prescribed values) + self.num_of_aux_data = 0 + + # A flag for BCs that need auxilary data recovery after streaming + self.needs_aux_recovery = False + if self.compute_backend == ComputeBackend.WARP: # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func - def prepare_bc_auxilary_data( + def update_bc_auxilary_data( index: Any, timestep: Any, missing_mask: Any, @@ -102,10 +116,10 @@ def _get_thread_data( # Construct some helper warp functions for getting tid data if self.compute_backend == ComputeBackend.WARP: self._get_thread_data = _get_thread_data - self.prepare_bc_auxilary_data = prepare_bc_auxilary_data + self.update_bc_auxilary_data = update_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): + def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ A placeholder function for prepare the auxilary distribution functions for the boundary condition. currently being called after collision only. @@ -146,3 +160,54 @@ def kernel( f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) return kernel + + def _construct_aux_data_init_kernel(self, functional): + """ + Constructs the warp kernel for the auxilary data recovery. + """ + _id = wp.uint8(self.id) + _opp_indices = self.velocity_set.opp_indices + _num_of_aux_data = self.num_of_aux_data + + # Construct the warp kernel + @wp.kernel + def aux_data_init_kernel( + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_0, _f_1, _boundary_id, _missing_mask = self._get_thread_data(f_0, f_1, bc_mask, missing_mask, index) + + # Apply the functional + if _boundary_id == _id: + # prescribed_values is a q-sized vector of type wp.vec + prescribed_values = functional(index) + # Write the result for all q directions, but only store up to num_of_aux_data + # TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions + counter = wp.int32(0) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1) and counter < _num_of_aux_data: + f_1[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(prescribed_values[counter]) + counter += 1 + + return aux_data_init_kernel + + def aux_data_init(self, f_0, f_1, bc_mask, missing_mask): + if self.compute_backend == ComputeBackend.WARP: + # Launch the warp kernel + wp.launch( + self._construct_aux_data_init_kernel(self.profile), + inputs=[f_0, f_1, bc_mask, missing_mask], + dim=f_0.shape[1:], + ) + elif self.compute_backend == ComputeBackend.JAX: + # We don't use boundary aux encoding/decoding in JAX + self.prescribed_values = self.profile() + self.is_initialized_with_aux_data = True + return f_0, f_1 diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 0c1f7e1..ff014f3 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -47,12 +47,18 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): dim = missing_mask.ndim - 1 nDevices = jax.device_count() pad_x, pad_y, pad_z = nDevices, 1, 1 + # TODO MEHDI: There is sometimes a halting problem here when padding is used in a multi-GPU setting since we're not jitting this function. + # For now, we compute the bmap on GPU zero. if dim == 2: + bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1]), dtype=jnp.uint8) + bmap = bmap.at[pad_x : -pad_x, pad_y : -pad_y].set(bc_mask[0]) grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) - bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) + # bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) if dim == 3: + bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1], pad_z * 2 + bc_mask[0].shape[2]), dtype=jnp.uint8) + bmap = bmap.at[pad_x : -pad_x, pad_y : -pad_y, pad_z : -pad_z].set(bc_mask[0]) grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) - bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) + # bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) # shift indices shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) @@ -111,16 +117,15 @@ def kernel( is_interior: wp.array1d(dtype=wp.bool), bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - start_index: wp.vec3i, ): # Get the index of indices ii = wp.tid() # Get local indices index = wp.vec3i() - index[0] = indices[0, ii] - start_index[0] - index[1] = indices[1, ii] - start_index[1] - index[2] = indices[2, ii] - start_index[2] + index[0] = indices[0, ii] + index[1] = indices[1, ii] + index[2] = indices[2, ii] # Check if index is in bounds shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3]) @@ -198,11 +203,6 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): wp_id_numbers = wp.array(id_numbers, dtype=wp.uint8) wp_is_interior = wp.array(is_interior, dtype=wp.bool) - if start_index is None: - start_index = wp.vec3i(0, 0, 0) - else: - start_index = wp.vec3i(*start_index) - # Launch the warp kernel wp.launch( self.warp_kernel, @@ -213,7 +213,6 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): wp_is_interior, bc_mask, missing_mask, - start_index, ], ) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index edfdb42..b424422 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -35,73 +35,56 @@ def __init__( def jax_implementation( self, bc, - origin, - spacing, - id_number, bc_mask, missing_mask, - start_index=(0, 0, 0), ): raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") # Use Warp backend even for this particular operation. wp.init() bc_mask = wp.from_jax(bc_mask) missing_mask = wp.from_jax(missing_mask) - bc_mask, missing_mask = self.warp_implementation(bc, origin, spacing, bc_mask, missing_mask, start_index) + bc_mask, missing_mask = self.warp_implementation(bc, bc_mask, missing_mask) return wp.to_jax(bc_mask), wp.to_jax(missing_mask) def _construct_warp(self): # Make constants for warp _c = self.velocity_set.c - _q = wp.constant(self.velocity_set.q) + _q = self.velocity_set.q # Construct the warp kernel @wp.kernel def kernel( mesh_id: wp.uint64, - origin: wp.vec3, - spacing: wp.vec3, id_number: wp.int32, bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - start_index: wp.vec3i, ): # get index i, j, k = wp.tid() # Get local indices - index = wp.vec3i() - index[0] = i - start_index[0] - index[1] = j - start_index[1] - index[2] = k - start_index[2] + index = wp.vec3i(i, j, k) # position of the point ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) - ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center - pos = wp.cw_mul(ijk, spacing) + origin - + pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center # Compute the maximum length - max_length = wp.sqrt( - (spacing[0] * wp.float32(bc_mask.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(bc_mask.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(bc_mask.shape[3])) ** 2.0 - ) + max_length = wp.sqrt(2.0) / 2.0 # half of unit cell diagonal # evaluate if point is inside mesh - query = wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length) + query = wp.mesh_query_point_no_sign(mesh_id, pos, max_length) if query.result: # set point to be solid - if query.sign <= 0: # TODO: fix this - # Stream indices - for l in range(1, _q): - # Get the index of the streaming direction - push_index = wp.vec3i() - for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] - - # Set the boundary id and missing_mask - bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) - missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + # Stream indices + for l in range(1, _q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and missing_mask + bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @@ -109,20 +92,27 @@ def kernel( def warp_implementation( self, bc, - origin, - spacing, bc_mask, missing_mask, - start_index=(0, 0, 0), ): assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" - assert ( - bc.mesh_vertices.shape[1] == self.velocity_set.d - ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + assert bc.mesh_vertices.shape[1] == self.velocity_set.d, ( + "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + ) mesh_vertices = bc.mesh_vertices id_number = bc.id + # Check mesh extents against domain dimensions + domain_shape = bc_mask.shape[1:] # (nx, ny, nz) + mesh_min = np.min(mesh_vertices, axis=0) + mesh_max = np.max(mesh_vertices, axis=0) + + if any(mesh_min < 0) or any(mesh_max >= domain_shape): + raise ValueError( + f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {domain_shape}. The mesh must be fully contained within the domain." + ) + # We are done with bc.mesh_vertices. Remove them from BC objects bc.__dict__.pop("mesh_vertices", None) @@ -140,12 +130,9 @@ def warp_implementation( self.warp_kernel, inputs=[ mesh.id, - origin, - spacing, id_number, bc_mask, missing_mask, - start_index, ], dim=bc_mask.shape[1:], ) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index e08e95c..62079ec 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -7,6 +7,7 @@ from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import Precision from xlb.operator import Operator from xlb.operator.stream import Stream from xlb.operator.collision import BGK, KBC @@ -16,31 +17,114 @@ from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.collision import ForcedCollision +from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker +from xlb.helper import check_bc_overlaps +from xlb.helper.nse_solver import create_nse_fields class IncompressibleNavierStokesStepper(Stepper): - def __init__(self, omega, boundary_conditions=[], collision_type="BGK", forcing_scheme="exact_difference", force_vector=None): - velocity_set = DefaultConfig.velocity_set - precision_policy = DefaultConfig.default_precision_policy - compute_backend = DefaultConfig.default_backend + def __init__( + self, + omega, + grid, + boundary_conditions=[], + collision_type="BGK", + forcing_scheme="exact_difference", + force_vector=None, + ): + super().__init__(grid, boundary_conditions) # Construct the collision operator if collision_type == "BGK": - self.collision = BGK(omega, velocity_set, precision_policy, compute_backend) + self.collision = BGK(omega, self.velocity_set, self.precision_policy, self.compute_backend) elif collision_type == "KBC": - self.collision = KBC(omega, velocity_set, precision_policy, compute_backend) + self.collision = KBC(omega, self.velocity_set, self.precision_policy, self.compute_backend) if force_vector is not None: self.collision = ForcedCollision(collision_operator=self.collision, forcing_scheme=forcing_scheme, force_vector=force_vector) # Construct the operators - self.stream = Stream(velocity_set, precision_policy, compute_backend) - self.equilibrium = QuadraticEquilibrium(velocity_set, precision_policy, compute_backend) - self.macroscopic = Macroscopic(velocity_set, precision_policy, compute_backend) - - operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] + self.stream = Stream(self.velocity_set, self.precision_policy, self.compute_backend) + self.equilibrium = QuadraticEquilibrium(self.velocity_set, self.precision_policy, self.compute_backend) + self.macroscopic = Macroscopic(self.velocity_set, self.precision_policy, self.compute_backend) + + def prepare_fields(self, initializer=None): + """Prepare the fields required for the stepper. + + Args: + initializer: Optional operator to initialize the distribution functions. + If provided, it should be a callable that takes (grid, velocity_set, + precision_policy, compute_backend) as arguments and returns initialized f_0. + If None, default equilibrium initialization is used with rho=1 and u=0. + + Returns: + Tuple of (f_0, f_1, bc_mask, missing_mask): + - f_0: Initial distribution functions + - f_1: Copy of f_0 for double-buffering + - bc_mask: Boundary condition mask indicating which BC applies to each node + - missing_mask: Mask indicating which populations are missing at boundary nodes + """ + # Create fields using the helper function + _, f_0, f_1, missing_mask, bc_mask = create_nse_fields( + grid=self.grid, velocity_set=self.velocity_set, compute_backend=self.compute_backend, precision_policy=self.precision_policy + ) - super().__init__(operators, boundary_conditions) + # Initialize distribution functions if initializer is provided + if initializer is not None: + f_0 = initializer(self.grid, self.velocity_set, self.precision_policy, self.compute_backend) + else: + from xlb.helper.initializers import initialize_eq + f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend) + + # Copy f_0 using backend-specific copy to f_1 + if self.compute_backend == ComputeBackend.JAX: + f_1 = f_0.copy() + else: + wp.copy(f_1, f_0) + + # Process boundary conditions and update masks + bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, bc_mask, missing_mask) + # Initialize auxiliary data if needed + f_0, f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_0, f_1, bc_mask, missing_mask) + + return f_0, f_1, bc_mask, missing_mask + + @classmethod + def _process_boundary_conditions(cls, boundary_conditions, bc_mask, missing_mask): + """Process boundary conditions and update boundary masks.""" + # Check for boundary condition overlaps + check_bc_overlaps(boundary_conditions, DefaultConfig.velocity_set.d, DefaultConfig.default_backend) + # Create boundary maskers + indices_masker = IndicesBoundaryMasker( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + # Split boundary conditions by type + bc_with_vertices = [bc for bc in boundary_conditions if bc.mesh_vertices is not None] + bc_with_indices = [bc for bc in boundary_conditions if bc.indices is not None] + # Process indices-based boundary conditions + if bc_with_indices: + bc_mask, missing_mask = indices_masker(bc_with_indices, bc_mask, missing_mask) + # Process mesh-based boundary conditions for 3D + if DefaultConfig.velocity_set.d == 3 and bc_with_vertices: + mesh_masker = MeshBoundaryMasker( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + for bc in bc_with_vertices: + bc_mask, missing_mask = mesh_masker(bc, bc_mask, missing_mask) + + return bc_mask, missing_mask + + @staticmethod + def _initialize_auxiliary_data(boundary_conditions, f_0, f_1, bc_mask, missing_mask): + """Initialize auxiliary data for boundary conditions that require it.""" + for bc in boundary_conditions: + if bc.needs_aux_init and not bc.is_initialized_with_aux_data: + f_0, f_1 = bc.aux_data_init(f_0, f_1, bc_mask, missing_mask) + return f_0, f_1 @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) @@ -76,7 +160,7 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) + f_post_collision = bc.update_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, @@ -108,6 +192,8 @@ def _construct_warp(self): # Group active boundary conditions active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) + _opp_indices = self.velocity_set.opp_indices + @wp.func def apply_bc( index: Any, @@ -134,7 +220,7 @@ def apply_bc( f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): if _boundary_id == wp.static(self.boundary_conditions[i].id): - f_result = wp.static(self.boundary_conditions[i].prepare_bc_auxilary_data)( + f_result = wp.static(self.boundary_conditions[i].update_bc_auxilary_data)( index, timestep, missing_mask, f_0, f_1, f_pre, f_post ) return f_result @@ -161,6 +247,23 @@ def get_thread_data( return _f0_thread, _f1_thread, _missing_mask + @wp.func + def apply_aux_recovery_bc( + index: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0: Any, + _f1_thread: Any, + ): + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if wp.static(self.boundary_conditions[i].needs_aux_recovery): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + # Perform the swapping of data + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) + @wp.kernel def kernel( f_0: wp.array4d(dtype=Any), @@ -192,13 +295,11 @@ def kernel( # Apply post-collision boundary conditions _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) + # Apply auxiliary recovery for boundary conditions (swapping) + apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0, _f1_thread) + # Store the result in f_1 for l in range(self.velocity_set.q): - # TODO: Improve this later - if wp.static("GradsApproximationBC" in active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): - if _missing_mask[l] == wp.uint8(1): - f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) return None, kernel diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 44aab5f..b2ed741 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -8,25 +8,27 @@ class Stepper(Operator): Class that handles the construction of lattice boltzmann stepping operator """ - def __init__(self, operators, boundary_conditions): - # Get the boundary condition ids - from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - self.operators = operators + def __init__(self, grid, boundary_conditions): + self.grid = grid self.boundary_conditions = boundary_conditions - # Get velocity set, precision policy, and compute backend - velocity_sets = set([op.velocity_set for op in self.operators if op is not None]) - assert len(velocity_sets) < 2, "All velocity sets must be the same. Got {}".format(velocity_sets) - velocity_set = DefaultConfig.velocity_set if not velocity_sets else velocity_sets.pop() - - precision_policies = set([op.precision_policy for op in self.operators if op is not None]) - assert len(precision_policies) < 2, "All precision policies must be the same. Got {}".format(precision_policies) - precision_policy = DefaultConfig.default_precision_policy if not precision_policies else precision_policies.pop() - - compute_backends = set([op.compute_backend for op in self.operators if op is not None]) - assert len(compute_backends) < 2, "All compute backends must be the same. Got {}".format(compute_backends) - compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() + velocity_set = DefaultConfig.velocity_set + precision_policy = DefaultConfig.default_precision_policy + compute_backend = DefaultConfig.default_backend # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) + + def prepare_fields(self, initializer=None): + """Initialize the fields required for the stepper. + + Args: + initializer: Optional operator to initialize the distribution functions. + If provided, it should be a callable that takes (grid, velocity_set, + precision_policy, compute_backend) as arguments and returns initialized f_0. + If None, default equilibrium initialization is used with rho=1 and u=0. + + Returns: + Tuple of (f_0, f_1, bc_mask, missing_mask) + """ + raise NotImplementedError("Subclasses must implement prepare_fields()")