Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some refactoring in the base class #18

Merged
merged 9 commits into from
Oct 18, 2023
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB:

To use XLB, you must first install JAX and other dependencies using the following commands:

```bash
# Please refer to https://github.com/google/jax for the latest installation documentation

pip install --upgrade pip

# For CPU run
pip install --upgrade "jax[cpu]"
Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax).

# For GPU run
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

# CUDA 12 and cuDNN 8.8 or newer.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly.

# CUDA 11 and cuDNN 8.6 or newer.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Run dependencies
Install dependencies:
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved
```bash
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp
```

Expand All @@ -118,6 +117,4 @@ export PYTHONPATH=.
python3 examples/cavity2d.py
```
## Citing XLB
Accompanying publication coming soon:

**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA
Accompanying paper will be available soon.
10 changes: 3 additions & 7 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
# from IPython import display
import matplotlib.pylab as plt
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *
from src.lattice import *
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
from src.utils import *
from jax.config import config
Expand Down Expand Up @@ -105,15 +105,13 @@ def output_data(self, **kwargs):
airfoil_thickness = 30
airfoil_angle = 20
airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T

precision = 'f32/f32'
lattice = LatticeD3Q27(precision=precision)

lattice = LatticeD3Q27(precision)

nx = airfoil.shape[0]
ny = airfoil.shape[1]

print("airfoil shape: ", airfoil.shape)

ny = 3 * ny
nx = 4 * nx
nz = 101
Expand All @@ -124,7 +122,6 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

Expand All @@ -141,5 +138,4 @@ def output_data(self, **kwargs):
}

sim = Airfoil(**kwargs)
print('Domain size: ', sim.nx, sim.ny, sim.nz)
sim.run(20000)
13 changes: 6 additions & 7 deletions examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview.

"""
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9
from src.utils import *

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -71,11 +71,10 @@ def output_data(self, **kwargs):
clength = nx - 1

checkpoint_rate = 1000
checkpoint_dir = "./checkpoints"
checkpoint_dir = os.path.abspath("./checkpoints")

visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
14 changes: 7 additions & 7 deletions examples/CFD/cavity3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q27

import numpy as np
from src.utils import *
from jax.config import config
from src.boundary_conditions import *

precision = 'f32/f32'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -68,8 +68,6 @@ def output_data(self, **kwargs):
# live_volume_randering(timestep, u_mag)

if __name__ == '__main__':
lattice = LatticeD3Q27(precision)

nx = 101
ny = 101
nz = 101
Expand All @@ -78,9 +76,11 @@ def output_data(self, **kwargs):
prescribed_vel = 0.1
clength = nx - 1

precision = 'f32/f32'
lattice = LatticeD3Q27(precision)

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
5 changes: 2 additions & 3 deletions examples/CFD/channel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dns_data():
}
return dns_dic

class turbulentChannel(KBCSim):
class TurbulentChannel(KBCSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -68,7 +68,7 @@ def set_boundary_conditions(self):
def initialize_macroscopic_fields(self):
rho = self.precisionPolicy.cast_to_output(1.0)
u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim),
self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
u = self.precisionPolicy.cast_to_output(u)
return rho, u

Expand Down Expand Up @@ -141,7 +141,6 @@ def output_data(self, **kwargs):
zz = np.minimum(zz, zz.max() - zz)
yplus = zz * u_tau / visc

print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand Down
11 changes: 6 additions & 5 deletions examples/CFD/couette2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM).
"""

from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9
import os
import jax.numpy as jnp
import numpy as np
from src.utils import *
from jax.config import config
import os


from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9

# config.update('jax_disable_jit', True)
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
Expand Down Expand Up @@ -60,7 +62,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re

omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
assert omega < 1.98, "omega must be less than 2.0"
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
17 changes: 9 additions & 8 deletions examples/CFD/cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView.

"""

import os
import jax
from time import time
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(KBCSim):
Expand Down Expand Up @@ -93,9 +93,10 @@ def output_data(self, **kwargs):

if __name__ == '__main__':
precision = 'f64/f64'
lattice = LatticeD2Q9(precision)

prescribed_vel = 0.005
diam = 80
lattice = LatticeD2Q9(precision)

nx = int(22*diam)
ny = int(4.1*diam)
Expand Down
18 changes: 7 additions & 11 deletions examples/CFD/oscilating_cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
"""


import os
import jax
from time import time
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
jax.config.update('jax_enable_x64', True)

class Cylinder(KBCSim):
Expand Down Expand Up @@ -119,7 +120,6 @@ def output_data(self, **kwargs):
if __name__ == '__main__':
precision = 'f64/f64'
lattice = LatticeD2Q9(precision)

prescribed_vel = 0.005
diam = 20
nx = int(22*diam)
Expand All @@ -129,10 +129,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * diam / Re
omega = 1.0 / (3. * visc + 0.5)

print('omega = ', omega)
print("Mesh size: ", nx, ny)
print("Number of voxels: ", nx * ny)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')
kwargs = {
'lattice': lattice,
Expand Down
21 changes: 11 additions & 10 deletions examples/CFD/taylor_green_vortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
"""


from src.boundary_conditions import *
from src.utils import *
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK
import os
import matplotlib.pyplot as plt
import json
import jax
import numpy as np
import matplotlib.pyplot as plt

from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK
from src.lattice import LatticeD2Q9


# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
# disable JIt compilation

jax.config.update('jax_enable_x64', True)
Expand All @@ -37,9 +39,9 @@ def set_boundary_conditions(self):

def initialize_macroscopic_fields(self):
ux, uy, rho = taylor_green_initial_fields(xx, yy, vel_ref, 1, 0., 0.)
rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, initVal=1.0, sharding=self.sharding)
rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, init_val=1.0, sharding=self.sharding)
u = np.stack([ux, uy], axis=-1)
u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, initVal=u, sharding=self.sharding)
u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, init_val=u, sharding=self.sharding)
return rho, u

def initialize_populations(self, rho, u):
Expand Down Expand Up @@ -95,7 +97,6 @@ def output_data(self, **kwargs):

visc = vel_ref * nx / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")
kwargs = {
'lattice': lattice,
Expand Down
Loading