Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a helper function for bc related warp functions #98

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 6 additions & 79 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux


Expand Down Expand Up @@ -64,7 +62,6 @@ def __init__(
indices,
mesh_vertices,
)
# Overwrite the boundary condition registry id with the bc_type in the name
self.momentum_flux = MomentumFlux()

@partial(jit, static_argnums=(0,), inline=True)
Expand Down Expand Up @@ -127,83 +124,13 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
from xlb.operator.boundary_condition.common_warp_functions import get_normal_vectors, get_bc_fsum, bounceback_nonequilibrium, regularize_fpop

# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
_qi = self.velocity_set.qi
# TODO: related to _c_float: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = self.momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
for l in range(_q):
QiPi1 = self.compute_dtype(0.0)
for t in range(nt):
QiPi1 += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1
fpop[l] = feq[l] + fpop1
return fpop

@wp.func
def functional_velocity(
Expand Down Expand Up @@ -231,7 +158,7 @@ def functional_velocity(
break

# calculate rho
fsum = _get_fsum(_f, missing_mask)
fsum = get_bc_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
Expand Down Expand Up @@ -269,7 +196,7 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
fsum = get_bc_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

Expand Down
57 changes: 6 additions & 51 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)
from xlb.operator.equilibrium import QuadraticEquilibrium
import jax

Expand Down Expand Up @@ -277,55 +274,13 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
from xlb.operator.boundary_condition.common_warp_functions import get_normal_vectors, get_bc_fsum, bounceback_nonequilibrium

# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
# TODO: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def functional_velocity(
Expand All @@ -344,7 +299,7 @@ def functional_velocity(
normals = get_normal_vectors(_missing_mask)

# calculate rho
fsum = _get_fsum(_f, _missing_mask)
fsum = get_bc_fsum(_f, _missing_mask)
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
Expand Down Expand Up @@ -391,7 +346,7 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
fsum = get_bc_fsum(_f, _missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

Expand Down
69 changes: 22 additions & 47 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,55 +71,26 @@ def __init__(
# A flag for BCs that need auxilary data recovery after streaming
self.needs_aux_recovery = False

if self.compute_backend == ComputeBackend.WARP:
# Set local constants TODO: This is a hack and should be fixed with warp update
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

@wp.func
def update_bc_auxilary_data(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
):
return f_post

@wp.func
def _get_thread_data(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
index: wp.vec3i,
):
# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = bc_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]])
_f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]])

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
return _f_pre, _f_post, _boundary_id, _missing_mask

# Construct some helper warp functions for getting tid data
if self.compute_backend == ComputeBackend.WARP:
self._get_thread_data = _get_thread_data
self.update_bc_auxilary_data = update_bc_auxilary_data
self.update_bc_auxilary_data = self.update_bc_auxilary_data_warp
else:
self.update_bc_auxilary_data = self.update_bc_auxilary_data_jax

@wp.func
def update_bc_auxilary_data_warp(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
):
return f_post

@partial(jit, static_argnums=(0,), inline=True)
def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
def update_bc_auxilary_data_jax(self, f_pre, f_post, bc_mask, missing_mask):
"""
A placeholder function for prepare the auxilary distribution functions for the boundary condition.
currently being called after collision only.
Expand All @@ -131,6 +102,8 @@ def _construct_kernel(self, functional):
Constructs the warp kernel for the boundary condition.
The functional is specific to each boundary condition and should be passed as an argument.
"""
from xlb.operator.boundary_condition.common_warp_functions import get_thread_data

_id = wp.uint8(self.id)

# Construct the warp kernel
Expand All @@ -146,7 +119,7 @@ def kernel(
index = wp.vec3i(i, j, k)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data(f_pre, f_post, bc_mask, missing_mask, index)
_f_pre, _f_post, _boundary_id, _missing_mask = get_thread_data(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == _id:
Expand All @@ -165,6 +138,8 @@ def _construct_aux_data_init_kernel(self, functional):
"""
Constructs the warp kernel for the auxilary data recovery.
"""
from xlb.operator.boundary_condition.common_warp_functions import get_thread_data

_id = wp.uint8(self.id)
_opp_indices = self.velocity_set.opp_indices
_num_of_aux_data = self.num_of_aux_data
Expand All @@ -182,7 +157,7 @@ def aux_data_init_kernel(
index = wp.vec3i(i, j, k)

# read tid data
_f_0, _f_1, _boundary_id, _missing_mask = self._get_thread_data(f_0, f_1, bc_mask, missing_mask, index)
_f_0, _f_1, _boundary_id, _missing_mask = get_thread_data(f_0, f_1, bc_mask, missing_mask, index)

# Apply the functional
if _boundary_id == _id:
Expand Down
Loading
Loading