Skip to content

Commit

Permalink
fixed a bug in InterpDiffBB. Added more comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Nov 23, 2023
1 parent ed256cf commit b101e16
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
24 changes: 6 additions & 18 deletions src/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,15 +1035,12 @@ class InterpolatedBounceBackBouzidi(BounceBackHalfway):
----------
name : str
The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi".
implementationStep : str
The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class,
it is "PostStreaming".
isSolid : bool
Whether the boundary condition represents a solid boundary. For this class, it is True.
needsExtraConfiguration : bool
Whether the boundary condition needs extra configuration before it can be applied. For this class, it is True.
implicit_distances : array-like
An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls
weights : array-like
An array of shape (number_of_bc_cells, q) initialized as None and constructed using implicit_distances array
during runtime. These "weights" are associated with the fractional distance of fluid cell to the boundary
position defined as: weights(dir_i) = |x_fluid - x_boundary(dir_i)| / |x_fluid - x_solid(dir_i)|.
"""

def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None):
Expand Down Expand Up @@ -1136,15 +1133,6 @@ class InterpolatedBounceBackDifferentiable(InterpolatedBounceBackBouzidi):
----------
name : str
The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable".
implementationStep : str
The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class,
it is "PostStreaming".
isSolid : bool
Whether the boundary condition represents a solid boundary. For this class, it is True.
needsExtraConfiguration : bool
Whether the boundary condition needs extra configuration before it can be applied. For this class, it is True.
implicit_distances : array-like
An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls
"""

def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None):
Expand Down Expand Up @@ -1178,8 +1166,8 @@ def apply(self, fout, fin):
f_postcollision_iknown = fin[self.indices][bindex, self.iknown]
f_postcollision_imissing = fin[self.indices][bindex, self.imissing]
f_poststreaming_iknown = fout[self.indices][bindex, self.iknown]
fmissing = ((1. - self.weights) * f_postcollision_iknown +
self.weights * (f_postcollision_imissing + f_poststreaming_iknown)) / (1.0 + self.weights)
fmissing = ((1. - self.weights) * f_poststreaming_iknown +
self.weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + self.weights)
fbd = fbd.at[bindex, self.imissing].set(fmissing)

if self.vel is not None:
Expand Down
9 changes: 8 additions & 1 deletion src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ def collision(self, f):
@partial(jit, static_argnums=(0,), donate_argnums=(1,))
def collision_modified(self, f):
"""
KBC collision step for lattice.
Alternative KBC collision step for lattice.
Note:
At low Reynolds number the orignal KBC collision above produces inaccurate results because
it does not check for the entropy increase/decrease. The KBC stabalizations should only be
applied in principle to cells whose entropy decrease after a regular BGK collision. This is
the case in most cells at higher Reynolds numbers and hence a check may not be needed.
Overall the following alternative collision is more reliable and may replace the original
implementation. The issue at the moment is that it is about 60-80% slower than the above method.
"""
f = self.precisionPolicy.cast_to_compute(f)
tiny = 1e-32
Expand Down

0 comments on commit b101e16

Please sign in to comment.