Skip to content

Commit

Permalink
Added a helper function for BCs to avoid repeated definition of ident…
Browse files Browse the repository at this point in the history
…ical warp functions.
  • Loading branch information
hsalehipour committed Dec 19, 2024
1 parent 237ac22 commit f2db380
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 130 deletions.
94 changes: 94 additions & 0 deletions xlb/helper/bc_warp_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from xlb import DefaultConfig
from xlb.compute_backend import ComputeBackend
from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux
import warp as wp
from typing import Any

# Set the compute and Store dtypes
if DefaultConfig.default_backend == ComputeBackend.JAX:
compute_dtype = DefaultConfig.default_precision_policy.compute_precision.jax_dtype
store_dtype = DefaultConfig.default_precision_policy.store_precision.jax_dtype
elif DefaultConfig.default_backend == ComputeBackend.WARP:
compute_dtype = DefaultConfig.default_precision_policy.compute_precision.wp_dtype
compute_dtype = DefaultConfig.default_precision_policy.store_precision.wp_dtype

# Set local constants
_d = DefaultConfig.velocity_set.d
_q = DefaultConfig.velocity_set.q
_u_vec = wp.vec(_d, dtype=compute_dtype)
_opp_indices = DefaultConfig.velocity_set.opp_indices
_w = DefaultConfig.velocity_set.w
_c = DefaultConfig.velocity_set.c
_c_float = DefaultConfig.velocity_set.c_float
_qi = DefaultConfig.velocity_set.qi


# Define the operator needed for computing the momentum flux
momentum_flux = MomentumFlux()


@wp.func
def get_bc_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = compute_dtype(0.0)
fsum_middle = compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle


@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])


@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop


@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
for l in range(_q):
QiPi1 = compute_dtype(0.0)
for t in range(nt):
QiPi1 += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = compute_dtype(4.5) * _w[l] * QiPi1
fpop[l] = feq[l] + fpop1
return fpop
85 changes: 6 additions & 79 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux


Expand Down Expand Up @@ -64,7 +62,6 @@ def __init__(
indices,
mesh_vertices,
)
# Overwrite the boundary condition registry id with the bc_type in the name
self.momentum_flux = MomentumFlux()

@partial(jit, static_argnums=(0,), inline=True)
Expand Down Expand Up @@ -127,83 +124,13 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
from xlb.helper.bc_warp_functions import get_normal_vectors, get_bc_fsum, bounceback_nonequilibrium, regularize_fpop

# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
_qi = self.velocity_set.qi
# TODO: related to _c_float: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = self.momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
for l in range(_q):
QiPi1 = self.compute_dtype(0.0)
for t in range(nt):
QiPi1 += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1
fpop[l] = feq[l] + fpop1
return fpop

@wp.func
def functional_velocity(
Expand Down Expand Up @@ -231,7 +158,7 @@ def functional_velocity(
break

# calculate rho
fsum = _get_fsum(_f, missing_mask)
fsum = get_bc_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
Expand Down Expand Up @@ -269,7 +196,7 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
fsum = get_bc_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

Expand Down
57 changes: 6 additions & 51 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)
from xlb.operator.equilibrium import QuadraticEquilibrium
import jax

Expand Down Expand Up @@ -277,55 +274,13 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
from xlb.helper.bc_warp_functions import get_normal_vectors, get_bc_fsum, bounceback_nonequilibrium

# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
# TODO: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def functional_velocity(
Expand All @@ -344,7 +299,7 @@ def functional_velocity(
normals = get_normal_vectors(_missing_mask)

# calculate rho
fsum = _get_fsum(_f, _missing_mask)
fsum = get_bc_fsum(_f, _missing_mask)
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
Expand Down Expand Up @@ -391,7 +346,7 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
fsum = get_bc_fsum(_f, _missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

Expand Down

0 comments on commit f2db380

Please sign in to comment.