diff --git a/src/base.py b/src/base.py index b359b6a..0ca82a7 100644 --- a/src/base.py +++ b/src/base.py @@ -11,17 +11,16 @@ # JAX-related imports from jax import jit, lax, vmap -from jax.config import config from jax.experimental import mesh_utils from jax.experimental.multihost_utils import process_allgather from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding, PartitionSpec, PositionalSharding, Mesh import orbax.checkpoint as orb + # functools imports from functools import partial # Local/Custom Libraries -import src.models from src.utils import downsample_field jax.config.update("jax_spmd_mode", 'allow_all') diff --git a/src/boundary_conditions.py b/src/boundary_conditions.py index 3b01ccb..2bb6119 100644 --- a/src/boundary_conditions.py +++ b/src/boundary_conditions.py @@ -498,14 +498,28 @@ def configure(self, boundaryBitmask): """ # Perform index shift for halfway BB. hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices] + nbd_orig = len(self.indices[0]) idx = np.array(self.indices).T idx_trg = [] for i in range(self.lattice.q): idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) indices_new = np.unique(np.vstack(idx_trg), axis=0) self.indices = tuple(indices_new.T) + nbd_modified = len(self.indices[0]) + if (nbd_orig != nbd_modified) and self.vel is not None: + vel_avg = np.mean(self.vel, axis=0) + self.vel = jnp.zeros(indices_new.shape, dtype=self.precisionPolicy.compute_dtype) + vel_avg + print("WARNING: assuming a constant averaged velocity vector is imposed at all BC cells!") + return + @partial(jit, static_argnums=(0,)) + def impose_boundary_vel(self, fbd, bindex): + c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) + cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c) + fbd = fbd.at[bindex, self.imissing].add(-cu[bindex, self.iknown]) + return fbd + @partial(jit, static_argnums=(0,)) def apply(self, fout, fin): """ @@ -526,13 +540,10 @@ def apply(self, fout, fin): nbd = len(self.indices[0]) bindex = np.arange(nbd)[:, None] fbd = fout[self.indices] - if self.vel is not None: - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c) - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown] - cu[bindex, self.iknown]) - else: - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) + fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) + if self.vel is not None: + fbd = self.impose_boundary_vel(fbd, bindex) return fbd class EquilibriumBC(BoundaryCondition): @@ -1012,10 +1023,10 @@ def apply(self, fout, fin): return fbd -class InterpolatedBounceBackDifferentiable(BoundaryCondition): +class InterpolatedBounceBackBouzidi(BounceBackHalfway): """ - A local and differentiable single-node interpolated bounce-back boundary condition for a lattice Boltzmann method - simulation. + A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice + Boltzmann method simulation. This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after the streaming step. @@ -1023,7 +1034,7 @@ class InterpolatedBounceBackDifferentiable(BoundaryCondition): Attributes ---------- name : str - The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable". + The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi". implementationStep : str The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, it is "PostStreaming". @@ -1035,43 +1046,12 @@ class InterpolatedBounceBackDifferentiable(BoundaryCondition): An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls """ - def __init__(self, indices, implicit_distances, grid_info, precision_policy): + def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): - super().__init__(indices, grid_info, precision_policy) - self.name = "InterpolatedBounceBackDifferentiable" - self.implementationStep = "PostStreaming" + super().__init__(indices, grid_info, precision_policy, vel=vel) + self.name = "InterpolatedBounceBackBouzidi" self.implicit_distances = implicit_distances self.weights = None - self.isSolid = True - self.needsExtraConfiguration = True - - def configure(self, boundaryBitmask): - """ - Configures the boundary condition. - - Parameters - ---------- - boundaryBitmask : array-like - The connectivity bitmask for the boundary voxels. - - Returns - ------- - None - - Notes - ----- - This method performs an index shift for the halfway bounce-back boundary condition. It updates the indices of - the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes. - """ - # Perform index shift for halfway BB. - hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices] - idx = np.array(self.indices).T - idx_trg = [] - for i in range(self.lattice.q): - idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) - indices_new = np.unique(np.vstack(idx_trg), axis=0) - self.indices = tuple(indices_new.T) - return def set_proximity_ratio(self): """ @@ -1118,17 +1098,36 @@ def apply(self, fout, fin): f_postcollision_iknown = fin[self.indices][bindex, self.iknown] f_postcollision_imissing = fin[self.indices][bindex, self.imissing] f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] - fmissing = ((1. - self.weights) * f_postcollision_iknown + - self.weights * (f_postcollision_imissing + f_poststreaming_iknown)) / (1.0 + self.weights) + + # if weights<0.5 + fs_near = 2. * self.weights * f_postcollision_iknown + \ + (1.0 - 2.0 * self.weights) * f_poststreaming_iknown + + # if weights>=0.5 + fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \ + (2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing + + # combine near and far contributions + fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far) fbd = fbd.at[bindex, self.imissing].set(fmissing) + if self.vel is not None: + fbd = self.impose_boundary_vel(fbd, bindex) return fbd -class InterpolatedBounceBackBouzidi(BoundaryCondition): +class InterpolatedBounceBackDifferentiable(InterpolatedBounceBackBouzidi): """ - A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice - Boltzmann method simulation. + A differentiable variant of the "InterpolatedBounceBackBouzidi" BC scheme. This BC is now differentiable at + self.weight = 0.5 unlike the original Bouzidi scheme which switches between 2 equations at weight=0.5. Refer to + [1] (their Appendix E) for more information. + + References + ---------- + [1] Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three + dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547. + doi:10.1016/j.camwa.2015.05.001. + This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after the streaming step. @@ -1136,7 +1135,7 @@ class InterpolatedBounceBackBouzidi(BoundaryCondition): Attributes ---------- name : str - The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi". + The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable". implementationStep : str The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, it is "PostStreaming". @@ -1148,63 +1147,11 @@ class InterpolatedBounceBackBouzidi(BoundaryCondition): An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls """ - def __init__(self, indices, implicit_distances, grid_info, precision_policy): + def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): - super().__init__(indices, grid_info, precision_policy) - self.name = "InterpolatedBounceBackBouzidi" - self.implementationStep = "PostStreaming" - self.implicit_distances = implicit_distances - self.weights = None - self.isSolid = True - self.needsExtraConfiguration = True - - def configure(self, boundaryBitmask): - """ - Configures the boundary condition. - - Parameters - ---------- - boundaryBitmask : array-like - The connectivity bitmask for the boundary voxels. - - Returns - ------- - None - - Notes - ----- - This method performs an index shift for the halfway bounce-back boundary condition. It updates the indices of - the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes. - """ - # Perform index shift for halfway BB. - hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices] - idx = np.array(self.indices).T - idx_trg = [] - for i in range(self.lattice.q): - idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) - indices_new = np.unique(np.vstack(idx_trg), axis=0) - self.indices = tuple(indices_new.T) - return - - def set_proximity_ratio(self): - """ - Creates the interpolation data needed for the boundary condition. + super().__init__(indices, implicit_distances, grid_info, precision_policy, vel=vel) + self.name = "InterpolatedBounceBackDifferentiable" - Returns - ------- - None. The function updates the object's weights attribute in place. - """ - idx = np.array(self.indices).T - self.weights = np.full((idx.shape[0], self.lattice.q), 0.5) - c = np.array(self.lattice.c) - sdf_f = self.implicit_distances[self.indices] - for q in range(1, self.lattice.q): - solid_indices = idx + c[:, q] - solid_indices_tuple = tuple(map(tuple, solid_indices.T)) - sdf_s = self.implicit_distances[solid_indices_tuple] - mask = self.iknownBitmask[:, q] - self.weights[mask, q] = sdf_f[mask] / (sdf_f[mask] - sdf_s[mask]) - return @partial(jit, static_argnums=(0,)) def apply(self, fout, fin): @@ -1231,16 +1178,10 @@ def apply(self, fout, fin): f_postcollision_iknown = fin[self.indices][bindex, self.iknown] f_postcollision_imissing = fin[self.indices][bindex, self.imissing] f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] - - # if weights<0.5 - fs_near = 2. * self.weights * f_postcollision_iknown + \ - (1.0 - 2.0 * self.weights) * f_poststreaming_iknown - - # if weights>=0.5 - fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \ - (2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing - - # combine near and far contributions - fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far) + fmissing = ((1. - self.weights) * f_postcollision_iknown + + self.weights * (f_postcollision_imissing + f_poststreaming_iknown)) / (1.0 + self.weights) fbd = fbd.at[bindex, self.imissing].set(fmissing) + + if self.vel is not None: + fbd = self.impose_boundary_vel(fbd, bindex) return fbd \ No newline at end of file