Skip to content

Commit

Permalink
added ZouHe in warp
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Aug 9, 2024
1 parent 90969d1 commit 4c6d7d5
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 19 deletions.
172 changes: 158 additions & 14 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@ class ZouHeBC(BoundaryCondition):

def __init__(
self,
bc_type=None,
prescribed_value=None,
bc_type,
prescribed_value,
velocity_set: VelocitySet = None,
precision_policy: PrecisionPolicy = None,
compute_backend: ComputeBackend = None,
indices=None,
):
assert bc_type in ["velocity", "pressure"], f'The boundary type must be either "velocity" or "pressure"'
assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'."
self.bc_type = bc_type
self.equilibrium_operator = QuadraticEquilibrium()
self.prescribed_value = prescribed_value

# Call the parent constructor
super().__init__(
Expand All @@ -56,8 +57,9 @@ def __init__(

# Set the prescribed value for pressure or velocity
dim = self.velocity_set.d
self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None]
# TODO: this won't work if the prescribed values are a profile with the length of bdry indices!
if self.compute_backend == ComputeBackend.JAX:
self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None]
# TODO: this won't work if the prescribed values are a profile with the length of bdry indices!

@partial(jit, static_argnums=(0,), inline=True)
def _get_known_middle_mask(self, missing_mask):
Expand Down Expand Up @@ -103,7 +105,7 @@ def calculate_vel(self, fpop, rho, missing_mask):
normals = self._get_normal_vec(missing_mask)
known_mask, middle_mask = self._get_known_middle_mask(missing_mask)

unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=-1, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=-1, keepdims=True))
unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True))

# Return the above unormal as a normal vector which sets the tangential velocities to zero
vel = unormal * normals
Expand Down Expand Up @@ -159,26 +161,160 @@ def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
_d = self.velocity_set.d
_q = self.velocity_set.q
u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d
rho = self.prescribed_value if self.bc_type == "pressure" else 0.0

# Set local constants TODO: This is a hack and should be fixed with warp update
_c = self.velocity_set.wp_c
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = wp.float32(rho)
_u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1])
_opp_indices = self.velocity_set.wp_opp_indices
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool
_c = self.velocity_set.wp_c
_c32 = self.velocity_set.wp_c32
# TODO: this is way less than ideal. we should not be making new types

@wp.func
def get_normal_vectors_2d(
lattice_direction: Any,
):
l = lattice_direction
if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
normals = -_u_vec(_c32[0, l], _c32[1, l])
return normals

@wp.func
def functional(
def get_normal_vectors_3d(
lattice_direction: Any,
):
l = lattice_direction
if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l])
return normals

@wp.func
def _helper_functional(
fpop: Any,
fsum: Any,
missing_mask: Any,
lattice_direction: Any,
):
l = lattice_direction
known_mask = missing_mask[_opp_indices[l]]
middle_mask = ~(missing_mask[l] | known_mask)
# fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask)
if middle_mask and known_mask:
fsum += fpop[l] + 2.0 * fpop[l]
elif middle_mask:
fsum += fpop[l]
elif known_mask:
fsum += 2.0 * fpop[l]
return fsum

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
missing_mask: Any,
density: Any,
velocity: Any,
):
feq = self.equilibrium_operator.warp_functional(density, velocity)
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def functional3d_velocity(
f_pre: Any,
f_post: Any,
missing_mask: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
for l in range(self.velocity_set.q):
# If the mask is missing then take the opposite index
_fsum = self.compute_dtype(0.0)
unormal = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
# Get the pre-streaming distribution function in oppisite direction
_f[l] = f_pre[_opp_indices[l]]
normals = get_normal_vectors_3d(l)
_fsum = _helper_functional(_f, _fsum, missing_mask, l)

for d in range(_d):
unormal += _u[d] * normals[d]
_rho = _fsum / (1.0 + unormal)

# impose non-equilibrium bounceback
_f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u)
return _f

@wp.func
def functional3d_pressure(
f_pre: Any,
f_post: Any,
missing_mask: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
_fsum = self.compute_dtype(0.0)
unormal = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
normals = get_normal_vectors_3d(l)
_fsum = _helper_functional(_f, _fsum, missing_mask, l)

unormal = -1.0 + _fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
_f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u)
return _f

@wp.func
def functional2d_velocity(
f_pre: Any,
f_post: Any,
missing_mask: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
_fsum = self.compute_dtype(0.0)
unormal = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
normals = get_normal_vectors_2d(l)
_fsum = _helper_functional(_f, _fsum, missing_mask, l)

for d in range(_d):
unormal += _u[d] * normals[d]
_rho = _fsum / (1.0 + unormal)

# impose non-equilibrium bounceback
_f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u)
return _f

@wp.func
def functional2d_pressure(
f_pre: Any,
f_post: Any,
missing_mask: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
_fsum = self.compute_dtype(0.0)
unormal = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
normals = get_normal_vectors_2d(l)
_fsum = _helper_functional(_f, _fsum, missing_mask, l)

unormal = -1.0 + _fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
_f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u)
return _f

# Construct the warp kernel
Expand Down Expand Up @@ -232,6 +368,14 @@ def kernel3d(
f_post[l, index[0], index[1], index[2]] = _f[l]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
if self.velocity_set.d == 3 and self.bc_type == "velocity":
functional = functional3d_velocity
elif self.velocity_set.d == 3 and self.bc_type == "pressure":
functional = functional3d_pressure
elif self.bc_type == "velocity":
functional = functional2d_velocity
else:
functional = functional2d_pressure

return functional, kernel

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def kernel2d(
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
_feq[l] = feq[l, index[0], index[1]]
_u = self._warp_u_vec()
_u = self.warp_u_vec()
for l in range(_d):
_u[l] = u[l, index[0], index[1]]
_rho = rho[0, index[0], index[1]]
Expand Down Expand Up @@ -398,7 +398,7 @@ def kernel3d(
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_feq[l] = feq[l, index[0], index[1], index[2]]
_u = self._warp_u_vec()
_u = self.warp_u_vec()
for l in range(_d):
_u[l] = u[l, index[0], index[1], index[2]]
_rho = rho[0, index[0], index[1], index[2]]
Expand Down
11 changes: 8 additions & 3 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class BoundaryConditionIDStruct:
id_DoNothingBC: wp.uint8
id_HalfwayBounceBackBC: wp.uint8
id_FullwayBounceBackBC: wp.uint8
id_ZouHeBC: wp.uint8

@wp.kernel
def kernel2d(
Expand Down Expand Up @@ -139,6 +140,9 @@ def kernel2d(
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask)
elif _boundary_id == bc_struct.id_ZouHeBC:
# Zouhe boundary condition
f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask)

# Compute rho and u
rho, u = self.macroscopic.warp_functional(f_post_stream)
Expand Down Expand Up @@ -204,6 +208,9 @@ def kernel3d(
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask)
elif _boundary_id == bc_struct.id_ZouHeBC:
# Zouhe boundary condition
f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask)

# Compute rho and u
rho, u = self.macroscopic.warp_functional(f_post_stream)
Expand Down Expand Up @@ -237,9 +244,8 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):

bc_struct = self.warp_functional()
bc_attribute_list = []
for bc in self.boundary_conditions:
for attribute_str in bc_to_id.keys():
# Setting the Struct attributes based on the BC class names
attribute_str = bc.__class__.__name__
setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str])
bc_attribute_list.append("id_" + attribute_str)

Expand All @@ -248,7 +254,6 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
for var in ll:
if var not in bc_attribute_list and not var.startswith("_"):
# set unassigned boundaries to the maximum integer in uint8
attribute_str = bc.__class__.__name__
setattr(bc_struct, var, 255)

# Launch the warp kernel
Expand Down
2 changes: 2 additions & 0 deletions xlb/operator/stepper/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ def __init__(self, operators, boundary_conditions):
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC

# Define a list of tuples with attribute names and their corresponding classes
conditions = [
("equilibrium_bc", EquilibriumBC),
("do_nothing_bc", DoNothingBC),
("halfway_bounce_back_bc", HalfwayBounceBackBC),
("fullway_bounce_back_bc", FullwayBounceBackBC),
("zouhe_bc", ZouHeBC),
]

# this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example.
Expand Down
1 change: 1 addition & 0 deletions xlb/velocity_set/velocity_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, d, q, c, w):
self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow
self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices))
self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc))
self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c))

def warp_lattice_vec(self, dtype):
return wp.vec(len(self.c), dtype=dtype)
Expand Down

0 comments on commit 4c6d7d5

Please sign in to comment.