Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed mixed-precision for Warp #67

Merged
merged 7 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@ XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that

## Accompanying Paper

Please refer to the [accompanying paper](https://arxiv.org/abs/2311.16080) for benchmarks, validation, and more details about the library.
Please refer to the [accompanying paper](https://doi.org/10.1016/j.cpc.2024.109187) for benchmarks, validation, and more details about the library.

## Citing XLB

If you use XLB in your research, please cite the following paper:

```
@article{ataei2023xlb,
title={{XLB}: A Differentiable Massively Parallel Lattice Boltzmann Library in Python},
author={Ataei, Mohammadmehdi and Salehipour, Hesam},
journal={arXiv preprint arXiv:2311.16080},
year={2023},
@article{ataei2024xlb,
title={{XLB}: A differentiable massively parallel lattice {Boltzmann} library in {Python}},
author={Ataei, Mohammadmehdi and Salehipour, Hesam},
journal={Computer Physics Communications},
volume={300},
pages={109187},
year={2024},
publisher={Elsevier}
}
```

Expand Down Expand Up @@ -153,4 +156,47 @@ git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
python3 examples/CFD/cavity2d.py
```
```
## Roadmap

### Work in Progress (WIP)
*Note: Some of the work-in-progress features can be found in the branches of the XLB repository. For contributions to these features, please reach out.*

- 🚀 **Warp Backend:** Achieving state-of-the-art performance by leveraging the [Warp](https://github.com/NVIDIA/warp) framework in combination with JAX.

- 🌐 **Grid Refinement:** Implementing adaptive mesh refinement techniques for enhanced simulation accuracy.

- ⚡ **Multi-GPU Acceleration using [Neon](https://github.com/Autodesk/Neon) + Warp:** Using Neon's data structure for improved scaling.

- 💾 **Out-of-Core Computations:** Enabling simulations that exceed available GPU memory, suitable for CPU+GPU coherent memory models such as NVIDIA's Grace Superchips.

- 🗜️ **GPU Accelerated Lossless Compression and Decompression**: Implementing high-performance lossless compression and decompression techniques for larger-scale simulations and improved performance.

- 🌡️ **Fluid-Thermal Simulation Capabilities:** Incorporating heat transfer and thermal effects into fluid simulations.

- 🎯 **Adjoint-based Shape and Topology Optimization:** Implementing gradient-based optimization techniques for design optimization.

- 🧠 **Machine Learning Accelerated Simulations:** Leveraging machine learning to speed up simulations and improve accuracy.

- 📉 **Reduced Order Modeling using Machine Learning:** Developing data-driven reduced-order models for efficient and accurate simulations.


### Wishlist
*Contributions to these features are welcome. Please submit PRs for the Wishlist items.*

- 🌊 **Free Surface Flows:** Simulating flows with free surfaces, such as water waves and droplets.

- 📡 **Electromagnetic Wave Propagation:** Simulating the propagation of electromagnetic waves.

- 🛩️ **Supersonic Flows:** Simulating supersonic flows.

- 🌊🧱 **Fluid-Solid Interaction:** Modeling the interaction between fluids and solid objects.

- 🧩 **Multiphase Flow Simulation:** Simulating flows with multiple immiscible fluids.

- 🔥 **Combustion:** Simulating combustion processes and reactive flows.

- 🪨 **Particle Flows and Discrete Element Method:** Incorporating particle-based methods for granular and particulate flows.

- 🔧 **Better Geometry Processing Pipelines:** Improving the handling and preprocessing of complex geometries for simulations.

2 changes: 1 addition & 1 deletion examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def setup_boundary_masker(self):
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0))

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setup_boundary_masker(self):
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def initialize_fields(self):
u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init)
else:
u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend, u=u_init)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init)

def setup_stepper(self):
force = self.get_force()
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def setup_boundary_masker(self):
self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC")
Expand Down
2 changes: 1 addition & 1 deletion examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main():
backend, precision_policy = setup_simulation(args)
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge)
f_0 = initialize_eq(f_0, grid, velocity_set, backend)
f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, backend)

elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps)
mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time)
Expand Down
6 changes: 3 additions & 3 deletions xlb/helper/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from xlb.operator.equilibrium import QuadraticEquilibrium


def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None):
def initialize_eq(f, grid, velocity_set, precision_policy, backend, rho=None, u=None):
if rho is None:
rho = grid.create_field(cardinality=1, fill_value=1.0)
rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision)
if u is None:
u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0)
u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0, dtype=precision_policy.compute_precision)
equilibrium = QuadraticEquilibrium()

if backend == ComputeBackend.JAX:
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -112,7 +112,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -137,7 +137,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand Down Expand Up @@ -270,7 +270,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def kernel2d(

# Write the result to the output
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -121,7 +121,7 @@ def kernel3d(

# Write the result to the output
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -136,7 +136,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -370,7 +370,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
if self.velocity_set.d == 3 and self.bc_type == "velocity":
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -378,7 +378,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
if self.velocity_set.d == 3 and self.bc_type == "velocity":
Expand Down
8 changes: 4 additions & 4 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _get_thread_data_2d(
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]
_f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1]])
_f_post[l] = self.compute_dtype(f_post[l, index[0], index[1]])

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
Expand All @@ -106,8 +106,8 @@ def _get_thread_data_3d(
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]
_f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]])
_f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]])

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1]] = _fout[l]
fout[l, index[0], index[1]] = self.store_dtype(_fout[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -86,7 +86,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1]] = _fout[l]
fout[l, index[0], index[1]] = self.store_dtype(_fout[l])

# Construct the warp kernel
@wp.kernel
Expand Down Expand Up @@ -369,7 +369,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

functional = functional3d if self.velocity_set.d == 3 else functional2d
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def kernel3d(

# Set the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = feq[l]
f[l, index[0], index[1], index[2]] = self.store_dtype(feq[l])

@wp.kernel
def kernel2d(
Expand All @@ -100,7 +100,7 @@ def kernel2d(

# Set the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = feq[l]
f[l, index[0], index[1]] = self.store_dtype(feq[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/force/exact_difference_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1]] = _fout[l]
fout[l, index[0], index[1]] = self.store_dtype(_fout[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -134,7 +134,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
return functional, kernel
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/macroscopic/first_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def kernel3d(
_u = functional(_f, _rho)

for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]
u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d])

@wp.kernel
def kernel2d(
Expand All @@ -71,7 +71,7 @@ def kernel2d(
_u = functional(_f, _rho)

for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]
u[d, index[0], index[1]] = self.store_dtype(_u[d])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
8 changes: 4 additions & 4 deletions xlb/operator/macroscopic/macroscopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def kernel3d(
_f[l] = f[l, index[0], index[1], index[2]]
_rho, _u = functional(_f)

rho[0, index[0], index[1], index[2]] = _rho
rho[0, index[0], index[1], index[2]] = self.store_dtype(_rho)
for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]
u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d])

@wp.kernel
def kernel2d(
Expand All @@ -68,9 +68,9 @@ def kernel2d(
_f[l] = f[l, index[0], index[1]]
_rho, _u = functional(_f)

rho[0, index[0], index[1]] = _rho
rho[0, index[0], index[1]] = self.store_dtype(_rho)
for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]
u[d, index[0], index[1]] = self.store_dtype(_u[d])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
Loading
Loading