Skip to content

Commit

Permalink
refactored the interpolating BC to follow related class inheritance. …
Browse files Browse the repository at this point in the history
…Also added more comments.
  • Loading branch information
hsalehipour committed Oct 24, 2023
1 parent 6b3d5c2 commit 8dced2b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 118 deletions.
3 changes: 1 addition & 2 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@

# JAX-related imports
from jax import jit, lax, vmap
from jax.config import config
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import process_allgather
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding, PartitionSpec, PositionalSharding, Mesh
import orbax.checkpoint as orb

# functools imports
from functools import partial

# Local/Custom Libraries
import src.models
from src.utils import downsample_field

jax.config.update("jax_spmd_mode", 'allow_all')
Expand Down
173 changes: 57 additions & 116 deletions src/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,28 @@ def configure(self, boundaryBitmask):
"""
# Perform index shift for halfway BB.
hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices]
nbd_orig = len(self.indices[0])
idx = np.array(self.indices).T
idx_trg = []
for i in range(self.lattice.q):
idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i])
indices_new = np.unique(np.vstack(idx_trg), axis=0)
self.indices = tuple(indices_new.T)
nbd_modified = len(self.indices[0])
if (nbd_orig != nbd_modified) and self.vel is not None:
vel_avg = np.mean(self.vel, axis=0)
self.vel = jnp.zeros(indices_new.shape, dtype=self.precisionPolicy.compute_dtype) + vel_avg
print("WARNING: assuming a constant averaged velocity vector is imposed at all BC cells!")

return

@partial(jit, static_argnums=(0,))
def impose_boundary_vel(self, fbd, bindex):
c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype)
cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c)
fbd = fbd.at[bindex, self.imissing].add(-cu[bindex, self.iknown])
return fbd

@partial(jit, static_argnums=(0,))
def apply(self, fout, fin):
"""
Expand All @@ -526,13 +540,10 @@ def apply(self, fout, fin):
nbd = len(self.indices[0])
bindex = np.arange(nbd)[:, None]
fbd = fout[self.indices]
if self.vel is not None:
c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype)
cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c)
fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown] - cu[bindex, self.iknown])
else:
fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown])

fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown])
if self.vel is not None:
fbd = self.impose_boundary_vel(fbd, bindex)
return fbd

class EquilibriumBC(BoundaryCondition):
Expand Down Expand Up @@ -1012,18 +1023,18 @@ def apply(self, fout, fin):
return fbd


class InterpolatedBounceBackDifferentiable(BoundaryCondition):
class InterpolatedBounceBackBouzidi(BounceBackHalfway):
"""
A local and differentiable single-node interpolated bounce-back boundary condition for a lattice Boltzmann method
simulation.
A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice
Boltzmann method simulation.
This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after
the streaming step.
Attributes
----------
name : str
The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable".
The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi".
implementationStep : str
The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class,
it is "PostStreaming".
Expand All @@ -1035,43 +1046,12 @@ class InterpolatedBounceBackDifferentiable(BoundaryCondition):
An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls
"""

def __init__(self, indices, implicit_distances, grid_info, precision_policy):
def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None):

super().__init__(indices, grid_info, precision_policy)
self.name = "InterpolatedBounceBackDifferentiable"
self.implementationStep = "PostStreaming"
super().__init__(indices, grid_info, precision_policy, vel=vel)
self.name = "InterpolatedBounceBackBouzidi"
self.implicit_distances = implicit_distances
self.weights = None
self.isSolid = True
self.needsExtraConfiguration = True

def configure(self, boundaryBitmask):
"""
Configures the boundary condition.
Parameters
----------
boundaryBitmask : array-like
The connectivity bitmask for the boundary voxels.
Returns
-------
None
Notes
-----
This method performs an index shift for the halfway bounce-back boundary condition. It updates the indices of
the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes.
"""
# Perform index shift for halfway BB.
hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices]
idx = np.array(self.indices).T
idx_trg = []
for i in range(self.lattice.q):
idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i])
indices_new = np.unique(np.vstack(idx_trg), axis=0)
self.indices = tuple(indices_new.T)
return

