diff --git a/xlb/helper/bc_warp_functions.py b/xlb/helper/bc_warp_functions.py new file mode 100644 index 0000000..c3b106f --- /dev/null +++ b/xlb/helper/bc_warp_functions.py @@ -0,0 +1,94 @@ +from xlb import DefaultConfig +from xlb.compute_backend import ComputeBackend +from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux +import warp as wp +from typing import Any + +# Set the compute and Store dtypes +if DefaultConfig.default_backend == ComputeBackend.JAX: + compute_dtype = DefaultConfig.default_precision_policy.compute_precision.jax_dtype + store_dtype = DefaultConfig.default_precision_policy.store_precision.jax_dtype +elif DefaultConfig.default_backend == ComputeBackend.WARP: + compute_dtype = DefaultConfig.default_precision_policy.compute_precision.wp_dtype + compute_dtype = DefaultConfig.default_precision_policy.store_precision.wp_dtype + +# Set local constants +_d = DefaultConfig.velocity_set.d +_q = DefaultConfig.velocity_set.q +_u_vec = wp.vec(_d, dtype=compute_dtype) +_opp_indices = DefaultConfig.velocity_set.opp_indices +_w = DefaultConfig.velocity_set.w +_c = DefaultConfig.velocity_set.c +_c_float = DefaultConfig.velocity_set.c_float +_qi = DefaultConfig.velocity_set.qi + + +# Define the operator needed for computing the momentum flux +momentum_flux = MomentumFlux() + + +@wp.func +def get_bc_fsum( + fpop: Any, + missing_mask: Any, +): + fsum_known = compute_dtype(0.0) + fsum_middle = compute_dtype(0.0) + for l in range(_q): + if missing_mask[_opp_indices[l]] == wp.uint8(1): + fsum_known += compute_dtype(2.0) * fpop[l] + elif missing_mask[l] != wp.uint8(1): + fsum_middle += fpop[l] + return fsum_known + fsum_middle + + +@wp.func +def get_normal_vectors( + missing_mask: Any, +): + if wp.static(_d == 3): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) + else: + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -_u_vec(_c_float[0, l], _c_float[1, l]) + + +@wp.func +def bounceback_nonequilibrium( + fpop: Any, + feq: Any, + missing_mask: Any, +): + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] + return fpop + + +@wp.func +def regularize_fpop( + fpop: Any, + feq: Any, +): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + """ + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + f_neq = fpop - feq + PiNeq = momentum_flux.warp_functional(f_neq) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + QiPi1 = compute_dtype(0.0) + for t in range(nt): + QiPi1 += _qi[l, t] * PiNeq[t] + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = compute_dtype(4.5) * _w[l] * QiPi1 + fpop[l] = feq[l] + fpop1 + return fpop diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index dee8679..020981d 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -15,8 +15,6 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC -from xlb.operator.boundary_condition.boundary_condition import ImplementationStep -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux @@ -64,7 +62,6 @@ def __init__( indices, mesh_vertices, ) - # Overwrite the boundary condition registry id with the bc_type in the name self.momentum_flux = MomentumFlux() @partial(jit, static_argnums=(0,), inline=True) @@ -127,83 +124,13 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): return f_post def _construct_warp(self): - # assign placeholders for both u and rho based on prescribed_value + # load helper functions + from xlb.helper.bc_warp_functions import get_normal_vectors, get_bc_fsum, bounceback_nonequilibrium, regularize_fpop + + # Set local constants _d = self.velocity_set.d _q = self.velocity_set.q - - # Set local constants TODO: This is a hack and should be fixed with warp update - # _u_vec = wp.vec(_d, dtype=self.compute_dtype) - # compute Qi tensor and store it in self - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _opp_indices = self.velocity_set.opp_indices - _w = self.velocity_set.w - _c = self.velocity_set.c - _c_float = self.velocity_set.c_float - _qi = self.velocity_set.qi - # TODO: related to _c_float: this is way less than ideal. we should not be making new types - - @wp.func - def _get_fsum( - fpop: Any, - missing_mask: Any, - ): - fsum_known = self.compute_dtype(0.0) - fsum_middle = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += self.compute_dtype(2.0) * fpop[l] - elif missing_mask[l] != wp.uint8(1): - fsum_middle += fpop[l] - return fsum_known + fsum_middle - - @wp.func - def get_normal_vectors( - missing_mask: Any, - ): - if wp.static(_d == 3): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) - else: - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l]) - - @wp.func - def bounceback_nonequilibrium( - fpop: Any, - feq: Any, - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] - return fpop - - @wp.func - def regularize_fpop( - fpop: Any, - feq: Any, - ): - """ - Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. - """ - # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} - f_neq = fpop - feq - PiNeq = self.momentum_flux.warp_functional(f_neq) - - # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) - nt = _d * (_d + 1) // 2 - for l in range(_q): - QiPi1 = self.compute_dtype(0.0) - for t in range(nt): - QiPi1 += _qi[l, t] * PiNeq[t] - - # assign all populations based on eq 45 of Latt et al (2008) - # fneq ~ f^1 - fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 - fpop[l] = feq[l] + fpop1 - return fpop @wp.func def functional_velocity( @@ -231,7 +158,7 @@ def functional_velocity( break # calculate rho - fsum = _get_fsum(_f, missing_mask) + fsum = get_bc_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -269,7 +196,7 @@ def functional_pressure( break # calculate velocity - fsum = _get_fsum(_f, missing_mask) + fsum = get_bc_fsum(_f, missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 8e889b0..919dc1e 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -18,9 +18,6 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) from xlb.operator.equilibrium import QuadraticEquilibrium import jax @@ -277,55 +274,13 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): return f_post def _construct_warp(self): - # assign placeholders for both u and rho based on prescribed_value + # load helper functions + from xlb.helper.bc_warp_functions import get_normal_vectors, get_bc_fsum, bounceback_nonequilibrium + + # Set local constants _d = self.velocity_set.d _q = self.velocity_set.q - - # Set local constants TODO: This is a hack and should be fixed with warp update - # _u_vec = wp.vec(_d, dtype=self.compute_dtype) - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _opp_indices = self.velocity_set.opp_indices - _c = self.velocity_set.c - _c_float = self.velocity_set.c_float - # TODO: this is way less than ideal. we should not be making new types - - @wp.func - def _get_fsum( - fpop: Any, - missing_mask: Any, - ): - fsum_known = self.compute_dtype(0.0) - fsum_middle = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += self.compute_dtype(2.0) * fpop[l] - elif missing_mask[l] != wp.uint8(1): - fsum_middle += fpop[l] - return fsum_known + fsum_middle - - @wp.func - def get_normal_vectors( - missing_mask: Any, - ): - if wp.static(_d == 3): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) - else: - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l]) - - @wp.func - def bounceback_nonequilibrium( - fpop: Any, - feq: Any, - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] - return fpop @wp.func def functional_velocity( @@ -344,7 +299,7 @@ def functional_velocity( normals = get_normal_vectors(_missing_mask) # calculate rho - fsum = _get_fsum(_f, _missing_mask) + fsum = get_bc_fsum(_f, _missing_mask) unormal = self.compute_dtype(0.0) # Find the value of u from the missing directions @@ -391,7 +346,7 @@ def functional_pressure( break # calculate velocity - fsum = _get_fsum(_f, _missing_mask) + fsum = get_bc_fsum(_f, _missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals