From 1d1560bb6c12e2fe9c2e1a1aeabc5691aad396ec Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Sun, 6 Oct 2024 16:58:35 -0400 Subject: [PATCH] Significantly simplified boundary application --- .../flow_past_sphere.py | 203 -------------- .../cfd_old_to_be_migrated/taylor_green.py | 181 ------------- .../boundary_condition/bc_regularized.py | 2 + xlb/operator/stepper/nse_stepper.py | 248 +++++++----------- 4 files changed, 93 insertions(+), 541 deletions(-) delete mode 100644 examples/cfd_old_to_be_migrated/flow_past_sphere.py delete mode 100644 examples/cfd_old_to_be_migrated/taylor_green.py diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py deleted file mode 100644 index 1684266..0000000 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ /dev/null @@ -1,203 +0,0 @@ -# Simple flow past sphere example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import numpy as np - -from xlb.compute_backend import ComputeBackend - -import warp as wp - -import xlb - -xlb.init( - default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=xlb.velocity_set.D2Q9, -) - - -from xlb.operator import Operator - - -class UniformInitializer(Operator): - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - ): - # Get the global index - i, j, k = wp.tid() - - # Set the velocity - u[0, i, j, k] = vel - u[1, i, j, k] = 0.0 - u[2, i, j, k] = 0.0 - - # Set the density - rho[0, i, j, k] = 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - ], - dim=rho.shape[1:], - ) - return rho, u - - -if __name__ == "__main__": - # Set parameters - compute_backend = xlb.ComputeBackend.WARP - precision_policy = xlb.PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D3Q19() - - # Make feilds - nr = 256 - vel = 0.05 - shape = (nr, nr, nr) - grid = xlb.grid.grid_factory(shape=shape) - rho = grid.create_field(cardinality=1) - u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) - f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - bc_mask = grid.create_field(cardinality=1, dtype=wp.uint8) - missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) - - # Make operators - initializer = UniformInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - collision = xlb.operator.collision.BGK( - omega=1.95, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=1.0, - u=(vel, 0.0, 0.0), - equilibrium_operator=equilibrium, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream, - equilibrium_bc=equilibrium_bc, - do_nothing_bc=do_nothing_bc, - half_way_bc=half_way_bc, - ) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - - # Make indices for boundary conditions (sphere) - sphere_radius = 32 - x = np.arange(nr) - y = np.arange(nr) - z = np.arange(nr) - X, Y, Z = np.meshgrid(x, y, z) - indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) - indices = np.array(indices).T - indices = wp.from_numpy(indices, dtype=wp.int32) - - # Set boundary conditions on the indices - bc_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set inlet bc - lower_bound = (0, 0, 0) - upper_bound = (0, nr, nr) - direction = (1, 0, 0) - bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set outlet bc - lower_bound = (nr - 1, 0, 0) - upper_bound = (nr - 1, nr, nr) - direction = (-1, 0, 0) - bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set initial conditions - rho, u = initializer(rho, u, vel) - f0 = equilibrium(rho, u, f0) - - # Time stepping - plot_freq = 512 - save_dir = "flow_past_sphere" - os.makedirs(save_dir, exist_ok=True) - # compute_mlup = False # Plotting results - compute_mlup = True - num_steps = 1024 * 8 - start = time.time() - for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, bc_mask, missing_mask, _) - f1, f0 = f0, f1 - if (_ % plot_freq == 0) and (not compute_mlup): - rho, u = macroscopic(f0, rho, u) - - # Plot the velocity field and boundary id side by side - plt.subplot(1, 2, 1) - plt.imshow(u[0, :, nr // 2, :].numpy()) - plt.colorbar() - plt.subplot(1, 2, 2) - plt.imshow(bc_mask[0, :, nr // 2, :].numpy()) - plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") - plt.close() - - wp.synchronize() - end = time.time() - - # Print MLUPS - print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py deleted file mode 100644 index 846ba30..0000000 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ /dev/null @@ -1,181 +0,0 @@ -# Simple Taylor green example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import jax.numpy as jnp -import warp as wp - -wp.init() - -import xlb -from xlb.operator import Operator - - -class TaylorGreenInitializer(Operator): - """ - Initialize the Taylor-Green vortex. - """ - - @Operator.register_backend(xlb.ComputeBackend.JAX) - # @partial(jit, static_argnums=(0)) - def jax_implementation(self, vel, nr): - # Make meshgrid - x = jnp.linspace(0, 2 * jnp.pi, nr) - y = jnp.linspace(0, 2 * jnp.pi, nr) - z = jnp.linspace(0, 2 * jnp.pi, nr) - X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij") - - # Compute u - u = jnp.stack( - [ - vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), - -vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), - jnp.zeros_like(X), - ], - axis=0, - ) - - # Compute rho - rho = 3.0 * vel * vel * (1.0 / 16.0) * (jnp.cos(2.0 * X) + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0))) + 1.0 - rho = jnp.expand_dims(rho, axis=0) - - return rho, u - - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - nr: int, - ): - # Get the global index - i, j, k = wp.tid() - - # Get real pos - x = 2.0 * wp.pi * wp.float(i) / wp.float(nr) - y = 2.0 * wp.pi * wp.float(j) / wp.float(nr) - z = 2.0 * wp.pi * wp.float(k) / wp.float(nr) - - # Compute u - u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1, i, j, k] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) - u[2, i, j, k] = 0.0 - - # Compute rho - rho[0, i, j, k] = 3.0 * vel * vel * (1.0 / 16.0) * (wp.cos(2.0 * x) + (wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0))) + 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel, nr): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - nr, - ], - dim=rho.shape[1:], - ) - return rho, u - - -def run_taylor_green(backend, compute_mlup=True): - # Set the compute backend - if backend == "warp": - compute_backend = xlb.ComputeBackend.WARP - elif backend == "jax": - compute_backend = xlb.ComputeBackend.JAX - - # Set the precision policy - precision_policy = xlb.PrecisionPolicy.FP32FP32 - - # Set the velocity set - velocity_set = xlb.velocity_set.D3Q19() - - # Make grid - nr = 128 - shape = (nr, nr, nr) - if backend == "jax": - grid = xlb.grid.JaxGrid(shape=shape) - elif backend == "warp": - grid = xlb.grid.WarpGrid(shape=shape) - - # Make feilds - rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) - u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) - f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - bc_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) - - # Make operators - initializer = TaylorGreenInitializer(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - collision = xlb.operator.collision.BGK(omega=1.9, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend - ) - macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - stream = xlb.operator.stream.Stream(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream - ) - - # Parrallelize the stepper TODO: Add this functionality - # stepper = grid.parallelize_operator(stepper) - - # Set initial conditions - if backend == "warp": - rho, u = initializer(rho, u, 0.1, nr) - f0 = equilibrium(rho, u, f0) - elif backend == "jax": - rho, u = initializer(0.1, nr) - f0 = equilibrium(rho, u) - - # Time stepping - plot_freq = 32 - save_dir = "taylor_green" - os.makedirs(save_dir, exist_ok=True) - num_steps = 8192 - start = time.time() - - for _ in tqdm(range(num_steps)): - # Time step - if backend == "warp": - f1 = stepper(f0, f1, bc_mask, missing_mask, _) - f1, f0 = f0, f1 - elif backend == "jax": - f0 = stepper(f0, bc_mask, missing_mask, _) - - # Plot if needed - if (_ % plot_freq == 0) and (not compute_mlup): - if backend == "warp": - rho, u = macroscopic(f0, rho, u) - local_u = u.numpy() - elif backend == "jax": - rho, local_u = macroscopic(f0) - - plt.imshow(local_u[0, :, nr // 2, :]) - plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") - plt.close() - wp.synchronize() - end = time.time() - - # Print MLUPS - print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") - - -if __name__ == "__main__": - # Run Taylor-Green vortex on different backends - backends = ["warp", "jax"] - # backends = ["jax"] - for backend in backends: - run_taylor_green(backend, compute_mlup=True) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index bb4b5f0..e1505b7 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -62,6 +62,8 @@ def __init__( mesh_vertices, ) + self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__) + # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 62790a6..f977519 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -14,7 +14,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep -from xlb.operator.boundary_condition import DoNothingBC as DummyBC +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.collision import ForcedCollision @@ -40,6 +40,9 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK", forcing_ operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] + self.boundary_conditions = boundary_conditions + self.active_bcs = set(type(bc).__name__ for bc in boundary_conditions) + super().__init__(operators, boundary_conditions) @Operator.register_backend(ComputeBackend.JAX) @@ -91,92 +94,84 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): return f_0, f_1 def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update + # Set local constants _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) _opp_indices = self.velocity_set.opp_indices - @wp.struct - class BoundaryConditionIDStruct: - # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning - # One needs to manually add the names of additional BC's as they are added. - # TODO: Any way to improve this? - id_EquilibriumBC: wp.uint8 - id_DoNothingBC: wp.uint8 - id_HalfwayBounceBackBC: wp.uint8 - id_FullwayBounceBackBC: wp.uint8 - id_ZouHeBC_velocity: wp.uint8 - id_ZouHeBC_pressure: wp.uint8 - id_RegularizedBC_velocity: wp.uint8 - id_RegularizedBC_pressure: wp.uint8 - id_ExtrapolationOutflowBC: wp.uint8 - id_GradsApproximationBC: wp.uint8 + # Read the list of bc_to_id created upon instantiation + bc_to_id = boundary_condition_registry.bc_to_id + id_to_bc = boundary_condition_registry.id_to_bc + + for bc in self.boundary_conditions: + bc_name = id_to_bc[bc.id] + setattr(self, bc_name, bc) @wp.func def apply_post_streaming_bc( index: Any, timestep: Any, _boundary_id: Any, - bc_struct: Any, missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, f_post: Any, ): - # Apply post-streaming type boundary conditions - # NOTE: 'f_pre' is included here as an input to the BC functionals for consistency with the BC API, - # particularly when compared to post-collision boundary conditions (see below). - - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ZouHeBC_velocity: - # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ZouHeBC_pressure: - # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_RegularizedBC_velocity: - # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_RegularizedBC_pressure: - # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_GradsApproximationBC: - # Reformulated Grads boundary condition - f_post = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post + f_result = f_post + + if wp.static("EquilibriumBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["EquilibriumBC"]): + f_result = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("DoNothingBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["DoNothingBC"]): + f_result = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("HalfwayBounceBackBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["HalfwayBounceBackBC"]): + f_result = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ZouHeBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ZouHeBC"]): + f_result = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("RegularizedBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["RegularizedBC"]): + f_result = self.RegularizedBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ExtrapolationOutflowBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ExtrapolationOutflowBC"]): + f_result = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("GradsApproximationBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): + f_result = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + return f_result @wp.func def apply_post_collision_bc( index: Any, timestep: Any, _boundary_id: Any, - bc_struct: Any, missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, f_post: Any, ): - # Apply post-collision type boundary conditions or special boundary preparations - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # Storing post-streaming data in directions that leave the domain - f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post + f_result = f_post + + if wp.static("FullwayBounceBackBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["FullwayBounceBackBC"]): + f_result = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ExtrapolationOutflowBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ExtrapolationOutflowBC"]): + f_result = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + return f_result @wp.func def get_thread_data_2d( @@ -186,17 +181,17 @@ def get_thread_data_2d( index: Any, ): # Read thread data for populations and missing mask - f0_thread = _f_vec() - f1_thread = _f_vec() + _f0_thread = _f_vec() + _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): - f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) - f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) + _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) + _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f0_thread, f1_thread, _missing_mask + return _f0_thread, _f1_thread, _missing_mask @wp.func def get_thread_data_3d( @@ -206,19 +201,19 @@ def get_thread_data_3d( index: Any, ): # Read thread data for populations - f0_thread = _f_vec() - f1_thread = _f_vec() + _f0_thread = _f_vec() + _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) - f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) + _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) + _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f0_thread, f1_thread, _missing_mask + return _f0_thread, _f1_thread, _missing_mask @wp.kernel def kernel2d( @@ -226,27 +221,23 @@ def kernel2d( f_1: wp.array3d(dtype=Any), bc_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: Any, timestep: int, ): - # Get the global index i, j = wp.tid() - index = wp.vec2i(i, j) # TODO warp should fix this + index = wp.vec2i(i, j) - # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1]] if _boundary_id == wp.uint8(255): return - # Apply streaming (pull method) + # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming type boundary conditions - f0_thread, f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) - _f_post_collision = f0_thread - _f_post_stream = apply_post_streaming_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream - ) + _f0_thread, _f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) + _f_post_collision = _f0_thread + + # Apply post-streaming boundary conditions + _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -257,119 +248,62 @@ def kernel2d( # Apply collision _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision - ) + # Apply post-collision boundary conditions + _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) - # Set the output + # Store the result in f_1 for l in range(self.velocity_set.q): f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) - # Construct the kernel @wp.kernel def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), - bc_struct: Any, timestep: int, ): - # Get the global index i, j, k = wp.tid() - index = wp.vec3i(i, j, k) # TODO warp should fix this + index = wp.vec3i(i, j, k) - # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1], index[2]] if _boundary_id == wp.uint8(255): return - # Apply streaming (pull method) + # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming type boundary conditions - f0_thread, f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) - _f_post_collision = f0_thread - _f_post_stream = apply_post_streaming_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream - ) + _f0_thread, _f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) + _f_post_collision = _f0_thread - # Compute rho and u - _rho, _u = self.macroscopic.warp_functional(_f_post_stream) + # Apply post-streaming boundary conditions + _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) - # Compute equilibrium + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) - - # Apply collision _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision - ) + # Apply post-collision boundary conditions + _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) - # Set the output + # Store the result in f_1 for l in range(self.velocity_set.q): - # TODO 1: fix the perf drop due to l324-l236 even in cases where this BC is not used. - # TODO 2: is there better way to move these lines to a function inside BC class like "restore_bc_data" - # if _boundary_id == bc_struct.id_GradsApproximationBC: - # if _missing_mask[l] == wp.uint8(1): - # f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) + if wp.static("GradsApproximationBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): + if _missing_mask[l] == wp.uint8(1): + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return BoundaryConditionIDStruct, kernel + return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): - # Get the boundary condition ids - from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - # Read the list of bc_to_id created upon instantiation - bc_to_id = boundary_condition_registry.bc_to_id - id_to_bc = boundary_condition_registry.id_to_bc - bc_struct = self.warp_functional() - active_bc_list = [] - for bc in self.boundary_conditions: - # Setting the Struct attributes and active BC classes based on the BC class names - bc_name = id_to_bc[bc.id] - setattr(self, bc_name, bc) - setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) - active_bc_list.append("id_" + bc_name) - - # Check if boundary_conditions is an empty list (e.g. all periodic and no BC) - # TODO: There is a huge issue here with perf. when boundary_conditions list - # is empty and is initialized with a dummy BC. If it is not empty, no perf - # loss ocurrs. The following code at least prevents syntax error for periodic examples. - if self.boundary_conditions: - bc_dummy = self.boundary_conditions[0] - else: - bc_dummy = DummyBC() - - # Setting the Struct attributes for inactive BC classes - for var in vars(bc_struct): - if var not in active_bc_list and not var.startswith("_"): - # set unassigned boundaries to the maximum integer in uint8 - setattr(bc_struct, var, 255) - - # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not - # produce error when a particular BC is not used in an example. - setattr(self, var.replace("id_", ""), bc_dummy) - - # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[ - f_0, - f_1, - bc_mask, - missing_mask, - bc_struct, - timestep, - ], + inputs=[f_0, f_1, bc_mask, missing_mask, timestep], dim=f_0.shape[1:], ) return f_0, f_1