Skip to content

Commit

Permalink
adding the modified exampels that are being used for benchmarking XLB
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Oct 2, 2023
1 parent 047919d commit 1cc98ec
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 72 deletions.
58 changes: 33 additions & 25 deletions examples/CFD/cavity3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
In this example you'll be introduced to the following concepts:
1. Lattice: The simulation employs a D2Q9 lattice. It's a 2D lattice model with nine discrete velocity directions, which is typically used for 2D simulations.
1. Lattice: The simulation employs a D3Q27 lattice. It's a 3D lattice model with 27 discrete velocity directions.
2. Boundary Conditions: The code implements two types of boundary conditions:
Expand All @@ -14,21 +14,19 @@
4. Visualization: The simulation outputs data in VTK format for visualization. The data can be visualized using software like Paraview.
"""

import os

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q27
import numpy as np
from src.utils import *
from jax.config import config
from src.boundary_conditions import *
import json, codecs

precision = 'f32/f32'
precision = 'f64/f64'
config.update('jax_enable_x64', True)

class Cavity(KBCSim):
class Cavity(BGKSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -39,49 +37,59 @@ def set_boundary_conditions(self):
self.boundingBoxIndices['front'], self.boundingBoxIndices['back'],
self.boundingBoxIndices['bottom']))
# apply bounce back boundary condition to the walls
self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy))
self.BCs.append(BounceBackHalfway(tuple(walls.T), self.gridInfo, self.precisionPolicy))

# apply inlet equilibrium boundary condition to the top wall
moving_wall = self.boundingBoxIndices['top']

rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype)
vel_wall[:, 0] = prescribed_vel
self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall))
self.BCs.append(BounceBackHalfway(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, vel_wall))

def output_data(self, **kwargs):
# 1: -1 to remove boundary voxels (not needed for visualization when using full-way bounce-back)
rho = np.array(kwargs['rho'][1:-1, 1:-1, 1:-1])
u = np.array(kwargs['u'][1:-1, 1:-1, 1:-1, :])
rho = np.array(kwargs['rho'])
u = np.array(kwargs['u'])
timestep = kwargs['timestep']
u_prev = kwargs['u_prev'][1:-1, 1:-1, 1:-1, :]
u_prev = kwargs['u_prev']

u_old = np.linalg.norm(u_prev, axis=2)
u_new = np.linalg.norm(u, axis=2)

err = np.sum(np.abs(u_old - u_new))
print('error= {:07.6f}'.format(err))
fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]}
save_fields_vtk(timestep, fields)
# save_fields_vtk(timestep, fields)

# output profiles of velocity at mid-plane for benchmarking
output_filename = "./profiles_" + f"{timestep:07d}.json"
ldc_ref_result = {'ux(x=y=0)': list(u[nx//2, ny//2, :, 0]/prescribed_vel),
'uz(z=y=0)': list(u[:, ny//2, nz//2, 2]/prescribed_vel)}
json.dump(ldc_ref_result, codecs.open(output_filename, 'w', encoding='utf-8'),
separators=(',', ':'),
sort_keys=True,
indent=4)

# Calculate the velocity magnitude
u_mag = np.linalg.norm(u, axis=2)
# u_mag = np.linalg.norm(u, axis=2)
# live_volume_randering(timestep, u_mag)

if __name__ == '__main__':
lattice = LatticeD3Q27(precision)

nx = 101
ny = 101
nz = 101
nx = 256
ny = 256
nz = 256

Re = 50000.0
prescribed_vel = 0.1
clength = nx - 1
Re = 1000.0
prescribed_vel = 0.06
clength = nx - 2

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand All @@ -91,9 +99,9 @@ def output_data(self, **kwargs):
'ny': ny,
'nz': nz,
'precision': precision,
'io_rate': 100,
'print_info_rate': 100,
'downsampling_factor': 2
'io_rate': 10000,
'print_info_rate': 10000,
'downsampling_factor': 1
}
sim = Cavity(**kwargs)
sim.run(2000)
sim.run(1000000)
114 changes: 67 additions & 47 deletions examples/CFD/cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os
import json

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(KBCSim):
class Cylinder(BGKSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -45,8 +46,8 @@ def set_boundary_conditions(self):
cylinder = (xx - cx)**2 + (yy-cy)**2 <= (diam/2.)**2
cylinder = coord[cylinder]
implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (diam/2.)**2, (self.nx, self.ny))
self.BCs.append(InterpolatedBounceBackLocal(tuple(cylinder.T), implicit_distance,
self.gridInfo, self.precisionPolicy))
self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy))
# self.BCs.append(BounceBackHalfway(tuple(cylinder.T), self.gridInfo, self.precisionPolicy))

# wall = np.concatenate([cylinder, self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']])
# self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy))
Expand Down Expand Up @@ -76,53 +77,72 @@ def output_data(self, **kwargs):
timestep = kwargs["timestep"]
u_prev = kwargs["u_prev"][..., 1:-1, :]

# compute lift and drag over the cyliner
cylinder = self.BCs[0]
boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision'])
boundary_force = np.sum(boundary_force, axis=0)
drag = boundary_force[0]
lift = boundary_force[1]
cd = 2. * drag / (prescribed_vel ** 2 * diam)
cl = 2. * lift / (prescribed_vel ** 2 * diam)

u_old = np.linalg.norm(u_prev, axis=2)
u_new = np.linalg.norm(u, axis=2)
err = np.sum(np.abs(u_old - u_new))
print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd))
save_image(timestep, u)
if timestep == 0:
self.CL_max = 0.0
self.CD_max = 0.0
if timestep > 0.8 * t_max:
# compute lift and drag over the cyliner
cylinder = self.BCs[0]
boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision'])
boundary_force = np.sum(np.array(boundary_force), axis=0)
drag = boundary_force[0]
lift = boundary_force[1]
cd = 2. * drag / (prescribed_vel ** 2 * diam)
cl = 2. * lift / (prescribed_vel ** 2 * diam)

u_old = np.linalg.norm(u_prev, axis=2)
u_new = np.linalg.norm(u, axis=2)
err = np.sum(np.abs(u_old - u_new))
self.CL_max = max(self.CL_max, cl)
self.CD_max = max(self.CD_max, cd)
print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd))
# save_image(timestep, u)

# Helper function to specify a parabolic poiseuille profile
poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2))

if __name__ == '__main__':
precision = 'f64/f64'
prescribed_vel = 0.005
diam = 80
lattice = LatticeD2Q9(precision)

nx = int(22*diam)
ny = int(4.1*diam)

Re = 100.0
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
'nz': 0,
'precision': precision,
'io_rate': 500,
'print_info_rate': 500,
'return_fpost': True # Need to retain fpost-collision for computation of lift and drag
}
sim = Cylinder(**kwargs)
sim.run(1000000)
diam_list = [10, 20, 30, 40, 60, 80]
CL_list, CD_list = [], []
result_dict = {}
result_dict['resolution_list'] = diam_list
for diam in diam_list:
scale_factor = 80 / diam
prescribed_vel = 0.003 * scale_factor
lattice = LatticeD2Q9(precision)

nx = int(22*diam)
ny = int(4.1*diam)

Re = 100.0
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
'nz': 0,
'precision': precision,
'io_rate': int(500 / scale_factor),
'print_info_rate': int(10000 / scale_factor),
'return_fpost': True # Need to retain fpost-collision for computation of lift and drag
}
sim = Cylinder(**kwargs)
t_max = int(1000000 / scale_factor)
sim.run(t_max)
CL_list.append(sim.CL_max)
CD_list.append(sim.CD_max)

result_dict['CL'] = CL_list
result_dict['CD'] = CD_list
with open('data.json', 'w') as fp:
json.dump(result_dict, fp)
34 changes: 34 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,40 @@ def collision(self, f):
if self.force is not None:
fout = self.apply_force(fout, feq, rho, u)
return self.precisionPolicy.cast_to_output(fout)

@partial(jit, static_argnums=(0,), donate_argnums=(1,))
def collision_modified(self, f):
"""
KBC collision step for lattice.
"""
f = self.precisionPolicy.cast_to_compute(f)
tiny = 1e-32
beta = self.omega * 0.5
rho, u = self.update_macroscopic(f)
feq = self.equilibrium(rho, u, castOutput=False)

# Alternative KBC: only stabalizes for voxels whose entropy decreases after BGK collision.
f_bgk = f - self.omega * (f - feq)
H_fin = jnp.sum(f * jnp.log(f / self.w), axis=-1, keepdims=True)
H_fout = jnp.sum(f_bgk * jnp.log(f_bgk / self.w), axis=-1, keepdims=True)

# the rest is identical to collision_deprecated
fneq = f - feq
if self.dim == 2:
deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0
else:
deltaS = self.fdecompose_shear_d3q27(fneq) * rho
deltaH = fneq - deltaS
invBeta = 1.0 / beta
gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq))

f_kbc = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH)
fout = jnp.where(H_fout > H_fin, f_kbc, f_bgk)

# add external force
if self.force is not None:
fout = self.apply_force(fout, feq, rho, u)
return self.precisionPolicy.cast_to_output(fout)

@partial(jit, static_argnums=(0,), inline=True)
def entropic_scalar_product(self, x, y, feq):
Expand Down

0 comments on commit 1cc98ec

Please sign in to comment.