Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some refactoring in the base class #18

Merged
merged 9 commits into from
Oct 18, 2023
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB:

To use XLB, you must first install JAX and other dependencies using the following commands:

```bash
# Please refer to https://github.com/google/jax for the latest installation documentation

pip install --upgrade pip

# For CPU run
pip install --upgrade "jax[cpu]"
Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax).

# For GPU run
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

# CUDA 12 and cuDNN 8.8 or newer.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly.

# CUDA 11 and cuDNN 8.6 or newer.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Run dependencies
Install dependencies:
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved
```bash
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp
```

Expand All @@ -118,6 +117,4 @@ export PYTHONPATH=.
python3 examples/cavity2d.py
```
## Citing XLB
Accompanying publication coming soon:

**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA
Accompanying paper will be available soon.
8 changes: 0 additions & 8 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import matplotlib.pylab as plt
from src.models import BGKSim, KBCSim
from src.boundary_conditions import *
from src.lattice import *
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
from src.utils import *
from jax.config import config
Expand Down Expand Up @@ -105,15 +104,11 @@ def output_data(self, **kwargs):
airfoil_thickness = 30
airfoil_angle = 20
airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T

precision = 'f32/f32'
lattice = LatticeD3Q27(precision=precision)

nx = airfoil.shape[0]
ny = airfoil.shape[1]

print("airfoil shape: ", airfoil.shape)

ny = 3 * ny
nx = 4 * nx
nz = 101
Expand All @@ -124,13 +119,11 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

# Set the parameters for the simulation
kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand All @@ -141,5 +134,4 @@ def output_data(self, **kwargs):
}

sim = Airfoil(**kwargs)
print('Domain size: ', sim.nx, sim.ny, sim.nz)
sim.run(20000)
7 changes: 1 addition & 6 deletions examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -61,7 +59,6 @@ def output_data(self, **kwargs):

if __name__ == "__main__":
precision = "f32/f32"
lattice = LatticeD2Q9(precision)

nx = 200
ny = 200
Expand All @@ -71,16 +68,14 @@ def output_data(self, **kwargs):
clength = nx - 1

checkpoint_rate = 1000
checkpoint_dir = "./checkpoints"
checkpoint_dir = os.path.abspath("./checkpoints")

visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
5 changes: 0 additions & 5 deletions examples/CFD/cavity3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q27
import numpy as np
from src.utils import *
from jax.config import config
Expand Down Expand Up @@ -68,8 +67,6 @@ def output_data(self, **kwargs):
# live_volume_randering(timestep, u_mag)

if __name__ == '__main__':
lattice = LatticeD3Q27(precision)

nx = 101
ny = 101
nz = 101
Expand All @@ -80,12 +77,10 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
6 changes: 2 additions & 4 deletions examples/CFD/channel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dns_data():
}
return dns_dic

class turbulentChannel(KBCSim):
class TurbulentChannel(KBCSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -68,7 +68,7 @@ def set_boundary_conditions(self):
def initialize_macroscopic_fields(self):
rho = self.precisionPolicy.cast_to_output(1.0)
u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim),
self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
u = self.precisionPolicy.cast_to_output(u)
return rho, u

Expand Down Expand Up @@ -141,11 +141,9 @@ def output_data(self, **kwargs):
zz = np.minimum(zz, zz.max() - zz)
yplus = zz * u_tau / visc

print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
4 changes: 0 additions & 4 deletions examples/CFD/couette2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9
import jax.numpy as jnp
import numpy as np
from src.utils import *
Expand Down Expand Up @@ -49,7 +48,6 @@ def output_data(self, **kwargs):

if __name__ == "__main__":
precision = "f32/f32"
lattice = LatticeD2Q9(precision)
nx = 501
ny = 101

Expand All @@ -60,12 +58,10 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re

omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
assert omega < 1.98, "omega must be less than 2.0"
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
3 changes: 0 additions & 3 deletions examples/CFD/cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os
Expand Down Expand Up @@ -95,7 +94,6 @@ def output_data(self, **kwargs):
precision = 'f64/f64'
prescribed_vel = 0.005
diam = 80
lattice = LatticeD2Q9(precision)

nx = int(22*diam)
ny = int(4.1*diam)
Expand All @@ -111,7 +109,6 @@ def output_data(self, **kwargs):
os.system('rm -rf ./*.vtk && rm -rf ./*.png')

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
7 changes: 0 additions & 7 deletions examples/CFD/oscilating_cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os
Expand Down Expand Up @@ -118,7 +117,6 @@ def output_data(self, **kwargs):

if __name__ == '__main__':
precision = 'f64/f64'
lattice = LatticeD2Q9(precision)

prescribed_vel = 0.005
diam = 20
Expand All @@ -129,13 +127,8 @@ def output_data(self, **kwargs):
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')
kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
5 changes: 2 additions & 3 deletions examples/CFD/taylor_green_vortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def set_boundary_conditions(self):

def initialize_macroscopic_fields(self):
ux, uy, rho = taylor_green_initial_fields(xx, yy, vel_ref, 1, 0., 0.)
rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, initVal=1.0, sharding=self.sharding)
rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, init_val=1.0, sharding=self.sharding)
u = np.stack([ux, uy], axis=-1)
u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, initVal=u, sharding=self.sharding)
u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, init_val=u, sharding=self.sharding)
return rho, u

def initialize_populations(self, rho, u):
Expand Down Expand Up @@ -95,7 +95,6 @@ def output_data(self, **kwargs):

visc = vel_ref * nx / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")
kwargs = {
'lattice': lattice,
Expand Down
6 changes: 0 additions & 6 deletions examples/CFD/windtunnel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os
Expand Down Expand Up @@ -109,7 +108,6 @@ def output_data(self, **kwargs):

if __name__ == '__main__':
precision = 'f32/f32'
lattice = LatticeD3Q27(precision)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be D3Q27 not D3Q19


nx = 601
ny = 351
Expand All @@ -122,13 +120,9 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny, nz)
print("Number of voxels: ", nx * ny * nz)
os.system('rm -rf ./*.vtk && rm -rf ./*.png')

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': nx,
'ny': ny,
Expand Down
8 changes: 2 additions & 6 deletions examples/performance/MLUPS2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
"""

import os

from src.models import BGKSim
from src.lattice import LatticeD2Q9

import jax.numpy as jnp
import numpy as np
from src.utils import *
from jax.config import config
from time import time
import argparse

from src.boundary_conditions import *
from src.models import BGKSim

class Cavity(BGKSim):
def __init__(self, **kwargs):
Expand All @@ -36,7 +34,6 @@ def set_boundary_conditions(self):

if __name__ == '__main__':
precision = 'f32/f32'
lattice = LatticeD2Q9(precision)

parser = argparse.ArgumentParser("simple_example")
parser.add_argument("N", help="The total number of voxels will be NxN", type=int)
Expand All @@ -54,7 +51,6 @@ def set_boundary_conditions(self):
print('omega = ', omega)

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': n,
'ny': n,
Expand Down
5 changes: 0 additions & 5 deletions examples/performance/MLUPS3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from src.models import BGKSim
from src.lattice import LatticeD3Q19
import jax.numpy as jnp
import numpy as np
from src.utils import *
Expand Down Expand Up @@ -38,8 +37,6 @@ def set_boundary_conditions(self):

if __name__ == '__main__':
precision = 'f32/f32'
# Create a 3D lattice with the D3Q19 scheme
lattice = LatticeD3Q19(precision)

# Create a parser that will read the command line arguments
parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation")
Expand All @@ -61,10 +58,8 @@ def set_boundary_conditions(self):
visc = u_wall * clength / Re
# Compute the relaxation parameter from the viscosity
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

kwargs = {
'lattice': lattice,
'omega': omega,
'nx': n,
'ny': n,
Expand Down
Loading