diff --git a/examples/CFD/cavity3d.py b/examples/CFD/cavity3d.py index 58db262..8961595 100644 --- a/examples/CFD/cavity3d.py +++ b/examples/CFD/cavity3d.py @@ -4,7 +4,7 @@ In this example you'll be introduced to the following concepts: -1. Lattice: The simulation employs a D2Q9 lattice. It's a 2D lattice model with nine discrete velocity directions, which is typically used for 2D simulations. +1. Lattice: The simulation employs a D3Q27 lattice. It's a 3D lattice model with 27 discrete velocity directions. 2. Boundary Conditions: The code implements two types of boundary conditions: @@ -14,47 +14,54 @@ 4. Visualization: The simulation outputs data in VTK format for visualization. The data can be visualized using software like Paraview. """ - -import os - # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' import numpy as np from src.utils import * from jax.config import config +import json, codecs from src.models import BGKSim, KBCSim from src.lattice import LatticeD3Q19, LatticeD3Q27 from src.boundary_conditions import * + +config.update('jax_enable_x64', True) + class Cavity(KBCSim): + # Note: We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000 def __init__(self, **kwargs): super().__init__(**kwargs) def set_boundary_conditions(self): + # Note: + # We have used halfway BB for Re=(1000, 3200) and regularized BC for Re=10,000 + + # apply inlet boundary condition to the top wall + moving_wall = self.boundingBoxIndices['top'] + vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) + vel_wall[:, 0] = prescribed_vel + # self.BCs.append(BounceBackHalfway(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, vel_wall)) + self.BCs.append(Regularized(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall)) + # concatenate the indices of the left, right, and bottom walls walls = np.concatenate( (self.boundingBoxIndices['left'], self.boundingBoxIndices['right'], self.boundingBoxIndices['front'], self.boundingBoxIndices['back'], self.boundingBoxIndices['bottom'])) # apply bounce back boundary condition to the walls - self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - - # apply inlet equilibrium boundary condition to the top wall - moving_wall = self.boundingBoxIndices['top'] - - rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) + # self.BCs.append(BounceBackHalfway(tuple(walls.T), self.gridInfo, self.precisionPolicy)) + vel_wall = np.zeros(walls.shape, dtype=self.precisionPolicy.compute_dtype) + self.BCs.append(Regularized(tuple(walls.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall)) + return def output_data(self, **kwargs): # 1: -1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho'][1:-1, 1:-1, 1:-1]) - u = np.array(kwargs['u'][1:-1, 1:-1, 1:-1, :]) + rho = np.array(kwargs['rho']) + u = np.array(kwargs['u']) timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'][1:-1, 1:-1, 1:-1, :] + u_prev = kwargs['u_prev'] u_old = np.linalg.norm(u_prev, axis=2) u_new = np.linalg.norm(u, axis=2) @@ -62,26 +69,43 @@ def output_data(self, **kwargs): err = np.sum(np.abs(u_old - u_new)) print('error= {:07.6f}'.format(err)) fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) + # save_fields_vtk(timestep, fields) + + # output profiles of velocity at mid-plane for benchmarking + output_filename = "./profiles_" + f"{timestep:07d}.json" + ux_mid = 0.5*(u[nx//2, ny//2, :, 0] + u[nx//2+1, ny//2+1, :, 0]) + uz_mid = 0.5*(u[:, ny//2, nz//2, 2] + u[:, ny//2+1, nz//2+1, 2]) + ldc_ref_result = {'ux(x=y=0)': list(ux_mid/prescribed_vel), + 'uz(z=y=0)': list(uz_mid/prescribed_vel)} + json.dump(ldc_ref_result, codecs.open(output_filename, 'w', encoding='utf-8'), + separators=(',', ':'), + sort_keys=True, + indent=4) + # Calculate the velocity magnitude - u_mag = np.linalg.norm(u, axis=2) + # u_mag = np.linalg.norm(u, axis=2) # live_volume_randering(timestep, u_mag) if __name__ == '__main__': - nx = 101 - ny = 101 - nz = 101 + # Note: + # We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000 + precision = 'f64/f64' + lattice = LatticeD3Q27(precision) - Re = 50000.0 - prescribed_vel = 0.1 - clength = nx - 1 + nx = 256 + ny = 256 + nz = 256 - precision = 'f32/f32' - lattice = LatticeD3Q27(precision) + Re = 10000.0 + prescribed_vel = 0.06 + clength = nx - 2 + + # characteristic time + tc = prescribed_vel/clength + niter_max = int(500//tc) visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { @@ -91,9 +115,9 @@ def output_data(self, **kwargs): 'ny': ny, 'nz': nz, 'precision': precision, - 'io_rate': 100, - 'print_info_rate': 100, - 'downsampling_factor': 2 + 'io_rate': int(10//tc), + 'print_info_rate': int(10//tc), + 'downsampling_factor': 1 } sim = Cavity(**kwargs) - sim.run(2000) \ No newline at end of file + sim.run(niter_max) \ No newline at end of file diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py index 9fa9779..cf49ad4 100644 --- a/examples/CFD/cylinder2d.py +++ b/examples/CFD/cylinder2d.py @@ -18,6 +18,7 @@ """ import os +import json import jax from time import time from jax.config import config @@ -33,7 +34,7 @@ # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' jax.config.update('jax_enable_x64', True) -class Cylinder(KBCSim): +class Cylinder(BGKSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -44,15 +45,15 @@ def set_boundary_conditions(self): cx, cy = 2.*diam, 2.*diam cylinder = (xx - cx)**2 + (yy-cy)**2 <= (diam/2.)**2 cylinder = coord[cylinder] - self.BCs.append(BounceBackHalfway(tuple(cylinder.T), self.gridInfo, self.precisionPolicy)) - # wall = np.concatenate([cylinder, self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']]) - # self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) + implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (diam/2.)**2, (self.nx, self.ny)) + self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy)) + # Outflow BC outlet = self.boundingBoxIndices['right'] rho_outlet = np.ones(outlet.shape[0], dtype=self.precisionPolicy.compute_dtype) self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy)) - # self.BCs.append(Regularized(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet)) + # Inlet BC inlet = self.boundingBoxIndices['left'] rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) @@ -60,67 +61,83 @@ def set_boundary_conditions(self): vel_inlet[:, 0] = poiseuille_profile(yy_inlet, yy_inlet.min(), yy_inlet.max()-yy_inlet.min(), 3.0 / 2.0 * prescribed_vel) - # self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) self.BCs.append(Regularized(tuple(inlet.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_inlet)) + # No-slip BC for top and bottom wall = np.concatenate([self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']]) - self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) + vel_wall = np.zeros(wall.shape, dtype=self.precisionPolicy.compute_dtype) + self.BCs.append(Regularized(tuple(wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall)) def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) + # 1:-1 to remove boundary voxels (not needed for visualization when using bounce-back) rho = np.array(kwargs["rho"][..., 1:-1, :]) u = np.array(kwargs["u"][..., 1:-1, :]) timestep = kwargs["timestep"] u_prev = kwargs["u_prev"][..., 1:-1, :] - # compute lift and drag over the cyliner - cylinder = self.BCs[0] - boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(boundary_force, axis=0) - drag = boundary_force[0] - lift = boundary_force[1] - cd = 2. * drag / (prescribed_vel ** 2 * diam) - cl = 2. * lift / (prescribed_vel ** 2 * diam) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - save_image(timestep, u) + if timestep == 0: + self.CL_max = 0.0 + self.CD_max = 0.0 + if timestep > 0.8 * t_max: + # compute lift and drag over the cyliner + cylinder = self.BCs[0] + boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) + boundary_force = np.sum(np.array(boundary_force), axis=0) + drag = boundary_force[0] + lift = boundary_force[1] + cd = 2. * drag / (prescribed_vel ** 2 * diam) + cl = 2. * lift / (prescribed_vel ** 2 * diam) + + u_old = np.linalg.norm(u_prev, axis=2) + u_new = np.linalg.norm(u, axis=2) + err = np.sum(np.abs(u_old - u_new)) + self.CL_max = max(self.CL_max, cl) + self.CD_max = max(self.CD_max, cd) + print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) + save_image(timestep, u) # Helper function to specify a parabolic poiseuille profile poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2)) if __name__ == '__main__': precision = 'f64/f64' - lattice = LatticeD2Q9(precision) - - prescribed_vel = 0.005 - diam = 80 - - nx = int(22*diam) - ny = int(4.1*diam) - - Re = 100.0 - visc = prescribed_vel * diam / Re - omega = 1.0 / (3. * visc + 0.5) - - print('omega = ', omega) - print("Mesh size: ", nx, ny) - print("Number of voxels: ", nx * ny) - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': 500, - 'print_info_rate': 500, - 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag - } - sim = Cylinder(**kwargs) - sim.run(1000000) + # diam_list = [10, 20, 30, 40, 60, 80] + diam_list = [80] + CL_list, CD_list = [], [] + result_dict = {} + result_dict['resolution_list'] = diam_list + for diam in diam_list: + scale_factor = 80 / diam + prescribed_vel = 0.003 * scale_factor + lattice = LatticeD2Q9(precision) + + nx = int(22*diam) + ny = int(4.1*diam) + + Re = 100.0 + visc = prescribed_vel * diam / Re + omega = 1.0 / (3. * visc + 0.5) + + os.system('rm -rf ./*.vtk && rm -rf ./*.png') + + kwargs = { + 'lattice': lattice, + 'omega': omega, + 'nx': nx, + 'ny': ny, + 'nz': 0, + 'precision': precision, + 'io_rate': int(500 / scale_factor), + 'print_info_rate': int(10000 / scale_factor), + 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag + } + sim = Cylinder(**kwargs) + t_max = int(1000000 / scale_factor) + sim.run(t_max) + CL_list.append(sim.CL_max) + CD_list.append(sim.CD_max) + + result_dict['CL'] = CL_list + result_dict['CD'] = CD_list + with open('data.json', 'w') as fp: + json.dump(result_dict, fp) diff --git a/examples/CFD/taylor_green_vortex.py b/examples/CFD/taylor_green_vortex.py index e52b7d3..374c499 100644 --- a/examples/CFD/taylor_green_vortex.py +++ b/examples/CFD/taylor_green_vortex.py @@ -50,7 +50,7 @@ def initialize_populations(self, rho, u): ADE = AdvectionDiffusionBGK(**kwargs) ADE.initialize_macroscopic_fields = self.initialize_macroscopic_fields print("Initializing the distribution functions using the specified macroscopic fields....") - f = ADE.run(int(20000*32/nx)) + f = ADE.run(int(20000*nx/32)) return f def output_data(self, **kwargs): @@ -83,7 +83,6 @@ def output_data(self, **kwargs): ErrL2ResListRho = [] result_dict[precision] = dict.fromkeys(['vel_error', 'rho_error']) for nx in resList: - print("Running at nx = ny = {:07.6f}".format(nx)) ny = nx twopi = 2.0 * np.pi coord = np.array([(i, j) for i in range(nx) for j in range(ny)]) diff --git a/src/base.py b/src/base.py index fd7d32e..5813563 100644 --- a/src/base.py +++ b/src/base.py @@ -11,17 +11,16 @@ # JAX-related imports from jax import jit, lax, vmap -from jax.config import config from jax.experimental import mesh_utils from jax.experimental.multihost_utils import process_allgather from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding, PartitionSpec, PositionalSharding, Mesh import orbax.checkpoint as orb + # functools imports from functools import partial # Local/Custom Libraries -import src.models from src.utils import downsample_field jax.config.update("jax_spmd_mode", 'allow_all') diff --git a/src/boundary_conditions.py b/src/boundary_conditions.py index 73cf55d..8d0ad10 100644 --- a/src/boundary_conditions.py +++ b/src/boundary_conditions.py @@ -497,15 +497,29 @@ def configure(self, boundaryBitmask): the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes. """ # Perform index shift for halfway BB. - shiftDir = ~boundaryBitmask[:, self.lattice.opp_indices] + hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices] + nbd_orig = len(self.indices[0]) idx = np.array(self.indices).T idx_trg = [] for i in range(self.lattice.q): - idx_trg.append(idx[shiftDir[:, i], :] + self.lattice.c[:, i]) + idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) indices_new = np.unique(np.vstack(idx_trg), axis=0) self.indices = tuple(indices_new.T) + nbd_modified = len(self.indices[0]) + if (nbd_orig != nbd_modified) and self.vel is not None: + vel_avg = np.mean(self.vel, axis=0) + self.vel = jnp.zeros(indices_new.shape, dtype=self.precisionPolicy.compute_dtype) + vel_avg + print("WARNING: assuming a constant averaged velocity vector is imposed at all BC cells!") + return + @partial(jit, static_argnums=(0,)) + def impose_boundary_vel(self, fbd, bindex): + c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) + cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c) + fbd = fbd.at[bindex, self.imissing].add(-cu[bindex, self.iknown]) + return fbd + @partial(jit, static_argnums=(0,)) def apply(self, fout, fin): """ @@ -526,13 +540,10 @@ def apply(self, fout, fin): nbd = len(self.indices[0]) bindex = np.arange(nbd)[:, None] fbd = fout[self.indices] - if self.vel is not None: - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c) - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown] - cu[bindex, self.iknown]) - else: - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) + fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) + if self.vel is not None: + fbd = self.impose_boundary_vel(fbd, bindex) return fbd class EquilibriumBC(BoundaryCondition): @@ -937,11 +948,11 @@ def configure(self, boundaryBitmask): boundaryBitmask : np.ndarray The connectivity bitmask for the boundary voxels. """ - shiftDir = ~boundaryBitmask[:, self.lattice.opp_indices] + hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices] idx = np.array(self.indices).T idx_trg = [] for i in range(self.lattice.q): - idx_trg.append(idx[shiftDir[:, i], :] + self.lattice.c[:, i]) + idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) indices_nbr = np.unique(np.vstack(idx_trg), axis=0) self.indices_nbr = tuple(indices_nbr.T) @@ -1010,3 +1021,155 @@ def apply(self, fout, fin): fbd = fout[self.indices] fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) return fbd + + +class InterpolatedBounceBackBouzidi(BounceBackHalfway): + """ + A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice + Boltzmann method simulation. + + This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after + the streaming step. + + Attributes + ---------- + name : str + The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi". + implicit_distances : array-like + An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls + weights : array-like + An array of shape (number_of_bc_cells, q) initialized as None and constructed using implicit_distances array + during runtime. These "weights" are associated with the fractional distance of fluid cell to the boundary + position defined as: weights(dir_i) = |x_fluid - x_boundary(dir_i)| / |x_fluid - x_solid(dir_i)|. + """ + + def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): + + super().__init__(indices, grid_info, precision_policy, vel=vel) + self.name = "InterpolatedBounceBackBouzidi" + self.implicit_distances = implicit_distances + self.weights = None + + def set_proximity_ratio(self): + """ + Creates the interpolation data needed for the boundary condition. + + Returns + ------- + None. The function updates the object's weights attribute in place. + """ + idx = np.array(self.indices).T + self.weights = np.full((idx.shape[0], self.lattice.q), 0.5) + c = np.array(self.lattice.c) + sdf_f = self.implicit_distances[self.indices] + for q in range(1, self.lattice.q): + solid_indices = idx + c[:, q] + solid_indices_tuple = tuple(map(tuple, solid_indices.T)) + sdf_s = self.implicit_distances[solid_indices_tuple] + mask = self.iknownBitmask[:, q] + self.weights[mask, q] = sdf_f[mask] / (sdf_f[mask] - sdf_s[mask]) + return + + @partial(jit, static_argnums=(0,)) + def apply(self, fout, fin): + """ + Applies the halfway bounce-back boundary condition. + + Parameters + ---------- + fout : jax.numpy.ndarray + The output distribution functions. + fin : jax.numpy.ndarray + The input distribution functions. + + Returns + ------- + jax.numpy.ndarray + The modified output distribution functions after applying the boundary condition. + """ + if self.weights is None: + self.set_proximity_ratio() + nbd = len(self.indices[0]) + bindex = np.arange(nbd)[:, None] + fbd = fout[self.indices] + f_postcollision_iknown = fin[self.indices][bindex, self.iknown] + f_postcollision_imissing = fin[self.indices][bindex, self.imissing] + f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] + + # if weights<0.5 + fs_near = 2. * self.weights * f_postcollision_iknown + \ + (1.0 - 2.0 * self.weights) * f_poststreaming_iknown + + # if weights>=0.5 + fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \ + (2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing + + # combine near and far contributions + fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far) + fbd = fbd.at[bindex, self.imissing].set(fmissing) + + if self.vel is not None: + fbd = self.impose_boundary_vel(fbd, bindex) + return fbd + + +class InterpolatedBounceBackDifferentiable(InterpolatedBounceBackBouzidi): + """ + A differentiable variant of the "InterpolatedBounceBackBouzidi" BC scheme. This BC is now differentiable at + self.weight = 0.5 unlike the original Bouzidi scheme which switches between 2 equations at weight=0.5. Refer to + [1] (their Appendix E) for more information. + + References + ---------- + [1] Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three + dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547. + doi:10.1016/j.camwa.2015.05.001. + + + This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after + the streaming step. + + Attributes + ---------- + name : str + The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable". + """ + + def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): + + super().__init__(indices, implicit_distances, grid_info, precision_policy, vel=vel) + self.name = "InterpolatedBounceBackDifferentiable" + + + @partial(jit, static_argnums=(0,)) + def apply(self, fout, fin): + """ + Applies the halfway bounce-back boundary condition. + + Parameters + ---------- + fout : jax.numpy.ndarray + The output distribution functions. + fin : jax.numpy.ndarray + The input distribution functions. + + Returns + ------- + jax.numpy.ndarray + The modified output distribution functions after applying the boundary condition. + """ + if self.weights is None: + self.set_proximity_ratio() + nbd = len(self.indices[0]) + bindex = np.arange(nbd)[:, None] + fbd = fout[self.indices] + f_postcollision_iknown = fin[self.indices][bindex, self.iknown] + f_postcollision_imissing = fin[self.indices][bindex, self.imissing] + f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] + fmissing = ((1. - self.weights) * f_poststreaming_iknown + + self.weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + self.weights) + fbd = fbd.at[bindex, self.imissing].set(fmissing) + + if self.vel is not None: + fbd = self.impose_boundary_vel(fbd, bindex) + return fbd \ No newline at end of file diff --git a/src/models.py b/src/models.py index a7af7b6..a0500c8 100644 --- a/src/models.py +++ b/src/models.py @@ -69,6 +69,47 @@ def collision(self, f): if self.force is not None: fout = self.apply_force(fout, feq, rho, u) return self.precisionPolicy.cast_to_output(fout) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def collision_modified(self, f): + """ + Alternative KBC collision step for lattice. + Note: + At low Reynolds number the orignal KBC collision above produces inaccurate results because + it does not check for the entropy increase/decrease. The KBC stabalizations should only be + applied in principle to cells whose entropy decrease after a regular BGK collision. This is + the case in most cells at higher Reynolds numbers and hence a check may not be needed. + Overall the following alternative collision is more reliable and may replace the original + implementation. The issue at the moment is that it is about 60-80% slower than the above method. + """ + f = self.precisionPolicy.cast_to_compute(f) + tiny = 1e-32 + beta = self.omega * 0.5 + rho, u = self.update_macroscopic(f) + feq = self.equilibrium(rho, u, castOutput=False) + + # Alternative KBC: only stabalizes for voxels whose entropy decreases after BGK collision. + f_bgk = f - self.omega * (f - feq) + H_fin = jnp.sum(f * jnp.log(f / self.w), axis=-1, keepdims=True) + H_fout = jnp.sum(f_bgk * jnp.log(f_bgk / self.w), axis=-1, keepdims=True) + + # the rest is identical to collision_deprecated + fneq = f - feq + if self.dim == 2: + deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 + else: + deltaS = self.fdecompose_shear_d3q27(fneq) * rho + deltaH = fneq - deltaS + invBeta = 1.0 / beta + gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq)) + + f_kbc = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH) + fout = jnp.where(H_fout > H_fin, f_kbc, f_bgk) + + # add external force + if self.force is not None: + fout = self.apply_force(fout, feq, rho, u) + return self.precisionPolicy.cast_to_output(fout) @partial(jit, static_argnums=(0,), inline=True) def entropic_scalar_product(self, x, y, feq):