diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 0ce07e9..220489d 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -3,13 +3,7 @@ from xlb.precision_policy import PrecisionPolicy from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import ( - FullwayBounceBackBC, - ZouHeBC, - RegularizedBC, - EquilibriumBC, - DoNothingBC, -) +from xlb.operator.boundary_condition import FullwayBounceBackBC, ZouHeBC, RegularizedBC, EquilibriumBC, DoNothingBC, ExtrapolationOutflowBC from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image @@ -72,8 +66,9 @@ def setup_boundary_conditions(self): bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet) # bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) + # bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) # bc_outlet = DoNothingBC(indices=outlet) + bc_outlet = ExtrapolationOutflowBC(indices=outlet) bc_sphere = FullwayBounceBackBC(indices=sphere) self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index eaac67e..460b217 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -9,6 +9,7 @@ FullwayBounceBackBC, EquilibriumBC, DoNothingBC, + ExtrapolationOutflowBC, ) from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker @@ -81,9 +82,9 @@ def setup_boundary_conditions(self, wind_speed): inlet, outlet, walls, car = self.define_boundary_indices() bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_do_nothing = DoNothingBC(indices=outlet) + bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) bc_car = FullwayBounceBackBC(indices=car) - self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_car] + self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 506da1d..b7ede03 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -8,3 +8,4 @@ from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC +from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 5049b1f..e8a91ee 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -56,6 +56,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): return f_pre @@ -76,7 +77,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -101,7 +103,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 1937a17..27d5eb2 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -79,6 +79,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) @@ -101,7 +102,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -126,7 +128,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py new file mode 100644 index 0000000..d16068b --- /dev/null +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -0,0 +1,285 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +import jax.lax as lax +from functools import partial +import warp as wp +from typing import Any +from collections import Counter +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + + +class ExtrapolationOutflowBC(BoundaryCondition): + """ + Extrapolation outflow boundary condition for a lattice Boltzmann method simulation. + + This class implements the extrapolation outflow boundary condition, which is a type of outflow boundary condition + that uses extrapolation to avoid strong wave reflections. + + References + ---------- + Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three + dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507-547. + doi:10.1016/j.camwa.2015.05.001. + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + ): + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + indices, + ) + + # find and store the normal vector using indices + self._get_normal_vec(indices) + + # Unpack the two warp functionals needed for this BC! + if self.compute_backend == ComputeBackend.WARP: + self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional + + def _get_normal_vec(self, indices): + # Get the frequency count and most common element directly + freq_counts = [Counter(coord).most_common(1)[0] for coord in indices] + + # Extract counts and elements + counts = np.array([count for _, count in freq_counts]) + elements = np.array([element for element, _ in freq_counts]) + + # Normalize the counts + self.normal = counts // counts.max() + + # Reverse the normal vector if the most frequent element is 0 + if elements[np.argmax(counts)] == 0: + self.normal *= -1 + return + + @partial(jit, static_argnums=(0,), inline=True) + def _roll(self, fld, vec): + """ + Perform rolling operation of a field with dimentions [q, nx, ny, nz] in a direction + given by vec. All q-directions are rolled at the same time. + # TODO: how to improve this for multi-gpu runs? + """ + if self.velocity_set.d == 2: + return jnp.roll(fld, (vec[0], vec[1]), axis=(1, 2)) + elif self.velocity_set.d == 3: + return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) + + @partial(jit, static_argnums=(0,), inline=True) + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + """ + Prepare the auxilary distribution functions for the boundary condition. + Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision + """ + sound_speed = 1.0 / jnp.sqrt(3.0) + boundary = boundary_mask == self.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + + # Roll boundary mask in the opposite of the normal vector to mask its next immediate neighbour + neighbour = self._roll(boundary, -self.normal) + + # gather post-streaming values associated with previous time-step to construct the auxilary data for BC + fpop = jnp.where(boundary, f_pre, f_post) + fpop_neighbour = jnp.where(neighbour, f_pre, f_post) + + # With fpop_neighbour isolated, now roll it back to be positioned at the boundary for subsequent operations + fpop_neighbour = self._roll(fpop_neighbour, self.normal) + fpop_extrapolated = sound_speed * fpop_neighbour + (1.0 - sound_speed) * fpop + + # Use the iknown directions of f_postcollision that leave the domain during streaming to store the BC data + opp = self.velocity_set.opp_indices + known_mask = missing_mask[opp] + f_post = jnp.where(jnp.logical_and(boundary, known_mask), fpop_extrapolated[opp], f_post) + return f_post + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + boundary = boundary_mask == self.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + return jnp.where( + jnp.logical_and(missing_mask, boundary), + f_pre[self.velocity_set.opp_indices], + f_post, + ) + + def _construct_warp(self): + # Set local constants + sound_speed = 1.0 / wp.sqrt(3.0) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _c = self.velocity_set.wp_c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.wp_opp_indices + + @wp.func + def get_normal_vectors_2d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -wp.vec2i(_c[0, l], _c[1, l]) + + @wp.func + def get_normal_vectors_3d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) + + # Construct the functionals for this BC + @wp.func + def functional( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + for l in range(self.velocity_set.q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + _f[l] = f_pre[_opp_indices[l]] + + return _f + + @wp.func + def prepare_bc_auxilary_data( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and + # f_pre (posti-streaming values of the current voxel). We use directions that leave the domain + # for storing this prepared data. + _f = f_post + for l in range(self.velocity_set.q): + if missing_mask[l] == wp.uint8(1): + _f[_opp_indices[l]] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + return _f + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_aux = _f_vec() + + # special preparation of auxiliary data + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + nv = get_normal_vectors_2d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] + + # Apply the boundary condition + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both + # collision and streaming? + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1]] = _f[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_aux = _f_vec() + + # special preparation of auxiliary data + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + nv = get_normal_vectors_3d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] + + # Apply the boundary condition + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both + # collision and streaming? + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return (functional, prepare_bc_auxilary_data), kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 0083bae..6272aca 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -63,6 +63,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): fliped_f = _f_vec() @@ -85,7 +86,8 @@ def kernel2d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -110,7 +112,8 @@ def kernel3d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 2ed0067..a363479 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -68,6 +68,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -90,14 +91,15 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) + index = wp.vec2i(i, j) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -122,7 +124,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 413a37f..84dbbf9 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -16,7 +16,7 @@ from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry -from xlb.operator.macroscopic.second_moment import SecondMoment +from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux class RegularizedBC(ZouHeBC): @@ -61,26 +61,7 @@ def __init__( ) # The operator to compute the momentum flux - self.momentum_flux = SecondMoment() - - # helper function - def compute_qi(self): - # Qi = cc - cs^2*I - dim = self.velocity_set.d - Qi = self.velocity_set.cc - if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {dim} not supported") - - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi[:, diagonal] += -1.0 / 3.0 - Qi[:, offdiagonal] *= 2.0 - return Qi + self.momentum_flux = MomentumFlux() @partial(jit, static_argnums=(0,), inline=True) def regularize_fpop(self, fpop, feq): @@ -102,22 +83,7 @@ def regularize_fpop(self, fpop, feq): # Qi = cc - cs^2*I dim = self.velocity_set.d weights = self.velocity_set.w[(slice(None),) + (None,) * dim] - # TODO: if I use the following I get NaN ! figure out why! - # Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype) - Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype) - if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {dim} not supported") - - # Qi = cc - cs^2*I - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) - Qi = Qi.at[:, offdiagonal].multiply(2.0) + Qi = jnp.array(self.velocity_set.qi, dtype=self.compute_dtype) # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} f_neq = fpop - feq @@ -166,7 +132,6 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) # compute Qi tensor and store it in self - _qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi())) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(rho) @@ -175,16 +140,8 @@ def _construct_warp(self): _w = self.velocity_set.wp_w _c = self.velocity_set.wp_c _c32 = self.velocity_set.wp_c32 - # TODO: this is way less than ideal. we should not be making new types - - @wp.func - def get_normal_vectors_2d( - lattice_direction: Any, - ): - l = lattice_direction - if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l]) - return normals + _qi = self.velocity_set.wp_qi + # TODO: related to _c32: this is way less than ideal. we should not be making new types @wp.func def _get_fsum( @@ -200,6 +157,14 @@ def _get_fsum( fsum_middle += fpop[l] return fsum_known + fsum_middle + @wp.func + def get_normal_vectors_2d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -_u_vec(_c32[0, l], _c32[1, l]) + @wp.func def get_normal_vectors_3d( missing_mask: Any, @@ -249,6 +214,7 @@ def regularize_fpop( def functional3d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -276,6 +242,7 @@ def functional3d_velocity( def functional3d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -301,6 +268,7 @@ def functional3d_pressure( def functional2d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -328,6 +296,7 @@ def functional2d_velocity( def functional2d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -359,14 +328,15 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) + index = wp.vec2i(i, j) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -391,7 +361,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 3b69b21..61783f8 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -231,6 +231,7 @@ def bounceback_nonequilibrium( def functional3d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -255,6 +256,7 @@ def functional3d_velocity( def functional3d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -277,6 +279,7 @@ def functional3d_pressure( def functional2d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -301,6 +304,7 @@ def functional2d_velocity( def functional2d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -329,14 +333,15 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) + index = wp.vec2i(i, j) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -361,7 +366,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index f4e0a1b..29ca2db 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -5,6 +5,8 @@ from enum import Enum, auto import warp as wp from typing import Any +from jax import jit +from functools import partial from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -49,6 +51,15 @@ def __init__( _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + @wp.func + def prepare_bc_auxilary_data( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + return f_post + @wp.func def _get_thread_data_2d( f_pre: wp.array3d(dtype=Any), @@ -103,3 +114,12 @@ def _get_thread_data_3d( if self.compute_backend == ComputeBackend.WARP: self._get_thread_data_2d = _get_thread_data_2d self._get_thread_data_3d = _get_thread_data_3d + self.prepare_bc_auxilary_data = prepare_bc_auxilary_data + + @partial(jit, static_argnums=(0,), inline=True) + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + """ + A placeholder function for prepare the auxilary distribution functions for the boundary condition. + currently being called after collision only. + """ + return f_post diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index da0aee5..ddd7ecc 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -12,7 +12,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision from xlb.operator import Operator -from xlb.operator.macroscopic import SecondMoment +from xlb.operator.macroscopic import SecondMoment as MomentumFlux class KBC(Collision): @@ -29,7 +29,7 @@ def __init__( precision_policy=None, compute_backend=None, ): - self.momentum_flux = SecondMoment() + self.momentum_flux = MomentumFlux() self.epsilon = 1e-32 self.beta = omega * 0.5 self.inv_beta = 1.0 / self.beta @@ -311,7 +311,7 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) # TODO: Warp needs to fix this + index = wp.vec2i(i, j) # TODO: Warp needs to fix this # Load needed values _f = _f_vec() diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index bd2403d..efbf847 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -58,6 +58,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: + f_post_collision = bc.prepare_bc_auxilary_data(f_0, f_post_collision, boundary_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_0, @@ -88,6 +89,10 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _c = self.velocity_set.wp_c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.wp_opp_indices + sound_speed = 1.0 / wp.sqrt(3.0) @wp.struct class BoundaryConditionIDStruct: @@ -102,11 +107,13 @@ class BoundaryConditionIDStruct: id_ZouHeBC_pressure: wp.uint8 id_RegularizedBC_velocity: wp.uint8 id_RegularizedBC_pressure: wp.uint8 + id_ExtrapolationOutflowBC: wp.uint8 @wp.func def apply_post_streaming_bc( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, _boundary_id: Any, bc_struct: Any, @@ -114,56 +121,73 @@ def apply_post_streaming_bc( # Apply post-streaming type boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, missing_mask) + f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, missing_mask) + f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, missing_mask) + f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, missing_mask) + f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) + elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + # Regularized boundary condition (bc type = velocity) + f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) return f_post @wp.func def apply_post_collision_bc( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, _boundary_id: Any, bc_struct: Any, ): if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + # f_aux is the neighbour's post-streaming values + # Storing post-streaming data in directions that leave the domain + f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask) + return f_post - @wp.kernel - def kernel2d( + @wp.func + def get_normal_vectors_2d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -wp.vec2i(_c[0, l], _c[1, l]) + + @wp.func + def get_normal_vectors_3d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) + + @wp.func + def get_thread_data_2d( f_0: wp.array3d(dtype=Any), - f_1: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: Any, - timestep: int, + index: Any, ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) # TODO warp should fix this - # Get the boundary id and missing mask f_post_collision = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations @@ -174,12 +198,99 @@ def kernel2d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) + return f_post_collision, _missing_mask + + @wp.func + def get_thread_data_3d( + f_0: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=Any), + index: Any, + ): + # Get the boundary id and missing mask + f_post_collision = _f_vec() + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + return f_post_collision, _missing_mask + + @wp.func + def get_bc_auxilary_data_2d( + f_0: wp.array3d(dtype=Any), + index: Any, + _boundary_id: Any, + _missing_mask: Any, + bc_struct: Any, + ): + # special preparation of auxiliary data + f_auxiliary = _f_vec() + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + nv = get_normal_vectors_2d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1]] + return f_auxiliary + + @wp.func + def get_bc_auxilary_data_3d( + f_0: wp.array4d(dtype=Any), + index: Any, + _boundary_id: Any, + _missing_mask: Any, + bc_struct: Any, + ): + # special preparation of auxiliary data + f_auxiliary = _f_vec() + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + nv = get_normal_vectors_3d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] + return f_auxiliary + + @wp.kernel + def kernel2d( + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=Any), + missing_mask: wp.array3d(dtype=Any), + bc_struct: Any, + timestep: int, + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) # TODO warp should fix this + + # Read thread data for populations and missing mask + f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) + # Prepare auxilary data for BC (if applicable) + _boundary_id = boundary_mask[0, index[0], index[1]] + f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -196,7 +307,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -216,25 +327,18 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) # TODO warp should fix this - # Get the boundary id and missing mask - f_post_collision = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # Read thread data for populations and missing mask + f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) + # Prepare auxilary data for BC (if applicable) + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] + f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -246,7 +350,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -276,6 +380,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Setting the Struct attributes and active BC classes based on the BC class names bc_fallback = self.boundary_conditions[0] + # TODO: what if self.boundary_conditions is an empty list e.g. when we have periodic BC all around! for var in vars(bc_struct): if var not in active_bc_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index cd63b36..a93d039 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -41,6 +41,7 @@ def __init__(self, d, q, c, w): self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() + self.qi = self._construct_qi() # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary @@ -49,6 +50,7 @@ def __init__(self, d, q, c, w): self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) + self.wp_qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.qi)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -59,6 +61,24 @@ def warp_u_vec(self, dtype): def warp_stream_mat(self, dtype): return wp.mat((self.q, self.d), dtype=dtype) + def _construct_qi(self): + # Qi = cc - cs^2*I + dim = self.d + Qi = self.cc.copy() + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + Qi[:, diagonal] += -1.0 / 3.0 + Qi[:, offdiagonal] *= 2.0 + return Qi + def _construct_lattice_moment(self): """ This function constructs the moments of the lattice.