diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2487919..2e0df95 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -102,7 +102,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index f94e209..20f3b7c 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -68,7 +68,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index 65b56bf..2ec5560 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -113,7 +113,7 @@ def setup_stepper(self): def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 077ae98..140e756 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -12,10 +12,11 @@ RegularizedBC, HalfwayBounceBackBC, ExtrapolationOutflowBC, + GradsApproximationBC, ) from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker +from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker, MeshDistanceBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np @@ -51,9 +52,10 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.lift_coefficients = [] def _setup(self): + # NOTE: it is important to initialize fields before setup_boundary_masker is called because f_0 or f_1 might be used to store BC information + self.initialize_fields() self.setup_boundary_conditions() self.setup_boundary_masker() - self.initialize_fields() self.setup_stepper() def voxelize_stl(self, stl_filename, length_lbm_unit): @@ -99,7 +101,8 @@ def setup_boundary_conditions(self): # bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = HalfwayBounceBackBC(mesh_vertices=car) + # bc_car = HalfwayBounceBackBC(mesh_vertices=car) + bc_car = GradsApproximationBC(mesh_vertices=car) # bc_car = FullwayBounceBackBC(mesh_vertices=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] @@ -109,7 +112,12 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - mesh_boundary_masker = MeshBoundaryMasker( + # mesh_boundary_masker = MeshBoundaryMasker( + # velocity_set=self.velocity_set, + # precision_policy=self.precision_policy, + # compute_backend=self.backend, + # ) + mesh_distance_boundary_masker = MeshDistanceBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, @@ -119,10 +127,12 @@ def setup_boundary_masker(self): dx = self.grid_spacing origin, spacing = (0, 0, 0), (dx, dx, dx) self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask) - self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) + # self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) + self.bc_mask, self.missing_mask, self.f_1 = mesh_distance_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask, self.f_1) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) + self.f_1 = initialize_eq(self.f_1, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") @@ -134,7 +144,7 @@ def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: @@ -169,7 +179,7 @@ def post_process(self, i): save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) # Compute lift and drag - boundary_force = self.momentum_transfer(self.f_0, self.bc_mask, self.missing_mask) + boundary_force = self.momentum_transfer(self.f_0, self.f_1, self.bc_mask, self.missing_mask) drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane lift = boundary_force[2] c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) @@ -226,7 +236,8 @@ def plot_drag_coefficient(self): print_interval = 1000 # Set up Reynolds number and deduce relaxation time (omega) - Re = 50000.0 + # Re = 50000.0 + Re = 500000000000.0 clength = grid_size_x - 1 visc = wind_speed * clength / Re omega = 1.0 / (3.0 * visc + 0.5) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 907c1f2..1812d95 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_st start_time = time.time() for i in range(num_steps): - f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i) + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i) f_0, f_1 = f_1, f_0 wp.synchronize() diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index b7ede03..925dfdc 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -9,3 +9,4 @@ from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC +from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC as GradsApproximationBC diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 0ddbcfc..55ce9ed 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -56,10 +56,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): return f_pre @@ -79,8 +82,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -105,8 +108,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 6d4e3ed..8c33d29 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -79,10 +79,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f @@ -104,8 +107,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -130,8 +133,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 55a5851..8b5f139 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -123,7 +123,7 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -160,10 +160,13 @@ def get_normal_vectors_3d( # Construct the functionals for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -171,23 +174,60 @@ def functional( # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): _f[l] = f_pre[_opp_indices[l]] - return _f @wp.func - def prepare_bc_auxilary_data( + def prepare_bc_auxilary_data_2d( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, + ): + # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and + # f_pre (post-streaming values of the current voxel). We use directions that leave the domain + # for storing this prepared data. + _f = f_post + nv = get_normal_vectors_2d(missing_mask) + for l in range(self.velocity_set.q): + if missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) + _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux + return _f + + @wp.func + def prepare_bc_auxilary_data_3d( + index: Any, + timestep: Any, missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, ): # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and - # f_pre (posti-streaming values of the current voxel). We use directions that leave the domain + # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. _f = f_post + nv = get_normal_vectors_3d(missing_mask) for l in range(self.velocity_set.q): if missing_mask[l] == wp.uint8(1): - _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) + _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux return _f # Construct the warp kernel @@ -201,29 +241,20 @@ def kernel2d( # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) + timestep = 0 # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - _f_aux = _f_vec() # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - nv = get_normal_vectors_2d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] + _f_pre = prepare_bc_auxilary_data_2d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post) # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both # collision and streaming? - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -242,6 +273,7 @@ def kernel3d( # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) + timestep = 0 # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) @@ -249,22 +281,13 @@ def kernel3d( # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - nv = get_normal_vectors_3d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] + _f_pre = prepare_bc_auxilary_data_3d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post) # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -273,6 +296,7 @@ def kernel3d( f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + prepare_bc_auxilary_data = prepare_bc_auxilary_data_3d if self.velocity_set.d == 3 else prepare_bc_auxilary_data_2d return (functional, prepare_bc_auxilary_data), kernel diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 57d29fd..29f83c1 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -48,7 +48,7 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -63,10 +63,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): fliped_f = _f_vec() for l in range(_q): @@ -88,8 +91,8 @@ def kernel2d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -114,8 +117,8 @@ def kernel3d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py new file mode 100644 index 0000000..870635e --- /dev/null +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -0,0 +1,353 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +import jax.lax as lax +from functools import partial +import warp as wp +from typing import Any +from collections import Counter +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.macroscopic.zero_moment import ZeroMoment +from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + + +class GradsApproximationBC(BoundaryCondition): + """ + Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and + Dirichlet BCs [2] + [1] S. Chikatamarla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann + simulations", Europhys. Lett. 74, 215 (2006). + [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + mesh_vertices=None, + ): + # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. + self.u = (0, 0, 0) + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + indices, + mesh_vertices, + ) + + # Instantiate the operator for computing macroscopic values + self.macroscopic = Macroscopic() + self.zero_moment = ZeroMoment() + self.equilibrium = QuadraticEquilibrium() + self.momentum_flux = MomentumFlux() + + # This BC needs implicit distance to the mesh + self.needs_mesh_distance = True + + # If this BC is defined using indices, it would need padding in order to find missing directions + # when imposed on a geometry that is in the domain interior + if self.mesh_vertices is None: + assert self.indices is not None + self.needs_padding = True + + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This BC is not implemented in 2D!") + + # if indices is not None: + # # this BC would be limited to stationary boundaries + # # assert mesh_vertices is None + # if mesh_vertices is not None: + # # this BC would be applicable for stationary and moving boundaries + # assert indices is None + # if mesh_velocity_function is not None: + # # mesh is moving and/or deforming + + assert self.compute_backend == ComputeBackend.WARP, "This BC is currently only implemented with the Warp backend!" + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # TODO + raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") + return + + def _construct_warp(self): + # Set local variables and constants + _c = self.velocity_set.c + _q = self.velocity_set.q + _d = self.velocity_set.d + _w = self.velocity_set.w + _qi = self.velocity_set.qi + _opp_indices = self.velocity_set.opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _u_wall = _u_vec(self.u[0], self.u[1], self.u[2]) if _d == 3 else _u_vec(self.u[0], self.u[1]) + # diagonal = wp.vec3i(0, 3, 5) if _d == 3 else wp.vec2i(0, 2) + + @wp.func + def regularize_fpop( + missing_mask: Any, + rho: Any, + u: Any, + fpop: Any, + ): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + """ + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + feq = self.equilibrium.warp_functional(rho, u) + f_neq = fpop - feq + PiNeq = self.momentum_flux.warp_functional(f_neq) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + QiPi1 = self.compute_dtype(0.0) + for t in range(nt): + QiPi1 += _qi[l, t] * PiNeq[t] + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 + fpop[l] = feq[l] + fpop1 + return fpop + + @wp.func + def grads_approximate_fpop( + missing_mask: Any, + rho: Any, + u: Any, + f_post: Any, + ): + # Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and + # Dirichlet BCs [2] + # [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann + # simulations", Europhys. Lett. 74, 215 (2006). + # [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + + # Note: See also self.regularize_fpop function which is somewhat similar. + + # Compute pressure tensor Pi using all f_post-streaming values + Pi = self.momentum_flux.warp_functional(f_post) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + # if missing_mask[l] == wp.uint8(1): + QiPi = self.compute_dtype(0.0) + for t in range(nt): + if t == 0 or t == 3 or t == 5: + QiPi += _qi[l, t] * (Pi[t] - rho / self.compute_dtype(3.0)) + else: + QiPi += _qi[l, t] * Pi[t] + + # Compute c.u + cu = self.compute_dtype(0.0) + for d in range(self.velocity_set.d): + if _c[d, l] == 1: + cu += u[d] + elif _c[d, l] == -1: + cu -= u[d] + cu *= self.compute_dtype(3.0) + + # change f_post using the Grad's approximation + f_post[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu) + _w[l] * self.compute_dtype(4.5) * QiPi + + return f_post + + # Construct the functionals for this BC + @wp.func + def functional_method1( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + ): + # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. + # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). + one = self.compute_dtype(1.0) + for l in range(_q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 + # weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + weight = self.compute_dtype(0.5) + + # Use differentiable interpolated BB to find f_missing: + f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) + + # # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + # cu = self.compute_dtype(0.0) + # for d in range(_d): + # if _c[d, l] == 1: + # cu += _u_wall[d] + # elif _c[d, l] == -1: + # cu -= _u_wall[d] + # cu *= self.compute_dtype(-6.0) * _w[l] + # f_post[l] += cu + + # Compute density, velocity using all f_post-streaming values + rho, u = self.macroscopic.warp_functional(f_post) + + # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. + f_post = regularize_fpop(missing_mask, rho, u, f_post) + # f_post = grads_approximate_fpop(missing_mask, rho, u, f_post) + return f_post + + # Construct the functionals for this BC + @wp.func + def functional_method2( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + ): + # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. + # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). + # NOTE: f_aux should contain populations at "x_f" (see their fig 1) in the missign direction of the BC which amounts + # to post-collision values being pulled from appropriate cells like ExtrapolationBC + # + # here I need to compute all terms in Eq (10) + # Strategy: + # 1) "weights" should have been stored somewhere to be used here. + # 2) Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) + # NOTE: in the original paper "u_target" is associated with the previous time step not current time. + # 3) Given "weights" use differentiable interpolated BB to find f_missing as I had before: + # fmissing = ((1. - weights) * f_poststreaming_iknown + weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + weights) + # 4) Add contribution due to u_w to f_missing as is usual in regular Bouzidi BC (ie. -6.0 * self.lattice.w * jnp.dot(self.vel, c) + # 5) Compute rho_target = \sum(f_ibb) based on these values + # 6) Compute feq using feq = self.equilibrium(rho_target, u_target) + # 7) Compute Pi_neq and Pi_eq using all f_post-streaming values as per: + # Pi_neq = self.momentum_flux(fneq) and Pi_eq = self.momentum_flux(feq) + # 8) Compute Grad's appriximation using full equation as in Eq (10) + # NOTE: this is very similar to the regularization procedure. + + _f_nbr = _f_vec() + u_target = _u_vec(0.0, 0.0, 0.0) if _d == 3 else _u_vec(0.0, 0.0) + num_missing = 0 + one = self.compute_dtype(1.0) + for l in range(_q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + # Find the neighbour and its velocity value + for ll in range(_q): + # f_0 is the post-collision values of the current time-step + # Get index associated with the fluid neighbours + fluid_nbr_index = type(index)() + for d in range(_d): + fluid_nbr_index[d] = index[d] + _c[d, l] + # The following is the post-collision values of the fluid neighbor cell + _f_nbr[ll] = self.compute_dtype(f_0[ll, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) + + # Compute the velocity vector at the fluid neighbouring cells + _, u_f = self.macroscopic.warp_functional(_f_nbr) + + # Record the number of missing directions + num_missing += 1 + + # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 + weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + + # Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) + for d in range(_d): + u_target[d] += (weight * u_f[d] + _u_wall[d]) / (one + weight) + + # Use differentiable interpolated BB to find f_missing: + f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) + + # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + cu = self.compute_dtype(0.0) + for d in range(_d): + if _c[d, l] == 1: + cu += _u_wall[d] + elif _c[d, l] == -1: + cu -= _u_wall[d] + cu *= self.compute_dtype(-6.0) * _w[l] + f_post[l] += cu + + # Compute rho_target = \sum(f_ibb) based on these values + rho_target = self.zero_moment.warp_functional(f_post) + for d in range(_d): + u_target[d] /= num_missing + + # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. + f_post = grads_approximate_fpop(missing_mask, rho_target, u_target, f_post) + return f_post + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + timestep = 0 + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) + _f_aux = _f_vec() + + # Apply the boundary condition + if _boundary_id == wp.uint8(GradsApproximationBC.id): + # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both + # collision and streaming? + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) + + functional = functional_method1 + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index e723570..6e787c2 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -54,7 +54,7 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -71,10 +71,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -103,8 +106,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -129,8 +132,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index a42b695..bb4b5f0 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -105,7 +105,7 @@ def regularize_fpop(self, fpop, feq): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): # creat a mask to slice boundary cells boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] @@ -200,24 +200,26 @@ def regularize_fpop( # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) nt = _d * (_d + 1) // 2 - QiPi1 = _f_vec() for l in range(_q): - QiPi1[l] = self.compute_dtype(0.0) + QiPi1 = self.compute_dtype(0.0) for t in range(nt): - QiPi1[l] += _qi[l, t] * PiNeq[t] + QiPi1 += _qi[l, t] * PiNeq[t] # assign all populations based on eq 45 of Latt et al (2008) # fneq ~ f^1 - fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1[l] + fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 fpop[l] = feq[l] + fpop1 return fpop @wp.func def functional3d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -242,10 +244,13 @@ def functional3d_velocity( @wp.func def functional3d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -268,10 +273,13 @@ def functional3d_pressure( @wp.func def functional2d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -296,10 +304,13 @@ def functional2d_velocity( @wp.func def functional2d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -337,8 +348,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -363,8 +374,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 4e9fe29..66b6377 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -159,7 +159,7 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): # creat a mask to slice boundary cells boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] @@ -234,10 +234,13 @@ def bounceback_nonequilibrium( @wp.func def functional3d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -259,10 +262,13 @@ def functional3d_velocity( @wp.func def functional3d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -282,10 +288,13 @@ def functional3d_pressure( @wp.func def functional2d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -307,10 +316,13 @@ def functional2d_velocity( @wp.func def functional2d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -345,8 +357,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -371,8 +383,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 9f6ef5d..be920bf 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -52,6 +52,9 @@ def __init__( # when inside/outside of the geoemtry is not known self.needs_padding = False + # A flag for BCs that need implicit boundary distance between the grid and a mesh (to be set to True if applicable inside each BC) + self.needs_mesh_distance = False + if self.compute_backend == ComputeBackend.WARP: # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) @@ -59,10 +62,13 @@ def __init__( @wp.func def prepare_bc_auxilary_data( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): return f_post diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index 20b16b5..fbe851d 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -4,3 +4,6 @@ from xlb.operator.boundary_masker.mesh_boundary_masker import ( MeshBoundaryMasker as MeshBoundaryMasker, ) +from xlb.operator.boundary_masker.mesh_distance_boundary_masker import ( + MeshDistanceBoundaryMasker as MeshDistanceBoundaryMasker, +) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index b88d251..edfdb42 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -23,6 +23,10 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This Operator is not implemented in 2D!") + # Also using Warp kernels for JAX implementation if self.compute_backend == ComputeBackend.JAX: self.warp_functional, self.warp_kernel = self._construct_warp() @@ -33,6 +37,7 @@ def jax_implementation( bc, origin, spacing, + id_number, bc_mask, missing_mask, start_index=(0, 0, 0), @@ -83,15 +88,12 @@ def kernel( ) # evaluate if point is inside mesh - face_index = int(0) - face_u = float(0.0) - face_v = float(0.0) - sign = float(0.0) - if wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length, sign, face_index, face_u, face_v): + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length) + if query.result: # set point to be solid - if sign <= 0: # TODO: fix this + if query.sign <= 0: # TODO: fix this # Stream indices - for l in range(_q): + for l in range(1, _q): # Get the index of the streaming direction push_index = wp.vec3i() for d in range(self.velocity_set.d): @@ -113,7 +115,7 @@ def warp_implementation( missing_mask, start_index=(0, 0, 0), ): - assert bc.mesh_vertices is not None, f'Please provide the mesh points for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' + assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" assert ( bc.mesh_vertices.shape[1] == self.velocity_set.d @@ -124,6 +126,9 @@ def warp_implementation( # We are done with bc.mesh_vertices. Remove them from BC objects bc.__dict__.pop("mesh_vertices", None) + # Ensure this masker is called only for BCs that need implicit distance to the mesh + assert not bc.needs_mesh_distance, 'Please use "MeshDistanceBoundaryMasker" if this BC needs mesh distance!' + mesh_indices = np.arange(mesh_vertices.shape[0]) mesh = wp.Mesh( points=wp.array(mesh_vertices, dtype=wp.vec3), diff --git a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py new file mode 100644 index 0000000..87af94c --- /dev/null +++ b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py @@ -0,0 +1,184 @@ +# Base class for all equilibriums + +import numpy as np +import warp as wp +import jax +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class MeshDistanceBoundaryMasker(Operator): + """ + Operator for creating a boundary missing_mask from an STL file + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.WARP, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This Operator is not implemented in 2D!") + + # Also using Warp kernels for JAX implementation + if self.compute_backend == ComputeBackend.JAX: + self.warp_functional, self.warp_kernel = self._construct_warp() + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation( + self, + bc, + origin, + spacing, + id_number, + bc_mask, + missing_mask, + f_field, + start_index=(0, 0, 0), + ): + raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.c + _q = wp.constant(self.velocity_set.q) + _opp_indices = self.velocity_set.opp_indices + + @wp.func + def check_index_bounds(index: wp.vec3i, shape: wp.vec3i): + is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2] + return is_in_bounds + + @wp.func + def index_to_position(index: wp.vec3i, origin: wp.vec3, spacing: wp.vec3): + # position of the point + ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) + ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center + pos = wp.cw_mul(ijk, spacing) + origin + return pos + + # Construct the warp kernel + @wp.kernel + def kernel( + mesh_id: wp.uint64, + origin: wp.vec3, + spacing: wp.vec3, + id_number: wp.int32, + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + f_field: wp.array4d(dtype=Any), + start_index: wp.vec3i, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i() + index[0] = i - start_index[0] + index[1] = j - start_index[1] + index[2] = k - start_index[2] + + # position of the point + pos_solid_cell = index_to_position(index, origin, spacing) + + # Compute the maximum length + max_length = wp.sqrt( + (spacing[0] * wp.float32(missing_mask.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(missing_mask.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(missing_mask.shape[3])) ** 2.0 + ) + + # evaluate if point is inside mesh + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_solid_cell, max_length) + if query.result and query.sign <= 0: # TODO: fix this + # Set bc_mask of solid to a large number to enable skipping LBM operations + bc_mask[0, index[0], index[1], index[2]] = wp.uint8(255) + + # Find neighboring fluid cells along each lattice direction and the their fractional distance to the mesh + for l in range(1, _q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3]) + if check_index_bounds(push_index, shape): + # find neighbouring fluid cell + pos_fluid_cell = index_to_position(push_index, origin, spacing) + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_fluid_cell, max_length) + if query.result and query.sign > 0: + # Set the boundary id and missing_mask + bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + + # get position of the mesh triangle that intersects with the solid cell + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + weight = wp.length(pos_fluid_cell - pos_mesh) / wp.length(pos_fluid_cell - pos_solid_cell) + f_field[_opp_indices[l], push_index[0], push_index[1], push_index[2]] = self.store_dtype(weight) + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + bc, + origin, + spacing, + bc_mask, + missing_mask, + f_field, + start_index=(0, 0, 0), + ): + assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' + assert bc.indices is None, f"Cannot find the implicit distance to the boundary for {bc.__class__.__name__} without a mesh!" + assert ( + bc.mesh_vertices.shape[1] == self.velocity_set.d + ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + assert ( + f_field is not None and f_field.shape == missing_mask.shape + ), "To compute and store the implicit distance to the boundary for this BC, use a population field!" + mesh_vertices = bc.mesh_vertices + id_number = bc.id + + # We are done with bc.mesh_vertices. Remove them from BC objects + bc.__dict__.pop("mesh_vertices", None) + + # Ensure this masker is called only for BCs that need implicit distance to the mesh + assert bc.needs_mesh_distance, 'Please use "MeshBoundaryMasker" if this BC does NOT need mesh distance!' + + mesh_indices = np.arange(mesh_vertices.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_vertices, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=int), + ) + + # Convert input tuples to warp vectors + origin = wp.vec3(origin[0], origin[1], origin[2]) + spacing = wp.vec3(spacing[0], spacing[1], spacing[2]) + start_index = wp.vec3i(start_index[0], start_index[1], start_index[2]) + mesh_id = wp.uint64(mesh.id) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + mesh_id, + origin, + spacing, + id_number, + bc_mask, + missing_mask, + f_field, + start_index, + ], + dim=missing_mask.shape[1:], + ) + + return bc_mask, missing_mask, f_field diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index bc731c6..cc2fb04 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -217,9 +217,9 @@ def decompose_shear_d3q27( s = _f_vec() # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) - two = self.self.compute_dtype(2.0) - four = self.self.compute_dtype(4.0) - six = self.self.compute_dtype(6.0) + two = self.compute_dtype(2.0) + four = self.compute_dtype(4.0) + six = self.compute_dtype(6.0) s[9] = (two * nxz - nyz) / six s[18] = (two * nxz - nyz) / six diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index da64e67..8b0aacf 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -90,6 +90,7 @@ def _construct_warp(self): _c = self.velocity_set.c _opp_indices = self.velocity_set.opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool _no_slip_id = self.no_slip_bc_instance.id @@ -102,7 +103,8 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel2d( - f: wp.array3d(dtype=Any), + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -128,16 +130,17 @@ def kernel2d( is_edge = wp.bool(True) # If the boundary is an edge then add the momentum transfer - m = wp.vec2() + m = _u_vec() if is_edge: # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = f[l, index[0], index[1]] + f_post_collision[l] = f_0[l, index[0], index[1]] # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + timestep = 0 + f_post_stream = self.stream.warp_functional(f_0, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -145,14 +148,18 @@ def kernel2d( for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) + if _c[d, _opp_indices[l]] == 1: + m[d] += phi + elif _c[d, _opp_indices[l]] == -1: + m[d] -= phi wp.atomic_add(force, 0, m) # Construct the warp kernel @wp.kernel def kernel3d( - f: wp.array4d(dtype=Any), + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -178,16 +185,17 @@ def kernel3d( is_edge = wp.bool(True) # If the boundary is an edge then add the momentum transfer - m = wp.vec3() + m = _u_vec() if is_edge: # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = f[l, index[0], index[1], index[2]] + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + timestep = 0 + f_post_stream = self.stream.warp_functional(f_0, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -195,7 +203,10 @@ def kernel3d( for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) + if _c[d, _opp_indices[l]] == 1: + m[d] += phi + elif _c[d, _opp_indices[l]] == -1: + m[d] -= phi wp.atomic_add(force, 0, m) @@ -205,14 +216,15 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, bc_mask, missing_mask): + def warp_implementation(self, f_0, f_1, bc_mask, missing_mask): # Allocate the force vector (the total integral value will be computed) - force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + force = wp.zeros((1), dtype=_u_vec) # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f, bc_mask, missing_mask, force], - dim=f.shape[1:], + inputs=[f_0, f_1, bc_mask, missing_mask, force], + dim=f_0.shape[1:], ) return force.numpy()[0] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 05cee7b..62790a6 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -88,14 +88,13 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): # Copy back to store precision f_1 = self.precision_policy.cast_to_store_jax(f_post_collision) - return f_1 + return f_0, f_1 def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - _c = self.velocity_set.c - _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices @wp.struct class BoundaryConditionIDStruct: @@ -111,163 +110,115 @@ class BoundaryConditionIDStruct: id_RegularizedBC_velocity: wp.uint8 id_RegularizedBC_pressure: wp.uint8 id_ExtrapolationOutflowBC: wp.uint8 + id_GradsApproximationBC: wp.uint8 @wp.func def apply_post_streaming_bc( - f_pre: Any, - f_post: Any, - f_aux: Any, - missing_mask: Any, + index: Any, + timestep: Any, _boundary_id: Any, bc_struct: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, ): # Apply post-streaming type boundary conditions + # NOTE: 'f_pre' is included here as an input to the BC functionals for consistency with the BC API, + # particularly when compared to post-collision boundary conditions (see below). + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + elif _boundary_id == bc_struct.id_GradsApproximationBC: + # Reformulated Grads boundary condition + f_post = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_post @wp.func def apply_post_collision_bc( - f_pre: Any, - f_post: Any, - f_aux: Any, - missing_mask: Any, + index: Any, + timestep: Any, _boundary_id: Any, bc_struct: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, ): + # Apply post-collision type boundary conditions or special boundary preparations if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain - f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask) - + f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_post - @wp.func - def get_normal_vectors_2d( - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -wp.vec2i(_c[0, l], _c[1, l]) - - @wp.func - def get_normal_vectors_3d( - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) - @wp.func def get_thread_data_2d( - f_0: wp.array3d(dtype=Any), + f0_buffer: wp.array3d(dtype=Any), + f1_buffer: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), index: Any, ): - # Get the boundary id and missing mask - _f_post_collision = _f_vec() + # Read thread data for populations and missing mask + f0_thread = _f_vec() + f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): - # q-sized vector of pre-streaming populations - _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1]]) - - # TODO fix vec bool + f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) + f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_post_collision, _missing_mask + return f0_thread, f1_thread, _missing_mask @wp.func def get_thread_data_3d( - f_0: wp.array4d(dtype=Any), + f0_buffer: wp.array4d(dtype=Any), + f1_buffer: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), index: Any, ): - # Get the boundary id and missing mask - _f_post_collision = _f_vec() + # Read thread data for populations + f0_thread = _f_vec() + f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1], index[2]]) - - # TODO fix vec bool + f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) + f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_post_collision, _missing_mask - @wp.func - def get_bc_auxilary_data_2d( - f_0: wp.array3d(dtype=Any), - index: Any, - _boundary_id: Any, - _missing_mask: Any, - bc_struct: Any, - ): - # special preparation of auxiliary data - f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - nv = get_normal_vectors_2d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) - return f_auxiliary - - @wp.func - def get_bc_auxilary_data_3d( - f_0: wp.array4d(dtype=Any), - index: Any, - _boundary_id: Any, - _missing_mask: Any, - bc_struct: Any, - ): - # special preparation of auxiliary data - f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - nv = get_normal_vectors_3d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) - return f_auxiliary + return f0_thread, f1_thread, _missing_mask @wp.kernel def kernel2d( @@ -282,18 +233,20 @@ def kernel2d( i, j = wp.tid() index = wp.vec2i(i, j) # TODO warp should fix this - # Read thread data for populations and missing mask - _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + # Get the boundary id + _boundary_id = bc_mask[0, index[0], index[1]] + if _boundary_id == wp.uint8(255): + return # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) - # Prepare auxilary data for BC (if applicable) - _boundary_id = bc_mask[0, index[0], index[1]] - _f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) - # Apply post-streaming type boundary conditions - _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f0_thread, f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) + _f_post_collision = f0_thread + _f_post_stream = apply_post_streaming_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream + ) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -305,7 +258,9 @@ def kernel2d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision + ) # Set the output for l in range(self.velocity_set.q): @@ -325,18 +280,20 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) # TODO warp should fix this - # Read thread data for populations and missing mask - _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + # Get the boundary id + _boundary_id = bc_mask[0, index[0], index[1], index[2]] + if _boundary_id == wp.uint8(255): + return # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) - # Prepare auxilary data for BC (if applicable) - _boundary_id = bc_mask[0, index[0], index[1], index[2]] - _f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) - # Apply post-streaming type boundary conditions - _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f0_thread, f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) + _f_post_collision = f0_thread + _f_post_stream = apply_post_streaming_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream + ) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -348,10 +305,17 @@ def kernel3d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision + ) # Set the output for l in range(self.velocity_set.q): + # TODO 1: fix the perf drop due to l324-l236 even in cases where this BC is not used. + # TODO 2: is there better way to move these lines to a function inside BC class like "restore_bc_data" + # if _boundary_id == bc_struct.id_GradsApproximationBC: + # if _missing_mask[l] == wp.uint8(1): + # f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel @@ -408,4 +372,4 @@ def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): ], dim=f_0.shape[1:], ) - return f_1 + return f_0, f_1 diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index d85deed..7d31c8a 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -86,4 +86,4 @@ def cast_to_compute_jax(self, array): def cast_to_store_jax(self, array): store_precision = self.store_precision - return jnp.array(array, dtype=store_precision.jax_dtype) \ No newline at end of file + return jnp.array(array, dtype=store_precision.jax_dtype)