Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boundary conditions #40

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![GitHub star chart](https://img.shields.io/github/stars/Autodesk/XLB?style=social)](https://star-history.com/#Autodesk/XLB)
<p align="center">
<img src="assets/logo-transparent.png" alt="" width="700">
<img src="assets/logo-transparent.png" alt="" width="300">
</p>

# XLB: Distributed Multi-GPU Lattice Boltzmann Simulation Framework for Differentiable Scientific Machine Learning
# XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning

XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It's built on top of the [JAX](https://github.com/google/jax) library and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning.

Expand All @@ -18,7 +18,7 @@ If you use XLB in your research, please cite the following paper:

```
@article{ataei2023xlb,
title={{XLB}: Distributed Multi-GPU Lattice Boltzmann Simulation Framework for Differentiable Scientific Machine Learning},
title={{XLB}: A Differentiable Massively Parallel Lattice Boltzmann Library in Python},
author={Ataei, Mohammadmehdi and Salehipour, Hesam},
journal={arXiv preprint arXiv:2311.16080},
year={2023},
Expand All @@ -33,7 +33,7 @@ If you use XLB in your research, please cite the following paper:
- **User-Friendly Interface:** Written entirely in Python, XLB emphasizes a highly accessible interface that allows users to extend the library with ease and quickly set up and run new simulations.
- **Leverages JAX Array and Shardmap:** The library incorporates the new JAX array unified array type and JAX shardmap, providing users with a numpy-like interface. This allows users to focus solely on the semantics, leaving performance optimizations to the compiler.
- **Platform Versatility:** The same XLB code can be executed on a variety of platforms including multi-core CPUs, single or multi-GPU systems, TPUs, and it also supports distributed runs on multi-GPU systems or TPU Pod slices.
- **Visualization:** XLB provides a variety of visualization options including in-situ rendering using [PhantomGaze](https://github.com/loliverhennigh/PhantomGaze).
- **Visualization:** XLB provides a variety of visualization options including in-situ on GPU rendering using [PhantomGaze](https://github.com/loliverhennigh/PhantomGaze).

## Showcase

Expand Down Expand Up @@ -153,4 +153,4 @@ git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
python3 examples/CFD/cavity2d.py
```
```
15 changes: 10 additions & 5 deletions examples/CFD/cylinder2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView.

# To run type:
nohup python3 examples/CFD/cylinder2d.py > logfile.log &
"""
import os
import json
Expand Down Expand Up @@ -50,8 +52,9 @@ def set_boundary_conditions(self):

# Outflow BC
outlet = self.boundingBoxIndices['right']
rho_outlet = np.ones(outlet.shape[0], dtype=self.precisionPolicy.compute_dtype)
rho_outlet = np.ones((outlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy))
# self.BCs.append(ZouHe(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet))

# Inlet BC
inlet = self.boundingBoxIndices['left']
Expand All @@ -78,7 +81,7 @@ def output_data(self, **kwargs):
if timestep == 0:
self.CL_max = 0.0
self.CD_max = 0.0
if timestep > 0.8 * t_max:
if timestep > 0.5 * niter_max:
# compute lift and drag over the cyliner
cylinder = self.BCs[0]
boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision'])
Expand All @@ -94,7 +97,7 @@ def output_data(self, **kwargs):
self.CL_max = max(self.CL_max, cl)
self.CD_max = max(self.CD_max, cd)
print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd))
save_image(timestep, u)
# save_image(timestep, u)

# Helper function to specify a parabolic poiseuille profile
poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2))
Expand Down Expand Up @@ -131,9 +134,11 @@ def output_data(self, **kwargs):
'print_info_rate': int(10000 / scale_factor),
'return_fpost': True # Need to retain fpost-collision for computation of lift and drag
}
# characteristic time
tc = prescribed_vel/diam
niter_max = int(100//tc)
sim = Cylinder(**kwargs)
t_max = int(1000000 / scale_factor)
sim.run(t_max)
sim.run(niter_max)
CL_list.append(sim.CL_max)
CD_list.append(sim.CD_max)

Expand Down
1 change: 1 addition & 0 deletions examples/CFD_refactor/windtunnel3d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os

Check failure on line 1 in examples/CFD_refactor/windtunnel3d.py

View check run for this annotation

Autodesk Chorus / security/bandit

syntax error

Bandit could not parse this file due to a syntax error. Fix any syntax errors and push a new commit. Please note that the annotation may not point to the line with the error. If this code is Python2, please migrate to Python3. Python2 is no longer supported and poses a security risk.
import jax
import trimesh
from time import time
Expand All @@ -14,6 +14,7 @@
from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC



class WindTunnel(IncompressibleNavierStokesSolver):
"""
This class extends the IncompressibleNavierStokesSolver class to define the boundary conditions for the wind tunnel simulation.
Expand Down
156 changes: 156 additions & 0 deletions examples/interfaces/boundary_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Simple script to run different boundary conditions with jax and warp backends
import time
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from typing import Any
import numpy as np
import jax.numpy as jnp
import warp as wp

wp.init()

import xlb

def run_boundary_conditions(backend):

# Set the compute backend
if backend == "warp":
compute_backend = xlb.ComputeBackend.WARP
elif backend == "jax":
compute_backend = xlb.ComputeBackend.JAX

# Set the precision policy
precision_policy = xlb.PrecisionPolicy.FP32FP32

# Set the velocity set
velocity_set = xlb.velocity_set.D3Q19()

# Make grid
nr = 256
shape = (nr, nr, nr)
if backend == "jax":
grid = xlb.grid.JaxGrid(shape=shape)
elif backend == "warp":
grid = xlb.grid.WarpGrid(shape=shape)

# Make feilds
f_pre = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32)
f_post = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32)
f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32)
boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8)
missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL)

# Make needed operators
equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC(
rho=1.0,
u=(0.0, 0.0, 0.0),
equilibrium_operator=equilibrium,
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
halfway_bounce_back_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
fullway_bounce_back_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)

# Make indices for boundary conditions (sphere)
sphere_radius = 32
x = np.arange(nr)
y = np.arange(nr)
z = np.arange(nr)
X, Y, Z = np.meshgrid(x, y, z)
indices = np.where(
(X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2
< sphere_radius**2
)
indices = np.array(indices).T
if backend == "jax":
indices = jnp.array(indices)
elif backend == "warp":
indices = wp.from_numpy(indices, dtype=wp.int32)

# Test equilibrium boundary condition
boundary_id, missing_mask = indices_boundary_masker(
indices,
equilibrium_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)
if backend == "jax":
f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask)
elif backend == "warp":
f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask, f)
print(f"Equilibrium BC test passed for {backend}")

# Test do nothing boundary condition
boundary_id, missing_mask = indices_boundary_masker(
indices,
do_nothing_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)
if backend == "jax":
f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask)
elif backend == "warp":
f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask, f)

# Test halfway bounce back boundary condition
boundary_id, missing_mask = indices_boundary_masker(
indices,
halfway_bounce_back_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)
if backend == "jax":
f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask)
elif backend == "warp":
f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f)
print(f"Halfway bounce back BC test passed for {backend}")

# Test the full boundary condition
boundary_id, missing_mask = indices_boundary_masker(
indices,
fullway_bounce_back_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)
if backend == "jax":
f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask)
elif backend == "warp":
f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f)
print(f"Fullway bounce back BC test passed for {backend}")


if __name__ == "__main__":

# Test the boundary conditions
backends = ["warp", "jax"]
for backend in backends:
run_boundary_conditions(backend)
Loading
Loading