def set_proximity_ratio(self):
"""
Expand Down Expand Up @@ -1118,25 +1098,44 @@ def apply(self, fout, fin):
f_postcollision_iknown = fin[self.indices][bindex, self.iknown]
f_postcollision_imissing = fin[self.indices][bindex, self.imissing]
f_poststreaming_iknown = fout[self.indices][bindex, self.iknown]
fmissing = ((1. - self.weights) * f_postcollision_iknown +
self.weights * (f_postcollision_imissing + f_poststreaming_iknown)) / (1.0 + self.weights)

# if weights<0.5
fs_near = 2. * self.weights * f_postcollision_iknown + \
(1.0 - 2.0 * self.weights) * f_poststreaming_iknown

# if weights>=0.5
fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \
(2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing

# combine near and far contributions
fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far)
fbd = fbd.at[bindex, self.imissing].set(fmissing)

if self.vel is not None:
fbd = self.impose_boundary_vel(fbd, bindex)
return fbd


class InterpolatedBounceBackBouzidi(BoundaryCondition):
class InterpolatedBounceBackDifferentiable(InterpolatedBounceBackBouzidi):
"""
A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice
Boltzmann method simulation.
A differentiable variant of the "InterpolatedBounceBackBouzidi" BC scheme. This BC is now differentiable at
self.weight = 0.5 unlike the original Bouzidi scheme which switches between 2 equations at weight=0.5. Refer to
[1] (their Appendix E) for more information.
References
----------
[1] Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three
dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547.
doi:10.1016/j.camwa.2015.05.001.
This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after
the streaming step.
Attributes
----------
name : str
The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi".
The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable".
implementationStep : str
The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class,
it is "PostStreaming".
Expand All @@ -1148,63 +1147,11 @@ class InterpolatedBounceBackBouzidi(BoundaryCondition):
An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls
"""

def __init__(self, indices, implicit_distances, grid_info, precision_policy):
def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None):

super().__init__(indices, grid_info, precision_policy)
self.name = "InterpolatedBounceBackBouzidi"
self.implementationStep = "PostStreaming"
self.implicit_distances = implicit_distances
self.weights = None
self.isSolid = True
self.needsExtraConfiguration = True

def configure(self, boundaryBitmask):
"""
Configures the boundary condition.
Parameters
----------
boundaryBitmask : array-like
The connectivity bitmask for the boundary voxels.
Returns
-------
None
Notes
-----
This method performs an index shift for the halfway bounce-back boundary condition. It updates the indices of
the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes.
"""
# Perform index shift for halfway BB.
hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices]
idx = np.array(self.indices).T
idx_trg = []
for i in range(self.lattice.q):
idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i])
indices_new = np.unique(np.vstack(idx_trg), axis=0)
self.indices = tuple(indices_new.T)
return

def set_proximity_ratio(self):
"""
Creates the interpolation data needed for the boundary condition.
super().__init__(indices, implicit_distances, grid_info, precision_policy, vel=vel)
self.name = "InterpolatedBounceBackDifferentiable"

Returns
-------
None. The function updates the object's weights attribute in place.
"""
idx = np.array(self.indices).T
self.weights = np.full((idx.shape[0], self.lattice.q), 0.5)
c = np.array(self.lattice.c)
sdf_f = self.implicit_distances[self.indices]
for q in range(1, self.lattice.q):
solid_indices = idx + c[:, q]
solid_indices_tuple = tuple(map(tuple, solid_indices.T))
sdf_s = self.implicit_distances[solid_indices_tuple]
mask = self.iknownBitmask[:, q]
self.weights[mask, q] = sdf_f[mask] / (sdf_f[mask] - sdf_s[mask])
return

@partial(jit, static_argnums=(0,))
def apply(self, fout, fin):
Expand All @@ -1231,16 +1178,10 @@ def apply(self, fout, fin):
f_postcollision_iknown = fin[self.indices][bindex, self.iknown]
f_postcollision_imissing = fin[self.indices][bindex, self.imissing]
f_poststreaming_iknown = fout[self.indices][bindex, self.iknown]

# if weights<0.5
fs_near = 2. * self.weights * f_postcollision_iknown + \
(1.0 - 2.0 * self.weights) * f_poststreaming_iknown

# if weights>=0.5
fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \
(2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing

# combine near and far contributions
fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far)
fmissing = ((1. - self.weights) * f_postcollision_iknown +
self.weights * (f_postcollision_imissing + f_poststreaming_iknown)) / (1.0 + self.weights)
fbd = fbd.at[bindex, self.imissing].set(fmissing)

if self.vel is not None:
fbd = self.impose_boundary_vel(fbd, bindex)
return fbd

0 comments on commit 8dced2b

Please sign in to comment.