From 1cc98ec7eda2a9e05e3815b59a15ba7fca236e6d Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 2 Oct 2023 16:13:03 -0400 Subject: [PATCH] adding the modified exampels that are being used for benchmarking XLB --- examples/CFD/cavity3d.py | 58 +++++++++++-------- examples/CFD/cylinder2d.py | 114 ++++++++++++++++++++++--------------- src/models.py | 34 +++++++++++ 3 files changed, 134 insertions(+), 72 deletions(-) diff --git a/examples/CFD/cavity3d.py b/examples/CFD/cavity3d.py index d912786..f7e35c8 100644 --- a/examples/CFD/cavity3d.py +++ b/examples/CFD/cavity3d.py @@ -4,7 +4,7 @@ In this example you'll be introduced to the following concepts: -1. Lattice: The simulation employs a D2Q9 lattice. It's a 2D lattice model with nine discrete velocity directions, which is typically used for 2D simulations. +1. Lattice: The simulation employs a D3Q27 lattice. It's a 3D lattice model with 27 discrete velocity directions. 2. Boundary Conditions: The code implements two types of boundary conditions: @@ -14,21 +14,19 @@ 4. Visualization: The simulation outputs data in VTK format for visualization. The data can be visualized using software like Paraview. """ - -import os - # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' from src.models import BGKSim, KBCSim from src.lattice import LatticeD3Q27 -import numpy as np from src.utils import * from jax.config import config from src.boundary_conditions import * +import json, codecs -precision = 'f32/f32' +precision = 'f64/f64' +config.update('jax_enable_x64', True) -class Cavity(KBCSim): +class Cavity(BGKSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -39,7 +37,7 @@ def set_boundary_conditions(self): self.boundingBoxIndices['front'], self.boundingBoxIndices['back'], self.boundingBoxIndices['bottom'])) # apply bounce back boundary condition to the walls - self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy)) + self.BCs.append(BounceBackHalfway(tuple(walls.T), self.gridInfo, self.precisionPolicy)) # apply inlet equilibrium boundary condition to the top wall moving_wall = self.boundingBoxIndices['top'] @@ -47,14 +45,14 @@ def set_boundary_conditions(self): rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) vel_wall[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) + self.BCs.append(BounceBackHalfway(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, vel_wall)) def output_data(self, **kwargs): # 1: -1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho'][1:-1, 1:-1, 1:-1]) - u = np.array(kwargs['u'][1:-1, 1:-1, 1:-1, :]) + rho = np.array(kwargs['rho']) + u = np.array(kwargs['u']) timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'][1:-1, 1:-1, 1:-1, :] + u_prev = kwargs['u_prev'] u_old = np.linalg.norm(u_prev, axis=2) u_new = np.linalg.norm(u, axis=2) @@ -62,26 +60,36 @@ def output_data(self, **kwargs): err = np.sum(np.abs(u_old - u_new)) print('error= {:07.6f}'.format(err)) fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) + # save_fields_vtk(timestep, fields) + + # output profiles of velocity at mid-plane for benchmarking + output_filename = "./profiles_" + f"{timestep:07d}.json" + ldc_ref_result = {'ux(x=y=0)': list(u[nx//2, ny//2, :, 0]/prescribed_vel), + 'uz(z=y=0)': list(u[:, ny//2, nz//2, 2]/prescribed_vel)} + json.dump(ldc_ref_result, codecs.open(output_filename, 'w', encoding='utf-8'), + separators=(',', ':'), + sort_keys=True, + indent=4) + # Calculate the velocity magnitude - u_mag = np.linalg.norm(u, axis=2) + # u_mag = np.linalg.norm(u, axis=2) # live_volume_randering(timestep, u_mag) if __name__ == '__main__': lattice = LatticeD3Q27(precision) - nx = 101 - ny = 101 - nz = 101 + nx = 256 + ny = 256 + nz = 256 - Re = 50000.0 - prescribed_vel = 0.1 - clength = nx - 1 + Re = 1000.0 + prescribed_vel = 0.06 + clength = nx - 2 visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) print('omega = ', omega) - + os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { @@ -91,9 +99,9 @@ def output_data(self, **kwargs): 'ny': ny, 'nz': nz, 'precision': precision, - 'io_rate': 100, - 'print_info_rate': 100, - 'downsampling_factor': 2 + 'io_rate': 10000, + 'print_info_rate': 10000, + 'downsampling_factor': 1 } sim = Cavity(**kwargs) - sim.run(2000) \ No newline at end of file + sim.run(1000000) \ No newline at end of file diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py index 5120668..deade62 100644 --- a/examples/CFD/cylinder2d.py +++ b/examples/CFD/cylinder2d.py @@ -27,13 +27,14 @@ from src.models import BGKSim, KBCSim import jax.numpy as jnp import os +import json # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' import jax jax.config.update('jax_enable_x64', True) -class Cylinder(KBCSim): +class Cylinder(BGKSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -45,8 +46,8 @@ def set_boundary_conditions(self): cylinder = (xx - cx)**2 + (yy-cy)**2 <= (diam/2.)**2 cylinder = coord[cylinder] implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (diam/2.)**2, (self.nx, self.ny)) - self.BCs.append(InterpolatedBounceBackLocal(tuple(cylinder.T), implicit_distance, - self.gridInfo, self.precisionPolicy)) + self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy)) + # self.BCs.append(BounceBackHalfway(tuple(cylinder.T), self.gridInfo, self.precisionPolicy)) # wall = np.concatenate([cylinder, self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']]) # self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) @@ -76,53 +77,72 @@ def output_data(self, **kwargs): timestep = kwargs["timestep"] u_prev = kwargs["u_prev"][..., 1:-1, :] - # compute lift and drag over the cyliner - cylinder = self.BCs[0] - boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(boundary_force, axis=0) - drag = boundary_force[0] - lift = boundary_force[1] - cd = 2. * drag / (prescribed_vel ** 2 * diam) - cl = 2. * lift / (prescribed_vel ** 2 * diam) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - save_image(timestep, u) + if timestep == 0: + self.CL_max = 0.0 + self.CD_max = 0.0 + if timestep > 0.8 * t_max: + # compute lift and drag over the cyliner + cylinder = self.BCs[0] + boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) + boundary_force = np.sum(np.array(boundary_force), axis=0) + drag = boundary_force[0] + lift = boundary_force[1] + cd = 2. * drag / (prescribed_vel ** 2 * diam) + cl = 2. * lift / (prescribed_vel ** 2 * diam) + + u_old = np.linalg.norm(u_prev, axis=2) + u_new = np.linalg.norm(u, axis=2) + err = np.sum(np.abs(u_old - u_new)) + self.CL_max = max(self.CL_max, cl) + self.CD_max = max(self.CD_max, cd) + print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) + # save_image(timestep, u) # Helper function to specify a parabolic poiseuille profile poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2)) if __name__ == '__main__': precision = 'f64/f64' - prescribed_vel = 0.005 - diam = 80 - lattice = LatticeD2Q9(precision) - - nx = int(22*diam) - ny = int(4.1*diam) - - Re = 100.0 - visc = prescribed_vel * diam / Re - omega = 1.0 / (3. * visc + 0.5) - - print('omega = ', omega) - print("Mesh size: ", nx, ny) - print("Number of voxels: ", nx * ny) - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': 500, - 'print_info_rate': 500, - 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag - } - sim = Cylinder(**kwargs) - sim.run(1000000) + diam_list = [10, 20, 30, 40, 60, 80] + CL_list, CD_list = [], [] + result_dict = {} + result_dict['resolution_list'] = diam_list + for diam in diam_list: + scale_factor = 80 / diam + prescribed_vel = 0.003 * scale_factor + lattice = LatticeD2Q9(precision) + + nx = int(22*diam) + ny = int(4.1*diam) + + Re = 100.0 + visc = prescribed_vel * diam / Re + omega = 1.0 / (3. * visc + 0.5) + + print('omega = ', omega) + print("Mesh size: ", nx, ny) + print("Number of voxels: ", nx * ny) + + os.system('rm -rf ./*.vtk && rm -rf ./*.png') + + kwargs = { + 'lattice': lattice, + 'omega': omega, + 'nx': nx, + 'ny': ny, + 'nz': 0, + 'precision': precision, + 'io_rate': int(500 / scale_factor), + 'print_info_rate': int(10000 / scale_factor), + 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag + } + sim = Cylinder(**kwargs) + t_max = int(1000000 / scale_factor) + sim.run(t_max) + CL_list.append(sim.CL_max) + CD_list.append(sim.CD_max) + + result_dict['CL'] = CL_list + result_dict['CD'] = CD_list + with open('data.json', 'w') as fp: + json.dump(result_dict, fp) diff --git a/src/models.py b/src/models.py index 7e5e825..7d0424c 100644 --- a/src/models.py +++ b/src/models.py @@ -68,6 +68,40 @@ def collision(self, f): if self.force is not None: fout = self.apply_force(fout, feq, rho, u) return self.precisionPolicy.cast_to_output(fout) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def collision_modified(self, f): + """ + KBC collision step for lattice. + """ + f = self.precisionPolicy.cast_to_compute(f) + tiny = 1e-32 + beta = self.omega * 0.5 + rho, u = self.update_macroscopic(f) + feq = self.equilibrium(rho, u, castOutput=False) + + # Alternative KBC: only stabalizes for voxels whose entropy decreases after BGK collision. + f_bgk = f - self.omega * (f - feq) + H_fin = jnp.sum(f * jnp.log(f / self.w), axis=-1, keepdims=True) + H_fout = jnp.sum(f_bgk * jnp.log(f_bgk / self.w), axis=-1, keepdims=True) + + # the rest is identical to collision_deprecated + fneq = f - feq + if self.dim == 2: + deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 + else: + deltaS = self.fdecompose_shear_d3q27(fneq) * rho + deltaH = fneq - deltaS + invBeta = 1.0 / beta + gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq)) + + f_kbc = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH) + fout = jnp.where(H_fout > H_fin, f_kbc, f_bgk) + + # add external force + if self.force is not None: + fout = self.apply_force(fout, feq, rho, u) + return self.precisionPolicy.cast_to_output(fout) @partial(jit, static_argnums=(0,), inline=True) def entropic_scalar_product(self, x, y, feq):