From 4c6d7d553e436daed0e2176441b57c2f6b7af807 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 9 Aug 2024 16:40:49 -0400 Subject: [PATCH] added ZouHe in warp --- xlb/operator/boundary_condition/bc_zouhe.py | 172 ++++++++++++++++++-- xlb/operator/collision/kbc.py | 4 +- xlb/operator/stepper/nse_stepper.py | 11 +- xlb/operator/stepper/stepper.py | 2 + xlb/velocity_set/velocity_set.py | 1 + 5 files changed, 171 insertions(+), 19 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 4e2eebc..06be077 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -34,16 +34,17 @@ class ZouHeBC(BoundaryCondition): def __init__( self, - bc_type=None, - prescribed_value=None, + bc_type, + prescribed_value, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, ): - assert bc_type in ["velocity", "pressure"], f'The boundary type must be either "velocity" or "pressure"' + assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() + self.prescribed_value = prescribed_value # Call the parent constructor super().__init__( @@ -56,8 +57,9 @@ def __init__( # Set the prescribed value for pressure or velocity dim = self.velocity_set.d - self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] - # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! + if self.compute_backend == ComputeBackend.JAX: + self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] + # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! @partial(jit, static_argnums=(0,), inline=True) def _get_known_middle_mask(self, missing_mask): @@ -103,7 +105,7 @@ def calculate_vel(self, fpop, rho, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) - unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=-1, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=-1, keepdims=True)) + unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) # Return the above unormal as a normal vector which sets the tangential velocities to zero vel = unormal * normals @@ -159,26 +161,160 @@ def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): return f_post def _construct_warp(self): + # assign placeholders for both u and rho based on prescribed_value + _d = self.velocity_set.d + _q = self.velocity_set.q + u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d + rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 + # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c + # _u_vec = wp.vec(_d, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = wp.float32(rho) + _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.wp_opp_indices - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _c = self.velocity_set.wp_c + _c32 = self.velocity_set.wp_c32 + # TODO: this is way less than ideal. we should not be making new types + + @wp.func + def get_normal_vectors_2d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l]) + return normals @wp.func - def functional( + def get_normal_vectors_3d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return normals + + @wp.func + def _helper_functional( + fpop: Any, + fsum: Any, + missing_mask: Any, + lattice_direction: Any, + ): + l = lattice_direction + known_mask = missing_mask[_opp_indices[l]] + middle_mask = ~(missing_mask[l] | known_mask) + # fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask) + if middle_mask and known_mask: + fsum += fpop[l] + 2.0 * fpop[l] + elif middle_mask: + fsum += fpop[l] + elif known_mask: + fsum += 2.0 * fpop[l] + return fsum + + @wp.func + def bounceback_nonequilibrium( + fpop: Any, + missing_mask: Any, + density: Any, + velocity: Any, + ): + feq = self.equilibrium_operator.warp_functional(density, velocity) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] + return fpop + + @wp.func + def functional3d_velocity( f_pre: Any, f_post: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post - for l in range(self.velocity_set.q): - # If the mask is missing then take the opposite index + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): if missing_mask[l] == wp.uint8(1): - # Get the pre-streaming distribution function in oppisite direction - _f[l] = f_pre[_opp_indices[l]] + normals = get_normal_vectors_3d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = _fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional3d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_3d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + unormal = -1.0 + _fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional2d_velocity( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_2d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = _fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional2d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_2d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + unormal = -1.0 + _fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) return _f # Construct the warp kernel @@ -232,6 +368,14 @@ def kernel3d( f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + if self.velocity_set.d == 3 and self.bc_type == "velocity": + functional = functional3d_velocity + elif self.velocity_set.d == 3 and self.bc_type == "pressure": + functional = functional3d_pressure + elif self.bc_type == "velocity": + functional = functional2d_velocity + else: + functional = functional2d_pressure return functional, kernel diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index fa0857a..9297363 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -366,7 +366,7 @@ def kernel2d( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1]] _feq[l] = feq[l, index[0], index[1]] - _u = self._warp_u_vec() + _u = self.warp_u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1]] _rho = rho[0, index[0], index[1]] @@ -398,7 +398,7 @@ def kernel3d( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = self._warp_u_vec() + _u = self.warp_u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1], index[2]] _rho = rho[0, index[0], index[1], index[2]] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index bfc9d8c..84a0b8f 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -98,6 +98,7 @@ class BoundaryConditionIDStruct: id_DoNothingBC: wp.uint8 id_HalfwayBounceBackBC: wp.uint8 id_FullwayBounceBackBC: wp.uint8 + id_ZouHeBC: wp.uint8 @wp.kernel def kernel2d( @@ -139,6 +140,9 @@ def kernel2d( elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC: + # Zouhe boundary condition + f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -204,6 +208,9 @@ def kernel3d( elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC: + # Zouhe boundary condition + f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -237,9 +244,8 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): bc_struct = self.warp_functional() bc_attribute_list = [] - for bc in self.boundary_conditions: + for attribute_str in bc_to_id.keys(): # Setting the Struct attributes based on the BC class names - attribute_str = bc.__class__.__name__ setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) bc_attribute_list.append("id_" + attribute_str) @@ -248,7 +254,6 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): for var in ll: if var not in bc_attribute_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 - attribute_str = bc.__class__.__name__ setattr(bc_struct, var, 255) # Launch the warp kernel diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 2127ea6..1608342 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -33,6 +33,7 @@ def __init__(self, operators, boundary_conditions): from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC # Define a list of tuples with attribute names and their corresponding classes conditions = [ @@ -40,6 +41,7 @@ def __init__(self, operators, boundary_conditions): ("do_nothing_bc", DoNothingBC), ("halfway_bounce_back_bc", HalfwayBounceBackBC), ("fullway_bounce_back_bc", FullwayBounceBackBC), + ("zouhe_bc", ZouHeBC), ] # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 47bbae4..cd63b36 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -48,6 +48,7 @@ def __init__(self, d, q, c, w): self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) + self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype)