diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2a580aa..0ce07e9 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -5,6 +5,8 @@ from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, + ZouHeBC, + RegularizedBC, EquilibriumBC, DoNothingBC, ) @@ -67,11 +69,16 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=inlet) + bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet) + # bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_do_nothing = DoNothingBC(indices=outlet) + bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) + # bc_outlet = DoNothingBC(indices=outlet) bc_sphere = FullwayBounceBackBC(indices=sphere) - self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_sphere] + self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] + # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because + # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. + # TODO: how to ensure about this behind in the src code? def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( @@ -85,7 +92,7 @@ def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") def run(self, num_steps, post_process_interval=100): for i in range(num_steps): @@ -107,9 +114,10 @@ def post_process(self, i): # remove boundary cells u = u[:, 1:-1, 1:-1, 1:-1] + rho = rho[:, 1:-1, 1:-1, 1:-1] u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - fields = {"u_magnitude": u_magnitude} + fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]} save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) @@ -117,7 +125,7 @@ def post_process(self, i): if __name__ == "__main__": # Running the simulation - grid_shape = (512, 128, 128) + grid_shape = (512 // 2, 128 // 2, 128 // 2) velocity_set = xlb.velocity_set.D3Q19() backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 4887085..506da1d 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -7,3 +7,4 @@ from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC +from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 85a3c35..0083bae 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -47,7 +48,8 @@ def __init__( @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) def _construct_warp(self): diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index df947c6..2ed0067 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -50,7 +51,8 @@ def __init__( @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py new file mode 100644 index 0000000..413a37f --- /dev/null +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -0,0 +1,422 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +import jax.lax as lax +from functools import partial +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.macroscopic.second_moment import SecondMoment + + +class RegularizedBC(ZouHeBC): + """ + Regularized boundary condition for a lattice Boltzmann method simulation. + + This class implements the regularized boundary condition, which is a non-equilibrium bounce-back boundary condition + with additional regularization. It can be used to set inflow and outflow boundary conditions with prescribed pressure + or velocity. + + Attributes + ---------- + name : str + The name of the boundary condition. For this class, it is "Regularized". + Qi : numpy.ndarray + The Qi tensor, which is used in the regularization of the distribution functions. + + References + ---------- + Latt, J. (2007). Hydrodynamic limit of lattice Boltzmann equations. PhD thesis, University of Geneva. + Latt, J., Chopard, B., Malaspinas, O., Deville, M., & Michler, A. (2008). Straight velocity boundaries in the + lattice Boltzmann method. Physical Review E, 77(5), 056703. doi:10.1103/PhysRevE.77.056703 + """ + + def __init__( + self, + bc_type, + prescribed_value, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + ): + # Call the parent constructor + super().__init__( + bc_type, + prescribed_value, + velocity_set, + precision_policy, + compute_backend, + indices, + ) + + # The operator to compute the momentum flux + self.momentum_flux = SecondMoment() + + # helper function + def compute_qi(self): + # Qi = cc - cs^2*I + dim = self.velocity_set.d + Qi = self.velocity_set.cc + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + Qi[:, diagonal] += -1.0 / 3.0 + Qi[:, offdiagonal] *= 2.0 + return Qi + + @partial(jit, static_argnums=(0,), inline=True) + def regularize_fpop(self, fpop, feq): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + + Parameters + ---------- + fpop : jax.numpy.ndarray + The distribution functions. + feq : jax.numpy.ndarray + The equilibrium distribution functions. + + Returns + ------- + jax.numpy.ndarray + The regularized distribution functions. + """ + # Qi = cc - cs^2*I + dim = self.velocity_set.d + weights = self.velocity_set.w[(slice(None),) + (None,) * dim] + # TODO: if I use the following I get NaN ! figure out why! + # Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype) + Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype) + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # Qi = cc - cs^2*I + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) + Qi = Qi.at[:, offdiagonal].multiply(2.0) + + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + f_neq = fpop - feq + PiNeq = self.momentum_flux(f_neq) + # PiNeq = self.momentum_flux(fpop) - self.momentum_flux(feq) + + # Compute double dot product Qi:Pi1 + # QiPi1 = np.zeros_like(fpop) + # Pi1 = PiNeq + QiPi1 = jnp.tensordot(Qi, PiNeq, axes=(1, 0)) + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = 9.0 / 2.0 * weights * QiPi1 + fpop_regularized = feq + fpop1 + return fpop_regularized + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + # creat a mask to slice boundary cells + boundary = boundary_mask == self.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + + # compute the equilibrium based on prescribed values and the type of BC + feq = self.calculate_equilibrium(f_post, missing_mask) + + # set the unknown f populations based on the non-equilibrium bounce-back method + f_post_bd = self.bounceback_nonequilibrium(f_post, feq, missing_mask) + + # Regularize the boundary fpop + f_post_bd = self.regularize_fpop(f_post_bd, feq) + + # apply bc + f_post = jnp.where(boundary, f_post_bd, f_post) + return f_post + + def _construct_warp(self): + # assign placeholders for both u and rho based on prescribed_value + _d = self.velocity_set.d + _q = self.velocity_set.q + u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d + rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 + + # Set local constants TODO: This is a hack and should be fixed with warp update + # _u_vec = wp.vec(_d, dtype=self.compute_dtype) + # compute Qi tensor and store it in self + _qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi())) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = wp.float32(rho) + _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) + _opp_indices = self.velocity_set.wp_opp_indices + _w = self.velocity_set.wp_w + _c = self.velocity_set.wp_c + _c32 = self.velocity_set.wp_c32 + # TODO: this is way less than ideal. we should not be making new types + + @wp.func + def get_normal_vectors_2d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l]) + return normals + + @wp.func + def _get_fsum( + fpop: Any, + missing_mask: Any, + ): + fsum_known = self.compute_dtype(0.0) + fsum_middle = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[_opp_indices[l]] == wp.uint8(1): + fsum_known += 2.0 * fpop[l] + elif missing_mask[l] != wp.uint8(1): + fsum_middle += fpop[l] + return fsum_known + fsum_middle + + @wp.func + def get_normal_vectors_3d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + + @wp.func + def bounceback_nonequilibrium( + fpop: Any, + feq: Any, + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] + return fpop + + @wp.func + def regularize_fpop( + fpop: Any, + feq: Any, + ): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + """ + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + f_neq = fpop - feq + PiNeq = self.momentum_flux.warp_functional(f_neq) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + QiPi1 = _f_vec() + for l in range(_q): + QiPi1[l] = 0.0 + for t in range(nt): + QiPi1[l] += _qi[l, t] * PiNeq[t] + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = 9.0 / 2.0 * _w[l] * QiPi1[l] + fpop[l] = feq[l] + fpop1 + return fpop + + @wp.func + def functional3d_velocity( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate rho + fsum = _get_fsum(_f, missing_mask) + unormal = self.compute_dtype(0.0) + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) + return _f + + @wp.func + def functional3d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate velocity + fsum = _get_fsum(_f, missing_mask) + unormal = -1.0 + fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) + return _f + + @wp.func + def functional2d_velocity( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate rho + fsum = _get_fsum(_f, missing_mask) + unormal = self.compute_dtype(0.0) + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) + return _f + + @wp.func + def functional2d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate velocity + fsum = _get_fsum(_f, missing_mask) + unormal = -1.0 + fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) + return _f + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec3i(i, j) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + + # Apply the boundary condition + if _boundary_id == wp.uint8(self.id): + _f = functional(_f_pre, _f_post, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1]] = _f[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + + # Apply the boundary condition + if _boundary_id == wp.uint8(self.id): + _f = functional(_f_pre, _f_post, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + if self.velocity_set.d == 3 and self.bc_type == "velocity": + functional = functional3d_velocity + elif self.velocity_set.d == 3 and self.bc_type == "pressure": + functional = functional3d_pressure + elif self.bc_type == "velocity": + functional = functional2d_velocity + else: + functional = functional2d_pressure + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 06be077..3b69b21 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -26,12 +27,14 @@ class ZouHeBC(BoundaryCondition): """ Zou-He boundary condition for a lattice Boltzmann method simulation. - This class implements the Zou-He boundary condition, which is a non-equilibrium bounce-back boundary condition. - It can be used to set inflow and outflow boundary conditions with prescribed pressure or velocity. + This method applies the Zou-He boundary condition by first computing the equilibrium distribution functions based + on the prescribed values and the type of boundary condition, and then setting the unknown distribution functions + based on the non-equilibrium bounce-back method. + Tangential velocity is not ensured to be zero by adding transverse contributions based on + Hecth & Harting (2010) (doi:10.1088/1742-5468/2010/01/P01018) as it caused numerical instabilities at higher + Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, bc_type, @@ -41,7 +44,10 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, ): + # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC + # may have different types (velocity or pressure). assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." + self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__ + "_" + bc_type) self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() self.prescribed_value = prescribed_value @@ -58,7 +64,7 @@ def __init__( # Set the prescribed value for pressure or velocity dim = self.velocity_set.d if self.compute_backend == ComputeBackend.JAX: - self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] + self.prescribed_value = jnp.atleast_1d(prescribed_value)[(slice(None),) + (None,) * dim] # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! @partial(jit, static_argnums=(0,), inline=True) @@ -77,7 +83,7 @@ def _get_normal_vec(self, missing_mask): @partial(jit, static_argnums=(0,), inline=True) def get_rho(self, fpop, missing_mask): if self.bc_type == "velocity": - vel = self.get_vel(fpop, missing_mask) + vel = self.prescribed_value rho = self.calculate_rho(fpop, vel, missing_mask) elif self.bc_type == "pressure": rho = self.prescribed_value @@ -90,7 +96,7 @@ def get_vel(self, fpop, missing_mask): if self.bc_type == "velocity": vel = self.prescribed_value elif self.bc_type == "pressure": - rho = self.get_rho(fpop, missing_mask) + rho = self.prescribed_value vel = self.calculate_vel(fpop, rho, missing_mask) else: raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") @@ -104,8 +110,8 @@ def calculate_vel(self, fpop, rho, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) - - unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) + fsum = jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True) + unormal = -1.0 + fsum / rho # Return the above unormal as a normal vector which sets the tangential velocities to zero vel = unormal * normals @@ -119,7 +125,8 @@ def calculate_rho(self, fpop, vel, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) unormal = jnp.sum(normals * vel, keepdims=True, axis=0) - rho = (1.0 / (1.0 + unormal)) * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) + fsum = jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True) + rho = fsum / (1.0 + unormal) return rho @partial(jit, static_argnums=(0,), inline=True) @@ -150,7 +157,8 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): # creat a mask to slice boundary cells boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # compute the equilibrium based on prescribed values and the type of BC feq = self.calculate_equilibrium(f_post, missing_mask) @@ -187,41 +195,33 @@ def get_normal_vectors_2d( return normals @wp.func - def get_normal_vectors_3d( - lattice_direction: Any, + def _get_fsum( + fpop: Any, + missing_mask: Any, ): - l = lattice_direction - if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) - return normals + fsum_known = self.compute_dtype(0.0) + fsum_middle = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[_opp_indices[l]] == wp.uint8(1): + fsum_known += 2.0 * fpop[l] + elif missing_mask[l] != wp.uint8(1): + fsum_middle += fpop[l] + return fsum_known + fsum_middle @wp.func - def _helper_functional( - fpop: Any, - fsum: Any, + def get_normal_vectors_3d( missing_mask: Any, - lattice_direction: Any, ): - l = lattice_direction - known_mask = missing_mask[_opp_indices[l]] - middle_mask = ~(missing_mask[l] | known_mask) - # fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask) - if middle_mask and known_mask: - fsum += fpop[l] + 2.0 * fpop[l] - elif middle_mask: - fsum += fpop[l] - elif known_mask: - fsum += 2.0 * fpop[l] - return fsum + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) @wp.func def bounceback_nonequilibrium( fpop: Any, + feq: Any, missing_mask: Any, - density: Any, - velocity: Any, ): - feq = self.equilibrium_operator.warp_functional(density, velocity) for l in range(_q): if missing_mask[l] == wp.uint8(1): fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] @@ -235,19 +235,20 @@ def functional3d_velocity( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_3d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate rho + fsum = _get_fsum(_f, missing_mask) + unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = _fsum / (1.0 + unormal) + _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f @wp.func @@ -258,18 +259,18 @@ def functional3d_pressure( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_3d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) - unormal = -1.0 + _fsum / _rho + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate velocity + fsum = _get_fsum(_f, missing_mask) + unormal = -1.0 + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f @wp.func @@ -280,19 +281,20 @@ def functional2d_velocity( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_2d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate rho + fsum = _get_fsum(_f, missing_mask) + unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = _fsum / (1.0 + unormal) + _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f @wp.func @@ -303,18 +305,18 @@ def functional2d_pressure( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_2d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) - unormal = -1.0 + _fsum / _rho + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate velocity + fsum = _get_fsum(_f, missing_mask) + unormal = -1.0 + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f # Construct the warp kernel @@ -333,7 +335,7 @@ def kernel2d( _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(ZouHeBC.id): + if _boundary_id == wp.uint8(self.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post @@ -358,7 +360,7 @@ def kernel3d( _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(ZouHeBC.id): + if _boundary_id == wp.uint8(self.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index b5c2fc4..7548cf0 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -52,10 +52,10 @@ def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=No local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] bid = bid.at[tuple(local_indices)].set(id_number) - if dim == 2: - grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) - if dim == 3: - grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + # if dim == 2: + # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) + # if dim == 3: + # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 9297363..da0aee5 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -6,12 +6,13 @@ from jax import jit import warp as wp from typing import Any +from functools import partial from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision from xlb.operator import Operator -from functools import partial +from xlb.operator.macroscopic import SecondMoment class KBC(Collision): @@ -28,6 +29,7 @@ def __init__( precision_policy=None, compute_backend=None, ): + self.momentum_flux = SecondMoment() self.epsilon = 1e-32 self.beta = omega * 0.5 self.inv_beta = 1.0 / self.beta @@ -94,33 +96,6 @@ def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarr """ return jnp.sum(x * y / feq, axis=0) - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def momentum_flux_jax( - self, - fneq: jnp.ndarray, - ): - """ - This function computes the momentum flux, which is the product of the non-equilibrium - distribution functions (fneq) and the lattice moments (cc). - - The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann - Method (LBM). - - # TODO: probably move this to equilibrium calculation - - Parameters - ---------- - fneq: jax.numpy.ndarray - The non-equilibrium distribution functions. - - Returns - ------- - jax.numpy.ndarray - The computed momentum flux. - """ - - return jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) - @partial(jit, static_argnums=(0,), inline=True) def decompose_shear_d3q27_jax(self, fneq): """ @@ -138,7 +113,7 @@ def decompose_shear_d3q27_jax(self, fneq): """ # Calculate the momentum flux - Pi = self.momentum_flux_jax(fneq) + Pi = self.momentum_flux(fneq) # Calculating Nxz and Nyz with indices moved to the first dimension Nxz = Pi[0, ...] - Pi[5, ...] Nyz = Pi[3, ...] - Pi[5, ...] @@ -187,7 +162,7 @@ def decompose_shear_d2q9_jax(self, fneq): jax.numpy.array Shear components of fneq. """ - Pi = self.momentum_flux_jax(fneq) + Pi = self.momentum_flux(fneq) N = Pi[0, ...] - Pi[2, ...] s = jnp.zeros_like(fneq) s = s.at[3, ...].set(N) @@ -207,35 +182,14 @@ def _construct_warp(self): raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set))) # Set local constants TODO: This is a hack and should be fixed with warp update - _w = self.velocity_set.wp_w - _cc = self.velocity_set.wp_cc - _omega = wp.constant(self.compute_dtype(self.omega)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 - _pi_vec = wp.vec( - _pi_dim, - dtype=self.compute_dtype, - ) _epsilon = wp.constant(self.compute_dtype(self.epsilon)) _beta = wp.constant(self.compute_dtype(self.beta)) _inv_beta = wp.constant(self.compute_dtype(1.0 / self.beta)) - # Construct functional for computing momentum flux - @wp.func - def momentum_flux_warp( - fneq: Any, - ): - # Get momentum flux - pi = _pi_vec() - for d in range(_pi_dim): - pi[d] = 0.0 - for q in range(self.velocity_set.q): - pi[d] += _cc[q, d] * fneq[q] - return pi - @wp.func def decompose_shear_d2q9(fneq: Any): - pi = momentum_flux_warp(fneq) + pi = self.momentum_flux.warp_functional(fneq) N = pi[0] - pi[1] s = wp.vec9(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) s[3] = N @@ -254,7 +208,7 @@ def decompose_shear_d3q27( fneq: Any, ): # Get momentum flux - pi = momentum_flux_warp(fneq) + pi = self.momentum_flux.warp_functional(fneq) nxz = pi[0] - pi[5] nyz = pi[3] - pi[5] diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 3463078..38195cd 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1 +1,2 @@ -from xlb.operator.macroscopic.macroscopic import Macroscopic as Macroscopic +from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic +from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py new file mode 100644 index 0000000..db8fce6 --- /dev/null +++ b/xlb/operator/macroscopic/second_moment.py @@ -0,0 +1,134 @@ +# Base class for all equilibriums + +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class SecondMoment(Operator): + """ + Operator to calculate the second moment of distribution functions. + + The second moment may be used to compute the momentum flux in the computation of + the stress tensor in the Lattice Boltzmann Method (LBM). + + Important Note: + Note that this rank 2 symmetric tensor (dim*dim) has been converted into a rank one + vector where the diagonal and off-diagonal components correspond to the following elements of + the vector: + if self.grid.dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif self.grid.dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + + ** For any reduction operation on the full tensor it is crucial to account for the full tensor by + considering all diagonal and off-diagonal components. + """ + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def jax_implementation( + self, + fneq: jnp.ndarray, + ): + """ + This function computes the second order moment, which is the product of the + distribution functions (f) and the lattice moments (cc). + + Parameters + ---------- + fneq: jax.numpy.ndarray + The distribution functions. + + Returns + ------- + jax.numpy.ndarray + The computed second moment. + """ + return jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) + + def _construct_warp(self): + # Make constants for warp + _cc = self.velocity_set.wp_cc + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 + _pi_vec = wp.vec( + _pi_dim, + dtype=self.compute_dtype, + ) + + # Construct functional for computing second moment + @wp.func + def functional( + fneq: Any, + ): + # Get second order moment (a symmetric tensore shaped into a vector) + pi = _pi_vec() + for d in range(_pi_dim): + pi[d] = 0.0 + for q in range(self.velocity_set.q): + pi[d] += _cc[q, d] * fneq[q] + return pi + + # Construct the kernel + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + pi: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the equilibrium + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _pi = functional(_f) + + # Set the output + for d in range(_pi_dim): + pi[d, index[0], index[1], index[2]] = _pi[d] + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + pi: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the equilibrium + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _pi = functional(_f) + + # Set the output + for d in range(_pi_dim): + pi[d, index[0], index[1]] = _pi[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, pi): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + f, + pi, + ], + dim=pi.shape[1:], + ) + return pi diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/zero_first_moments.py similarity index 97% rename from xlb/operator/macroscopic/macroscopic.py rename to xlb/operator/macroscopic/zero_first_moments.py index 13d3817..fbf7c93 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/zero_first_moments.py @@ -10,9 +10,9 @@ from xlb.operator.operator import Operator -class Macroscopic(Operator): +class ZeroAndFirstMoments(Operator): """ - Base class for all macroscopic operators + A class to compute first and zeroth moments of distribution functions. TODO: Currently this is only used for the standard rho and u moments. In the future, this should be extended to include higher order moments diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 84a0b8f..bd2403d 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -98,7 +98,55 @@ class BoundaryConditionIDStruct: id_DoNothingBC: wp.uint8 id_HalfwayBounceBackBC: wp.uint8 id_FullwayBounceBackBC: wp.uint8 - id_ZouHeBC: wp.uint8 + id_ZouHeBC_velocity: wp.uint8 + id_ZouHeBC_pressure: wp.uint8 + id_RegularizedBC_velocity: wp.uint8 + id_RegularizedBC_pressure: wp.uint8 + + @wp.func + def apply_post_streaming_bc( + f_pre: Any, + f_post: Any, + missing_mask: Any, + _boundary_id: Any, + bc_struct: Any, + ): + # Apply post-streaming type boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: + # Equilibrium boundary condition + f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_DoNothingBC: + # Do nothing boundary condition + f_post = self.DoNothingBC.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: + # Half way boundary condition + f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC_velocity: + # Zouhe boundary condition (bc type = velocity) + f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC_pressure: + # Zouhe boundary condition (bc type = pressure) + f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_RegularizedBC_velocity: + # Regularized boundary condition (bc type = velocity) + f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_RegularizedBC_pressure: + # Regularized boundary condition (bc type = velocity) + f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, missing_mask) + return f_post + + @wp.func + def apply_post_collision_bc( + f_pre: Any, + f_post: Any, + missing_mask: Any, + _boundary_id: Any, + bc_struct: Any, + ): + if _boundary_id == bc_struct.id_FullwayBounceBackBC: + # Full way boundary condition + f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + return f_post @wp.kernel def kernel2d( @@ -106,7 +154,7 @@ def kernel2d( f_1: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: BoundaryConditionIDStruct, + bc_struct: Any, timestep: int, ): # Get the global index @@ -131,18 +179,7 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Apply post-streaming type boundary conditions - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC: - # Zouhe boundary condition - f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -159,9 +196,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -174,7 +209,7 @@ def kernel3d( f_1: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), - bc_struct: BoundaryConditionIDStruct, + bc_struct: Any, timestep: int, ): # Get the global index @@ -198,19 +233,8 @@ def kernel3d( # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming boundary conditions - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC: - # Zouhe boundary condition - f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + # Apply post-streaming type boundary conditions + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -221,10 +245,8 @@ def kernel3d( # Apply collision f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) - # Apply collision type boundary conditions - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) + # Apply post-collision type boundary conditions + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -240,22 +262,29 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + # Read the list of bc_to_id created upon instantiation bc_to_id = boundary_condition_registry.bc_to_id - + id_to_bc = boundary_condition_registry.id_to_bc bc_struct = self.warp_functional() - bc_attribute_list = [] - for attribute_str in bc_to_id.keys(): - # Setting the Struct attributes based on the BC class names - setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) - bc_attribute_list.append("id_" + attribute_str) - - # Unused attributes of the struct are set to inernal (id=0) - ll = vars(bc_struct) - for var in ll: - if var not in bc_attribute_list and not var.startswith("_"): + active_bc_list = [] + for bc in self.boundary_conditions: + # Setting the Struct attributes and active BC classes based on the BC class names + bc_name = id_to_bc[bc.id] + setattr(self, bc_name, bc) + setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) + active_bc_list.append("id_" + bc_name) + + # Setting the Struct attributes and active BC classes based on the BC class names + bc_fallback = self.boundary_conditions[0] + for var in vars(bc_struct): + if var not in active_bc_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 setattr(bc_struct, var, 255) + # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not + # produce error when a particular BC is not used in an example. + setattr(self, var.replace("id_", ""), bc_fallback) + # Launch the warp kernel wp.launch( self.warp_kernel, diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 1608342..44aab5f 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -9,8 +9,12 @@ class Stepper(Operator): """ def __init__(self, operators, boundary_conditions): + # Get the boundary condition ids + from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + self.operators = operators self.boundary_conditions = boundary_conditions + # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators if op is not None]) assert len(velocity_sets) < 2, "All velocity sets must be the same. Got {}".format(velocity_sets) @@ -24,39 +28,5 @@ def __init__(self, operators, boundary_conditions): assert len(compute_backends) < 2, "All compute backends must be the same. Got {}".format(compute_backends) compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() - # Add boundary conditions - ############################################ - # Warp cannot handle lists of functions currently - # TODO: Fix this later - ############################################ - from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC - from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC - from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC - from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC - - # Define a list of tuples with attribute names and their corresponding classes - conditions = [ - ("equilibrium_bc", EquilibriumBC), - ("do_nothing_bc", DoNothingBC), - ("halfway_bounce_back_bc", HalfwayBounceBackBC), - ("fullway_bounce_back_bc", FullwayBounceBackBC), - ("zouhe_bc", ZouHeBC), - ] - - # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. - bc_fallback = boundary_conditions[0] - - # Iterate over each boundary condition - for attr_name, bc_class in conditions: - for bc in boundary_conditions: - if isinstance(bc, bc_class): - setattr(self, attr_name, bc) - break - elif not hasattr(self, attr_name): - setattr(self, attr_name, bc_fallback) - - ############################################ - # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend)