diff --git a/README.md b/README.md index 6b66008..4a7bac7 100644 --- a/README.md +++ b/README.md @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB: To use XLB, you must first install JAX and other dependencies using the following commands: -```bash -# Please refer to https://github.com/google/jax for the latest installation documentation - -pip install --upgrade pip -# For CPU run -pip install --upgrade "jax[cpu]" +Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax). -# For GPU run +| Hardware | Instructions | +|------------|-----------------------------------------------------------------------------------------------------------------| +| CPU | `pip install -U "jax[cpu]"` | +| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` | +| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | +| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | +| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | -# CUDA 12 and cuDNN 8.8 or newer. -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly. -# CUDA 11 and cuDNN 8.6 or newer. -pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -# Run dependencies +Install dependencies: +```bash pip install jmp pyvista numpy matplotlib Rtree trimesh jmp ``` @@ -118,6 +117,4 @@ export PYTHONPATH=. python3 examples/cavity2d.py ``` ## Citing XLB -Accompanying publication coming soon: - -**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA +Accompanying paper will be available soon. \ No newline at end of file diff --git a/examples/CFD/airfoil3d.py b/examples/CFD/airfoil3d.py index e601551..c33879a 100644 --- a/examples/CFD/airfoil3d.py +++ b/examples/CFD/airfoil3d.py @@ -31,8 +31,8 @@ # from IPython import display import matplotlib.pylab as plt from src.models import BGKSim, KBCSim +from src.lattice import LatticeD3Q19, LatticeD3Q27 from src.boundary_conditions import * -from src.lattice import * import numpy as np from src.utils import * from jax.config import config @@ -105,15 +105,13 @@ def output_data(self, **kwargs): airfoil_thickness = 30 airfoil_angle = 20 airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T - precision = 'f32/f32' - lattice = LatticeD3Q27(precision=precision) + + lattice = LatticeD3Q27(precision) nx = airfoil.shape[0] ny = airfoil.shape[1] - print("airfoil shape: ", airfoil.shape) - ny = 3 * ny nx = 4 * nx nz = 101 @@ -124,7 +122,6 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) os.system('rm -rf ./*.vtk && rm -rf ./*.png') @@ -141,5 +138,4 @@ def output_data(self, **kwargs): } sim = Airfoil(**kwargs) - print('Domain size: ', sim.nx, sim.ny, sim.nz) sim.run(20000) \ No newline at end of file diff --git a/examples/CFD/cavity2d.py b/examples/CFD/cavity2d.py index a7fae66..5692027 100644 --- a/examples/CFD/cavity2d.py +++ b/examples/CFD/cavity2d.py @@ -16,18 +16,18 @@ 4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview. """ -from src.boundary_conditions import * from jax.config import config -from src.utils import * import numpy as np -from src.lattice import LatticeD2Q9 -from src.models import BGKSim, KBCSim import jax.numpy as jnp import os +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD2Q9 +from src.utils import * + # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax class Cavity(KBCSim): def __init__(self, **kwargs): @@ -71,11 +71,10 @@ def output_data(self, **kwargs): clength = nx - 1 checkpoint_rate = 1000 - checkpoint_dir = "./checkpoints" + checkpoint_dir = os.path.abspath("./checkpoints") visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) - print("omega = ", omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") diff --git a/examples/CFD/cavity3d.py b/examples/CFD/cavity3d.py index d912786..58db262 100644 --- a/examples/CFD/cavity3d.py +++ b/examples/CFD/cavity3d.py @@ -19,14 +19,14 @@ # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -from src.models import BGKSim, KBCSim -from src.lattice import LatticeD3Q27 + import numpy as np from src.utils import * from jax.config import config -from src.boundary_conditions import * -precision = 'f32/f32' +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD3Q19, LatticeD3Q27 +from src.boundary_conditions import * class Cavity(KBCSim): def __init__(self, **kwargs): @@ -68,8 +68,6 @@ def output_data(self, **kwargs): # live_volume_randering(timestep, u_mag) if __name__ == '__main__': - lattice = LatticeD3Q27(precision) - nx = 101 ny = 101 nz = 101 @@ -78,9 +76,11 @@ def output_data(self, **kwargs): prescribed_vel = 0.1 clength = nx - 1 + precision = 'f32/f32' + lattice = LatticeD3Q27(precision) + visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") diff --git a/examples/CFD/channel3d.py b/examples/CFD/channel3d.py index b356ac9..e1a0cec 100644 --- a/examples/CFD/channel3d.py +++ b/examples/CFD/channel3d.py @@ -55,7 +55,7 @@ def get_dns_data(): } return dns_dic -class turbulentChannel(KBCSim): +class TurbulentChannel(KBCSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -68,7 +68,7 @@ def set_boundary_conditions(self): def initialize_macroscopic_fields(self): rho = self.precisionPolicy.cast_to_output(1.0) u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim), - self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim))) + self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim))) u = self.precisionPolicy.cast_to_output(u) return rho, u @@ -141,7 +141,6 @@ def output_data(self, **kwargs): zz = np.minimum(zz, zz.max() - zz) yplus = zz * u_tau / visc - print("omega = ", omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { diff --git a/examples/CFD/couette2d.py b/examples/CFD/couette2d.py index e14765c..05b60c5 100644 --- a/examples/CFD/couette2d.py +++ b/examples/CFD/couette2d.py @@ -2,14 +2,16 @@ This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM). """ -from src.models import BGKSim -from src.boundary_conditions import * -from src.lattice import LatticeD2Q9 +import os import jax.numpy as jnp import numpy as np from src.utils import * from jax.config import config -import os + + +from src.models import BGKSim +from src.boundary_conditions import * +from src.lattice import LatticeD2Q9 # config.update('jax_disable_jit', True) # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' @@ -60,7 +62,6 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) - print("omega = ", omega) assert omega < 1.98, "omega must be less than 2.0" os.system("rm -rf ./*.vtk && rm -rf ./*.png") diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py index 8196378..9fa9779 100644 --- a/examples/CFD/cylinder2d.py +++ b/examples/CFD/cylinder2d.py @@ -17,20 +17,20 @@ 5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView. """ - +import os +import jax from time import time -from src.boundary_conditions import * from jax.config import config -from src.utils import * import numpy as np -from src.lattice import LatticeD2Q9 -from src.models import BGKSim, KBCSim import jax.numpy as jnp -import os + +from src.utils import * +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD2Q9 # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax jax.config.update('jax_enable_x64', True) class Cylinder(KBCSim): @@ -93,9 +93,10 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f64/f64' + lattice = LatticeD2Q9(precision) + prescribed_vel = 0.005 diam = 80 - lattice = LatticeD2Q9(precision) nx = int(22*diam) ny = int(4.1*diam) diff --git a/examples/CFD/oscilating_cylinder2d.py b/examples/CFD/oscilating_cylinder2d.py index 8c7fcc3..f6db4d4 100644 --- a/examples/CFD/oscilating_cylinder2d.py +++ b/examples/CFD/oscilating_cylinder2d.py @@ -19,19 +19,20 @@ """ +import os +import jax from time import time -from src.boundary_conditions import * from jax.config import config -from src.utils import * import numpy as np -from src.lattice import LatticeD2Q9 -from src.models import BGKSim, KBCSim import jax.numpy as jnp -import os + +from src.utils import * +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim +from src.lattice import LatticeD2Q9 # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax jax.config.update('jax_enable_x64', True) class Cylinder(KBCSim): @@ -119,7 +120,6 @@ def output_data(self, **kwargs): if __name__ == '__main__': precision = 'f64/f64' lattice = LatticeD2Q9(precision) - prescribed_vel = 0.005 diam = 20 nx = int(22*diam) @@ -129,10 +129,6 @@ def output_data(self, **kwargs): visc = prescribed_vel * diam / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) - print("Mesh size: ", nx, ny) - print("Number of voxels: ", nx * ny) - os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { 'lattice': lattice, diff --git a/examples/CFD/taylor_green_vortex.py b/examples/CFD/taylor_green_vortex.py index 1071025..e52b7d3 100644 --- a/examples/CFD/taylor_green_vortex.py +++ b/examples/CFD/taylor_green_vortex.py @@ -5,18 +5,20 @@ """ -from src.boundary_conditions import * -from src.utils import * -import numpy as np -from src.lattice import LatticeD2Q9 -from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK import os -import matplotlib.pyplot as plt import json +import jax +import numpy as np +import matplotlib.pyplot as plt + +from src.utils import * +from src.boundary_conditions import * +from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK +from src.lattice import LatticeD2Q9 + # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax # disable JIt compilation jax.config.update('jax_enable_x64', True) @@ -37,9 +39,9 @@ def set_boundary_conditions(self): def initialize_macroscopic_fields(self): ux, uy, rho = taylor_green_initial_fields(xx, yy, vel_ref, 1, 0., 0.) - rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, initVal=1.0, sharding=self.sharding) + rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, init_val=1.0, sharding=self.sharding) u = np.stack([ux, uy], axis=-1) - u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, initVal=u, sharding=self.sharding) + u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, init_val=u, sharding=self.sharding) return rho, u def initialize_populations(self, rho, u): @@ -95,7 +97,6 @@ def output_data(self, **kwargs): visc = vel_ref * nx / Re omega = 1.0 / (3.0 * visc + 0.5) - print("omega = ", omega) os.system("rm -rf ./*.vtk && rm -rf ./*.png") kwargs = { 'lattice': lattice, diff --git a/examples/CFD/windtunnel3d.py b/examples/CFD/windtunnel3d.py index 780d04a..2f94b60 100644 --- a/examples/CFD/windtunnel3d.py +++ b/examples/CFD/windtunnel3d.py @@ -12,20 +12,21 @@ """ -from time import time +import os +import jax import trimesh -from src.boundary_conditions import * +from time import time +import numpy as np +import jax.numpy as jnp from jax.config import config + from src.utils import * -import numpy as np -from src.lattice import LatticeD3Q19, LatticeD3Q27 from src.models import BGKSim, KBCSim -import jax.numpy as jnp -import os +from src.lattice import LatticeD3Q19, LatticeD3Q27 +from src.boundary_conditions import * # Use 8 CPU devices # os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax # disable JIt compilation @@ -122,9 +123,6 @@ def output_data(self, **kwargs): visc = prescribed_vel * clength / Re omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) - print("Mesh size: ", nx, ny, nz) - print("Number of voxels: ", nx * ny * nz) os.system('rm -rf ./*.vtk && rm -rf ./*.png') kwargs = { diff --git a/examples/performance/MLUPS2d.py b/examples/performance/MLUPS2d.py index 30b2941..5eb7e86 100644 --- a/examples/performance/MLUPS2d.py +++ b/examples/performance/MLUPS2d.py @@ -3,17 +3,16 @@ """ import os - -from src.models import BGKSim -from src.lattice import LatticeD2Q9 - +import argparse import jax.numpy as jnp import numpy as np -from src.utils import * from jax.config import config from time import time -import argparse + +from src.utils import * from src.boundary_conditions import * +from src.lattice import LatticeD2Q9 +from src.models import BGKSim class Cavity(BGKSim): def __init__(self, **kwargs): diff --git a/examples/performance/MLUPS3d.py b/examples/performance/MLUPS3d.py index 671f484..164afe6 100644 --- a/examples/performance/MLUPS3d.py +++ b/examples/performance/MLUPS3d.py @@ -2,22 +2,22 @@ This script computes the MLUPS (Million Lattice Updates per Second) in 3D by simulating fluid flow inside a 2D cavity. """ -from src.models import BGKSim -from src.lattice import LatticeD3Q19 +import os +import argparse + +import jax import jax.numpy as jnp import numpy as np -from src.utils import * from jax.config import config -import os from time import time -import argparse -import jax #config.update('jax_disable_jit', True) # Use 8 CPU devices #os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' #config.update("jax_enable_x64", True) +from src.utils import * from src.boundary_conditions import * - +from src.models import BGKSim +from src.lattice import LatticeD3Q19 class Cavity(BGKSim): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -38,9 +38,7 @@ def set_boundary_conditions(self): if __name__ == '__main__': precision = 'f32/f32' - # Create a 3D lattice with the D3Q19 scheme lattice = LatticeD3Q19(precision) - # Create a parser that will read the command line arguments parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation") parser.add_argument("N", help="The total number of voxels all directions. The final dimension will be N*NxN", default=100, type=int) @@ -61,7 +59,6 @@ def set_boundary_conditions(self): visc = u_wall * clength / Re # Compute the relaxation parameter from the viscosity omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) kwargs = { 'lattice': lattice, diff --git a/examples/performance/MLUPS3d_distributed.py b/examples/performance/MLUPS3d_distributed.py index 2aa5814..4a418db 100644 --- a/examples/performance/MLUPS3d_distributed.py +++ b/examples/performance/MLUPS3d_distributed.py @@ -5,24 +5,36 @@ """ -from src.models import BGKSim -from src.lattice import LatticeD3Q19 +# Standard Libraries +import argparse +import os +import jax +# Initialize JAX distributed. The IP, number of processes and process id must be updated. +# Currently set on local host for testing purposes. +# Can be tested on a two GPU system as follows: +# (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &) +#IMPORTANT: jax distributed must be initialized before any jax computation is performed +jax.distributed.initialize(f'127.0.0.1:1234', 2, process_id=int(os.environ['CUDA_VISIBLE_DEVICES'])) + +print('Process id: ', jax.process_index()) +print('Number of total devices (over all processes): ', jax.device_count()) +print('Number of local devices:', jax.local_device_count()) + + import jax.numpy as jnp import numpy as np -from src.utils import * + from jax.config import config -import os -from time import time -import argparse -import jax -import portpicker + +from src.boundary_conditions import * +from src.models import BGKSim +from src.lattice import LatticeD3Q19 +from src.utils import * + #config.update('jax_disable_jit', True) # Use 8 CPU devices #os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' #config.update("jax_enable_x64", True) -from src.boundary_conditions import * - -precision = 'f32/f32' class Cavity(BGKSim): @@ -44,15 +56,7 @@ def set_boundary_conditions(self): self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) if __name__ == '__main__': - - # Initialize JAX distributed. The IP, number of processes and process id must be updated. - # Currently set on local host for testing purposes. - # Can be tested with - # (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &) - port = portpicker.pick_unused_port() - jax.distributed.initialize(f'127.0.0.1:1234', 2, int(os.environ['CUDA_VISIBLE_DEVICES'])) - - # Create a 3D lattice with the D3Q19 scheme + precision = 'f32/f32' lattice = LatticeD3Q19(precision) # Create a parser that will read the command line arguments @@ -75,7 +79,6 @@ def set_boundary_conditions(self): visc = u_wall * clength / Re # Compute the relaxation parameter from the viscosity omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) # Create a new instance of the Cavity class kwargs = { diff --git a/requirements.txt b/requirements.txt index 2c912ed..bc453d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,10 @@ -jax==0.4.11 -jaxlib==0.4.11 +jax==0.4.19 +jaxlib==0.4.19 jmp==0.0.4 -matplotlib==3.7.1 -numpy==1.24.2 -pyvista==0.38.5 +matplotlib==3.8.0 +numpy==1.26.1 +pyvista==0.42.3 Rtree==1.0.1 -trimesh==3.20.2 -orbax-checkpoint==0.2.3 -portpicker===1.5.2 +trimesh==4.0.0 +orbax-checkpoint==0.4.1 termcolor==2.3.0 \ No newline at end of file diff --git a/src/base.py b/src/base.py index c4dbf36..b359b6a 100644 --- a/src/base.py +++ b/src/base.py @@ -1,24 +1,28 @@ -from src.boundary_conditions import * -from jax.config import config -from src.utils import * -from functools import partial -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec -from jax.sharding import PositionalSharding -from jax.sharding import Mesh -from jax.experimental import mesh_utils -from jax.experimental.shard_map import shard_map -from jax.experimental.multihost_utils import process_allgather -from jax import jit, lax, vmap -from termcolor import colored -from orbax.checkpoint import * +# Standard Libraries +import os import time + +# Third-Party Libraries +import jax import jax.numpy as jnp -import numpy as np import jmp -import os -import jax +import numpy as np +from termcolor import colored + +# JAX-related imports +from jax import jit, lax, vmap +from jax.config import config +from jax.experimental import mesh_utils +from jax.experimental.multihost_utils import process_allgather +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding, PartitionSpec, PositionalSharding, Mesh +import orbax.checkpoint as orb +# functools imports +from functools import partial +# Local/Custom Libraries +import src.models +from src.utils import downsample_field jax.config.update("jax_spmd_mode", 'allow_all') # Disables annoying TF warnings @@ -36,26 +40,30 @@ class LBMBase(object): ny (int): Number of grid points in the y-direction. nz (int, optional): Number of grid points in the z-direction. Defaults to 0. precision (str, optional): A string specifying the precision used for the simulation. Defaults to "f32/f32". - optimize (bool, optional): Whether or not to run adjoint optimization (not functional yet). Defaults to False. """ def __init__(self, **kwargs): - # Set the precision for computation and storage - precision = kwargs.get("precision", "f32/f32") - computedType, storedType = self.set_precisions(precision) + self.omega = kwargs.get("omega") + self.nx = kwargs.get("nx") + self.ny = kwargs.get("ny") + self.nz = kwargs.get("nz") + + self.precision = kwargs.get("precision") + computedType, storedType = self.set_precisions(self.precision) self.precisionPolicy = jmp.Policy(compute_dtype=computedType, param_dtype=computedType, output_dtype=storedType) - self.optimize = kwargs.get("optimize", False) + self.lattice = kwargs.get("lattice") self.checkpointRate = kwargs.get("checkpoint_rate", 0) self.checkpointDir = kwargs.get("checkpoint_dir", './checkpoints') self.downsamplingFactor = kwargs.get("downsampling_factor", 1) - self.printInfoRate= kwargs.get("print_info_rate", 100) + self.printInfoRate = kwargs.get("print_info_rate", 100) self.ioRate = kwargs.get("io_rate", 0) self.returnFpost = kwargs.get("return_fpost", False) self.computeMLUPS = kwargs.get("compute_MLUPS", False) self.restore_checkpoint = kwargs.get("restore_checkpoint", False) self.nDevices = jax.device_count() + self.backend = jax.default_backend() if self.computeMLUPS: self.restore_checkpoint = False @@ -66,21 +74,7 @@ def __init__(self, **kwargs): # Check for distributed mode if self.nDevices > jax.local_device_count(): print("WARNING: Running in distributed mode. Make sure that jax.distributed.initialize is called before performing any JAX computations.") - print("XLA backend:", jax.default_backend()) - print("Number of XLA devices available: " + colored(f'{self.nDevices}', 'green')) - self.p_i = np.arange(self.nDevices) - - # Set the lattice and relaxation parameter - lattice = kwargs.get("lattice", None) - if lattice is None: - raise ValueError("lattice must be provided") - - omega = kwargs.get("omega", None) - if omega is None: - raise ValueError("omega must be provided") - - self.lattice = lattice - self.omega = omega + self.c = self.lattice.c self.q = self.lattice.q self.w = self.lattice.w @@ -88,10 +82,9 @@ def __init__(self, **kwargs): # Set the checkpoint manager if self.checkpointRate > 0: - mngr_options = CheckpointManagerOptions(save_interval_steps=self.checkpointRate, max_to_keep=1) - self.mngr = CheckpointManager(self.checkpointDir, PyTreeCheckpointer(), options=mngr_options) + mngr_options = orb.CheckpointManagerOptions(save_interval_steps=self.checkpointRate, max_to_keep=1) + self.mngr = orb.CheckpointManager(self.checkpointDir, orb.PyTreeCheckpointer(), options=mngr_options) else: - print("WARNING: Checkpointing is disabled for this simulation.") self.mngr = None # Adjust the number of grid points in the x direction, if necessary. @@ -107,14 +100,16 @@ def __init__(self, **kwargs): print("WARNING: nx increased from {} to {} in order to accommodate domain sharding per XLA device.".format(nx, self.nx)) self.ny = ny self.nz = nz + + self.show_simulation_parameters() # Store grid information self.gridInfo = { "nx": self.nx, "ny": self.ny, "nz": self.nz, - "dim": lattice.d, - "lattice": lattice + "dim": self.lattice.d, + "lattice": self.lattice } P = PartitionSpec @@ -124,7 +119,6 @@ def __init__(self, **kwargs): # Define the left permutation self.leftPerm = [((i + 1) % self.nDevices, i) for i in range(self.nDevices)] - # Set up the sharding and streaming for 2D and 3D simulations if self.dim == 2: self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1)) @@ -152,11 +146,213 @@ def __init__(self, **kwargs): raise ValueError(f"dim = {self.dim} not supported") # Compute the bounding box indices for boundary conditions - self.boundingBoxIndices = self.bounding_box_indices() + self.boundingBoxIndices= self.bounding_box_indices() # Create boundary data for the simulation self._create_boundary_data() self.force = self.get_force() + @property + def lattice(self): + return self._lattice + + @lattice.setter + def lattice(self, value): + if value is None: + raise ValueError("Lattice type must be provided.") + if self.nz == 0 and value.name not in ['D2Q9']: + raise ValueError("For 2D simulations, lattice type must be LatticeD2Q9.") + if self.nz != 0 and value.name not in ['D3Q19', 'D3Q27']: + raise ValueError("For 3D simulations, lattice type must be LatticeD3Q19, or LatticeD3Q27.") + + self._lattice = value + + @property + def omega(self): + return self._omega + + @omega.setter + def omega(self, value): + if value is None: + raise ValueError("omega must be provided") + if not isinstance(value, float): + raise TypeError("omega must be a float") + self._omega = value + + @property + def nx(self): + return self._nx + + @nx.setter + def nx(self, value): + if value is None: + raise ValueError("nx must be provided") + if not isinstance(value, int): + raise TypeError("nx must be an integer") + self._nx = value + + @property + def ny(self): + return self._ny + + @ny.setter + def ny(self, value): + if value is None: + raise ValueError("ny must be provided") + if not isinstance(value, int): + raise TypeError("ny must be an integer") + self._ny = value + + @property + def nz(self): + return self._nz + + @nz.setter + def nz(self, value): + if value is None: + raise ValueError("nz must be provided") + if not isinstance(value, int): + raise TypeError("nz must be an integer") + self._nz = value + + @property + def precision(self): + return self._precision + + @precision.setter + def precision(self, value): + if not isinstance(value, str): + raise TypeError("precision must be a string") + self._precision = value + + @property + def checkpointRate(self): + return self._checkpointRate + + @checkpointRate.setter + def checkpointRate(self, value): + if not isinstance(value, int): + raise TypeError("checkpointRate must be an integer") + self._checkpointRate = value + + @property + def checkpointDir(self): + return self._checkpointDir + + @checkpointDir.setter + def checkpointDir(self, value): + if not isinstance(value, str): + raise TypeError("checkpointDir must be a string") + self._checkpointDir = value + + @property + def downsamplingFactor(self): + return self._downsamplingFactor + + @downsamplingFactor.setter + def downsamplingFactor(self, value): + if not isinstance(value, int): + raise TypeError("downsamplingFactor must be an integer") + self._downsamplingFactor = value + + @property + def printInfoRate(self): + return self._printInfoRate + + @printInfoRate.setter + def printInfoRate(self, value): + if not isinstance(value, int): + raise TypeError("printInfoRate must be an integer") + self._printInfoRate = value + + @property + def ioRate(self): + return self._ioRate + + @ioRate.setter + def ioRate(self, value): + if not isinstance(value, int): + raise TypeError("ioRate must be an integer") + self._ioRate = value + + @property + def returnFpost(self): + return self._returnFpost + + @returnFpost.setter + def returnFpost(self, value): + if not isinstance(value, bool): + raise TypeError("returnFpost must be a boolean") + self._returnFpost = value + + @property + def computeMLUPS(self): + return self._computeMLUPS + + @computeMLUPS.setter + def computeMLUPS(self, value): + if not isinstance(value, bool): + raise TypeError("computeMLUPS must be a boolean") + self._computeMLUPS = value + + @property + def restore_checkpoint(self): + return self._restore_checkpoint + + @restore_checkpoint.setter + def restore_checkpoint(self, value): + if not isinstance(value, bool): + raise TypeError("restore_checkpoint must be a boolean") + self._restore_checkpoint = value + + @property + def nDevices(self): + return self._nDevices + + @nDevices.setter + def nDevices(self, value): + if not isinstance(value, int): + raise TypeError("nDevices must be an integer") + self._nDevices = value + + def show_simulation_parameters(self): + attributes_to_show = [ + 'omega', 'nx', 'ny', 'nz', 'dim', 'precision', 'lattice', + 'checkpointRate', 'checkpointDir', 'downsamplingFactor', + 'printInfoRate', 'ioRate', 'computeMLUPS', + 'restore_checkpoint', 'backend', 'nDevices' + ] + + descriptive_names = { + 'omega': 'Omega', + 'nx': 'Grid Points in X', + 'ny': 'Grid Points in Y', + 'nz': 'Grid Points in Z', + 'dim': 'Dimensionality', + 'precision': 'Precision Policy', + 'lattice': 'Lattice Type', + 'checkpointRate': 'Checkpoint Rate', + 'checkpointDir': 'Checkpoint Directory', + 'downsamplingFactor': 'Downsampling Factor', + 'printInfoRate': 'Print Info Rate', + 'ioRate': 'I/O Rate', + 'computeMLUPS': 'Compute MLUPS', + 'restore_checkpoint': 'Restore Checkpoint', + 'backend': 'Backend', + 'nDevices': 'Number of Devices' + } + simulation_name = self.__class__.__name__ + + print(colored(f'**** Simulation Parameters for {simulation_name} ****', 'green')) + + header = f"{colored('Parameter', 'blue'):>30} | {colored('Value', 'yellow')}" + print(header) + print('-' * 50) + + for attr in attributes_to_show: + value = getattr(self, attr, 'Attribute not set') + descriptive_name = descriptive_names.get(attr, attr) # Use the attribute name as a fallback + row = f"{colored(descriptive_name, 'blue'):>30} | {colored(value, 'yellow')}" + print(row) def _create_boundary_data(self): """ @@ -182,7 +378,7 @@ def _create_boundary_data(self): print("Time to create the local bitmasks and normal arrays:", time.time() - start) # This is another non-JITed way of creating the distributed arrays. It is not used at the moment. - # def distributed_array_init(self, shape, type, initVal=None): + # def distributed_array_init(self, shape, type, init_val=None): # sharding_dim = shape[0] // self.nDevices # sharded_shape = (self.nDevices, sharding_dim, *shape[1:]) # device_shape = sharded_shape[1:] @@ -190,16 +386,16 @@ def _create_boundary_data(self): # for d, index in self.sharding.addressable_devices_indices_map(sharded_shape).items(): # jax.default_device = d - # if initVal is None: + # if init_val is None: # x = jnp.zeros(shape=device_shape, dtype=type) # else: - # x = jnp.full(shape=device_shape, fill_value=initVal, dtype=type) + # x = jnp.full(shape=device_shape, fill_value=init_val, dtype=type) # arrays += [jax.device_put(x, d)] # jax.default_device = jax.devices()[0] # return jax.make_array_from_single_device_arrays(shape, self.sharding, arrays) @partial(jit, static_argnums=(0, 1, 2, 4)) - def distributed_array_init(self, shape, type, initVal=0, sharding=None): + def distributed_array_init(self, shape, type, init_val=0, sharding=None): """ Initialize a distributed array using JAX, with a specified shape, data type, and initial value. Optionally, provide a custom sharding strategy. @@ -208,7 +404,7 @@ def distributed_array_init(self, shape, type, initVal=0, sharding=None): ---------- shape (tuple): The shape of the array to be created. type (dtype): The data type of the array to be created. - initVal (scalar, optional): The initial value to fill the array with. Defaults to 0. + init_val (scalar, optional): The initial value to fill the array with. Defaults to 0. sharding (Sharding, optional): The sharding strategy to use. Defaults to `self.sharding`. Returns @@ -217,7 +413,7 @@ def distributed_array_init(self, shape, type, initVal=0, sharding=None): """ if sharding is None: sharding = self.sharding - x = jnp.full(shape=shape, fill_value=initVal, dtype=type) + x = jnp.full(shape=shape, fill_value=init_val, dtype=type) return jax.lax.with_sharding_constraint(x, sharding) @partial(jit, static_argnums=(0,)) @@ -237,7 +433,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): hw_x = self.nDevices hw_y = hw_z = 1 if self.dim == 2: - connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, initVal=True) + connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, init_val=True) connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None))].set(False) if solid_halo_voxels is not None: solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) @@ -248,7 +444,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels): return lax.with_sharding_constraint(connectivity_bitmask, self.sharding) elif self.dim == 3: - connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, initVal=True) + connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, init_val=True) connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(hw_z, -hw_z), slice(None))].set(False) if solid_halo_voxels is not None: solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) @@ -273,18 +469,18 @@ def bounding_box_indices(self): # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. # Each edge is represented as an array of indices. For example, the bottom edge includes # all points where the y-coordinate is 0, so its indices are [[i, 0] for i in range(self.nx)]. - boundingBox = {"bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), + bounding_box = {"bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), "top": np.array([[i, self.ny - 1] for i in range(self.nx)], dtype=int), "left": np.array([[0, i] for i in range(self.ny)], dtype=int), "right": np.array([[self.nx - 1, i] for i in range(self.ny)], dtype=int)} - return boundingBox + return bounding_box elif self.dim == 3: # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. # Each face is represented as an array of indices. For example, the bottom face includes all points # where the z-coordinate is 0, so its indices are [[i, j, 0] for i in range(self.nx) for j in range(self.ny)]. - boundingBox = { + bounding_box = { "bottom": np.array([[i, j, 0] for i in range(self.nx) for j in range(self.ny)], dtype=int), "top": np.array([[i, j, self.nz - 1] for i in range(self.nx) for j in range(self.ny)],dtype=int), "left": np.array([[0, j, k] for j in range(self.ny) for k in range(self.nz)], dtype=int), @@ -292,7 +488,7 @@ def bounding_box_indices(self): "front": np.array([[i, 0, k] for i in range(self.nx) for k in range(self.nz)], dtype=int), "back": np.array([[i, self.ny - 1, k] for i in range(self.nx) for k in range(self.nz)], dtype=int)} - return boundingBox + return bounding_box def set_precisions(self, precision): """ @@ -335,7 +531,7 @@ def initialize_macroscopic_fields(self): print(" To set explicit initial density and velocity, use self.initialize_macroscopic_fields.") return None, None - def assign_fields_sharded(self, checkpoint=None): + def assign_fields_sharded(self): """ This function is used to initialize the simulation by assigning the macroscopic fields and populations. @@ -362,7 +558,7 @@ def assign_fields_sharded(self, checkpoint=None): shape = (self.nx, self.ny, self.nz, self.lattice.q) if rho0 is None or u0 is None: - f = self.distributed_array_init(shape, self.precisionPolicy.output_dtype, initVal=self.w) + f = self.distributed_array_init(shape, self.precisionPolicy.output_dtype, init_val=self.w) else: f = self.initialize_populations(rho0, u0) @@ -575,13 +771,13 @@ def compute_bitmask_i(b, i): return vmap(compute_bitmask_i, in_axes=(None, 0), out_axes=-1)(b, self.lattice.i_s) @partial(jit, static_argnums=(0, 3), inline=True) - def equilibrium(self, rho, u, castOutput=True): + def equilibrium(self, rho, u, cast_output=True): """ This function computes the equilibrium distribution function in the Lattice Boltzmann Method. The equilibrium distribution function is a function of the macroscopic density and velocity. - The function first casts the density and velocity to the compute precision if the castOutput flag is True. - The function finally casts the equilibrium distribution function to the output precision if the castOutput + The function first casts the density and velocity to the compute precision if the cast_output flag is True. + The function finally casts the equilibrium distribution function to the output precision if the cast_output flag is True. Parameters @@ -590,7 +786,7 @@ def equilibrium(self, rho, u, castOutput=True): The macroscopic density. u: jax.numpy.ndarray The macroscopic velocity. - castOutput: bool, optional + cast_output: bool, optional A flag indicating whether to cast the density, velocity, and equilibrium distribution function to the compute and output precisions. Default is True. @@ -599,8 +795,8 @@ def equilibrium(self, rho, u, castOutput=True): feq: ja.numpy.ndarray The equilibrium distribution function. """ - # Cast the density and velocity to the compute precision if the castOutput flag is True - if castOutput: + # Cast the density and velocity to the compute precision if the cast_output flag is True + if cast_output: rho, u = self.precisionPolicy.cast_to_compute((rho, u)) # Cast c to compute precision so that XLA call FXX matmul, @@ -610,7 +806,7 @@ def equilibrium(self, rho, u, castOutput=True): usqr = 1.5 * jnp.sum(jnp.square(u), axis=-1, keepdims=True) feq = rho * self.w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) - if castOutput: + if cast_output: return self.precisionPolicy.cast_to_output(feq) else: return feq @@ -736,9 +932,6 @@ def step(self, f_poststreaming, timestep, return_fpost=False): return f_poststreaming, f_postcollision else: return f_poststreaming, None - - def checkpoint_manager(self): - pass def run(self, t_max): """ @@ -768,11 +961,11 @@ def run(self, t_max): assert self.mngr is not None, "Checkpoint manager does not exist." state = {'f': f} shardings = jax.tree_map(lambda x: x.sharding, state) - restore_args = checkpoint_utils.construct_restore_args(state, shardings) + restore_args = orb.checkpoint_utils.construct_restore_args(state, shardings) try: f = self.mngr.restore(latest_step, restore_kwargs={'restore_args': restore_args})['f'] print(f"Restored checkpoint at step {latest_step}.") - except: + except ValueError: raise ValueError(f"Failed to restore checkpoint at step {latest_step}.") start_step = latest_step + 1 @@ -940,7 +1133,7 @@ def get_force(self): force: jax.numpy.ndarray The force to be applied to the fluid. """ - return + pass @partial(jit, static_argnums=(0,), inline=True) def apply_force(self, f_postcollision, feq, rho, u): @@ -972,8 +1165,8 @@ def apply_force(self, f_postcollision, feq, rho, u): Boundary conditions. Physica A, 392, 1925-1930. Krüger, T., et al. (2017). The lattice Boltzmann method. Springer International Publishing, 10.978-3, 4-15. """ - deltaU = self.get_force() - feq_force = self.equilibrium(rho, u + deltaU, castOutput=False) + delta_u = self.get_force() + feq_force = self.equilibrium(rho, u + delta_u, cast_output=False) f_postcollision = f_postcollision + feq_force - feq return f_postcollision diff --git a/src/lattice.py b/src/lattice.py index 3052795..788796b 100644 --- a/src/lattice.py +++ b/src/lattice.py @@ -1,6 +1,6 @@ +import re import numpy as np import jax.numpy as jnp -import re class Lattice(object): @@ -75,7 +75,7 @@ def construct_right_indices(self): The indices of the right velocities. """ c = self.c.T - return np.where(c[:, 0] == 1)[0] + return np.nonzero(c[:, 0] == 1)[0] def construct_left_indices(self): """ @@ -88,7 +88,7 @@ def construct_left_indices(self): The indices of the left velocities. """ c = self.c.T - return np.where(c[:, 0] == -1)[0] + return np.nonzero(c[:, 0] == -1)[0] def construct_main_indices(self): """ @@ -103,10 +103,10 @@ def construct_main_indices(self): """ c = self.c.T if self.d == 2: - return np.where((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] elif self.d == 3: - return np.where((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] def construct_lattice_velocity(self): """ @@ -165,7 +165,7 @@ def construct_lattice_weight(self): w[0] = 1.0 / 3.0 elif self.name == "D3Q27": cl = np.linalg.norm(c, axis=1) - w[cl == 1.0] = 2.0 / 27.0 + w[np.isclose(cl, 1.0, atol=1e-8)] = 2.0 / 27.0 w[(cl > 1) & (cl <= np.sqrt(2))] = 1.0 / 54.0 w[(cl > np.sqrt(2)) & (cl <= np.sqrt(3))] = 1.0 / 216.0 w[0] = 8.0 / 27.0 @@ -202,7 +202,9 @@ def construct_lattice_moment(self): cntr += 1 return cc - + + def __str__(self): + return self.name class LatticeD2Q9(Lattice): """ diff --git a/src/models.py b/src/models.py index 7e5e825..a7af7b6 100644 --- a/src/models.py +++ b/src/models.py @@ -26,7 +26,7 @@ def collision(self, f): """ f = self.precisionPolicy.cast_to_compute(f) rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, castOutput=False) + feq = self.equilibrium(rho, u, cast_output=False) fneq = f - feq fout = f - self.omega * fneq if self.force is not None: @@ -39,8 +39,9 @@ class KBCSim(LBMBase): This class implements the Karlin-Bösch-Chikatamarla (KBC) model for the collision step in the Lattice Boltzmann Method. """ - def __init__(self, **kwargs): + if kwargs.get('lattice').name != 'D3Q27' and kwargs.get('nz') > 0: + raise ValueError("KBC collision operator in 3D must only be used with D3Q27 lattice.") super().__init__(**kwargs) @partial(jit, static_argnums=(0,), donate_argnums=(1,)) @@ -52,7 +53,7 @@ def collision(self, f): tiny = 1e-32 beta = self.omega * 0.5 rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, castOutput=False) + feq = self.equilibrium(rho, u, cast_output=False) fneq = f - feq if self.dim == 2: deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 @@ -212,7 +213,7 @@ def collision(self, f): """ f = self.precisionPolicy.cast_to_compute(f) rho =jnp.sum(f, axis=-1, keepdims=True) - feq = self.equilibrium(rho, self.vel, castOutput=False) + feq = self.equilibrium(rho, self.vel, cast_output=False) fneq = f - feq fout = f - self.omega * fneq return self.precisionPolicy.cast_to_output(fout) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index d7b2eb1..d937042 100644 --- a/src/utils.py +++ b/src/utils.py @@ -73,7 +73,7 @@ def save_image(timestep, fld, prefix=None): fname = fname + "_" + str(timestep).zfill(4) if len(fld.shape) > 3: - raise ValueError(f"The input field should be 2D!") + raise ValueError("The input field should be 2D!") elif len(fld.shape) == 3: fld = np.sqrt(fld[..., 0] ** 2 + fld[..., 1] ** 2)