Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/HEAD' into benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Oct 20, 2023
2 parents 25fcd4c + 2180c56 commit 6b3d5c2
Show file tree
Hide file tree
Showing 19 changed files with 407 additions and 228 deletions.
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB:

To use XLB, you must first install JAX and other dependencies using the following commands:

```bash
# Please refer to https://github.com/google/jax for the latest installation documentation

pip install --upgrade pip

# For CPU run
pip install --upgrade "jax[cpu]"
Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax).

# For GPU run
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

# CUDA 12 and cuDNN 8.8 or newer.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly.

# CUDA 11 and cuDNN 8.6 or newer.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Run dependencies
Install dependencies:
```bash
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp
```

Expand All @@ -118,6 +117,4 @@ export PYTHONPATH=.
python3 examples/cavity2d.py
```
## Citing XLB
Accompanying publication coming soon:

**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA
Accompanying paper will be available soon.
10 changes: 3 additions & 7 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
# from IPython import display
import matplotlib.pylab as plt
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *
from src.lattice import *
import numpy as np
from src.utils import *
from jax.config import config
Expand Down Expand Up @@ -105,15 +105,13 @@ def output_data(self, **kwargs):
airfoil_thickness = 30
airfoil_angle = 20
airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T

precision = 'f32/f32'
lattice = LatticeD3Q27(precision=precision)

lattice = LatticeD3Q27(precision)

nx = airfoil.shape[0]
ny = airfoil.shape[1]

print("airfoil shape: ", airfoil.shape)

ny = 3 * ny
nx = 4 * nx
nz = 101
Expand All @@ -124,7 +122,6 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

Expand All @@ -141,5 +138,4 @@ def output_data(self, **kwargs):
}

sim = Airfoil(**kwargs)
print('Domain size: ', sim.nx, sim.ny, sim.nz)
sim.run(20000)
13 changes: 6 additions & 7 deletions examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview.
"""
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9
from src.utils import *

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -71,11 +71,10 @@ def output_data(self, **kwargs):
clength = nx - 1

checkpoint_rate = 1000
checkpoint_dir = "./checkpoints"
checkpoint_dir = os.path.abspath("./checkpoints")

visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
14 changes: 8 additions & 6 deletions examples/CFD/cavity3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
"""
# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27

import numpy as np
from src.utils import *
from jax.config import config
from src.boundary_conditions import *
import json, codecs

precision = 'f64/f64'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *


config.update('jax_enable_x64', True)

class Cavity(KBCSim):
Expand Down Expand Up @@ -86,6 +89,7 @@ def output_data(self, **kwargs):
if __name__ == '__main__':
# Note:
# We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000
precision = 'f64/f64'
lattice = LatticeD3Q27(precision)

nx = 256
Expand All @@ -102,8 +106,6 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand Down
5 changes: 2 additions & 3 deletions examples/CFD/channel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dns_data():
}
return dns_dic

class turbulentChannel(KBCSim):
class TurbulentChannel(KBCSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -68,7 +68,7 @@ def set_boundary_conditions(self):
def initialize_macroscopic_fields(self):
rho = self.precisionPolicy.cast_to_output(1.0)
u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim),
self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
u = self.precisionPolicy.cast_to_output(u)
return rho, u

Expand Down Expand Up @@ -141,7 +141,6 @@ def output_data(self, **kwargs):
zz = np.minimum(zz, zz.max() - zz)
yplus = zz * u_tau / visc

print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand Down
11 changes: 6 additions & 5 deletions examples/CFD/couette2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM).
"""

from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9
import os
import jax.numpy as jnp
import numpy as np
from src.utils import *
from jax.config import config
import os


from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9

# config.update('jax_disable_jit', True)
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
Expand Down Expand Up @@ -60,7 +62,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re

omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
assert omega < 1.98, "omega must be less than 2.0"
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
39 changes: 17 additions & 22 deletions examples/CFD/cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView.
"""

import os
import json
import jax
from time import time
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os
import json

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(BGKSim):
Expand All @@ -47,31 +47,29 @@ def set_boundary_conditions(self):
cylinder = coord[cylinder]
implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (diam/2.)**2, (self.nx, self.ny))
self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy))
# self.BCs.append(BounceBackHalfway(tuple(cylinder.T), self.gridInfo, self.precisionPolicy))

# wall = np.concatenate([cylinder, self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']])
# self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy))

# Outflow BC
outlet = self.boundingBoxIndices['right']
rho_outlet = np.ones(outlet.shape[0], dtype=self.precisionPolicy.compute_dtype)
self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy))
# self.BCs.append(Regularized(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet))

# Inlet BC
inlet = self.boundingBoxIndices['left']
rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype)
yy_inlet = yy.reshape(self.nx, self.ny)[tuple(inlet.T)]
vel_inlet[:, 0] = poiseuille_profile(yy_inlet,
yy_inlet.min(),
yy_inlet.max()-yy_inlet.min(), 3.0 / 2.0 * prescribed_vel)
# self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet))
self.BCs.append(Regularized(tuple(inlet.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_inlet))

# No-slip BC for top and bottom
wall = np.concatenate([self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']])
self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy))
vel_wall = np.zeros(wall.shape, dtype=self.precisionPolicy.compute_dtype)
self.BCs.append(Regularized(tuple(wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall))

def output_data(self, **kwargs):
# 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back)
# 1:-1 to remove boundary voxels (not needed for visualization when using bounce-back)
rho = np.array(kwargs["rho"][..., 1:-1, :])
u = np.array(kwargs["u"][..., 1:-1, :])
timestep = kwargs["timestep"]
Expand All @@ -96,14 +94,15 @@ def output_data(self, **kwargs):
self.CL_max = max(self.CL_max, cl)
self.CD_max = max(self.CD_max, cd)
print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd))
# save_image(timestep, u)
save_image(timestep, u)

# Helper function to specify a parabolic poiseuille profile
poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2))

if __name__ == '__main__':
precision = 'f64/f64'
diam_list = [10, 20, 30, 40, 60, 80]
# diam_list = [10, 20, 30, 40, 60, 80]
diam_list = [80]
CL_list, CD_list = [], []
result_dict = {}
result_dict['resolution_list'] = diam_list
Expand All @@ -118,10 +117,6 @@ def output_data(self, **kwargs):
Re = 100.0
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

Check notice on line 121 in examples/CFD/cylinder2d.py

View check run for this annotation

Autodesk Chorus / security/bandit

B605: start_process_with_a_shell

Starting a process with a shell: Seems safe, but may be changed in the future, consider rewriting without shell secure coding id: PYTH-INJC-30.

Check notice on line 121 in examples/CFD/cylinder2d.py

View check run for this annotation

Autodesk Chorus / security/bandit

B607: start_process_with_partial_path

Starting a process with a partial executable path secure coding id: PYTH-INJC-30.

Expand Down
18 changes: 7 additions & 11 deletions examples/CFD/oscilating_cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
"""


import os
import jax
from time import time
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(KBCSim):
Expand Down Expand Up @@ -119,7 +120,6 @@ def output_data(self, **kwargs):
if __name__ == '__main__':
precision = 'f64/f64'
lattice = LatticeD2Q9(precision)

prescribed_vel = 0.005
diam = 20
nx = int(22*diam)
Expand All @@ -129,10 +129,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')
kwargs = {
'lattice': lattice,
Expand Down
Loading

0 comments on commit 6b3d5c2

Please sign in to comment.