Skip to content

Commit

Permalink
Added multi-GPU support and mlups computation
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Jan 29, 2024
1 parent a05441f commit 71c952b
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 102 deletions.
62 changes: 62 additions & 0 deletions examples/refactor/example_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import xlb
from xlb.compute_backends import ComputeBackends
from xlb.precision_policy import Fp32Fp32

from xlb.solver import IncompressibleNavierStokes
from xlb.grid import Grid
from xlb.operator.macroscopic import Macroscopic
from xlb.operator.equilibrium import QuadraticEquilibrium
from xlb.utils import save_fields_vtk, save_image

xlb.init(
precision_policy=Fp32Fp32,
compute_backend=ComputeBackends.JAX,
velocity_set=xlb.velocity_set.D2Q9,
)

grid_shape = (1000, 1000)
grid = Grid.create(grid_shape)


def initializer():
rho = grid.create_field(cardinality=1) + 1.0
u = grid.create_field(cardinality=2)

circle_center = (grid_shape[0] // 2, grid_shape[1] // 2)
circle_radius = 10

for x in range(grid_shape[0]):
for y in range(grid_shape[1]):
if (x - circle_center[0]) ** 2 + (
y - circle_center[1]
) ** 2 <= circle_radius**2:
rho = rho.at[0, x, y].add(0.001)

func_eq = QuadraticEquilibrium()
f_eq = func_eq(rho, u)

return f_eq


f = initializer()

compute_macro = Macroscopic()

solver = IncompressibleNavierStokes(grid, omega=1.0)


def perform_io(f, step):
rho, u = compute_macro(f)
fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1]}
save_fields_vtk(fields, step)
save_image(rho[0], step)
print(f"Step {step + 1} complete")


num_steps = 1000
io_rate = 100
for step in range(num_steps):
f = solver.step(f, timestep=step)

if step % io_rate == 0:
perform_io(f, step)
37 changes: 0 additions & 37 deletions examples/refactor/example_mehdi.py

This file was deleted.

53 changes: 53 additions & 0 deletions examples/refactor/mlups3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import xlb
import time
import jax
import argparse
from xlb.compute_backends import ComputeBackends
from xlb.precision_policy import Fp32Fp32
from xlb.operator.initializer import EquilibriumInitializer

from xlb.solver import IncompressibleNavierStokes
from xlb.grid import Grid

parser = argparse.ArgumentParser(
description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)"
)
parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid")
parser.add_argument("num_steps", type=int, help="Timestep for the simulation")

args = parser.parse_args()

cube_edge = args.cube_edge
num_steps = args.num_steps


xlb.init(
precision_policy=Fp32Fp32,
compute_backend=ComputeBackends.JAX,
velocity_set=xlb.velocity_set.D3Q19,
)

grid_shape = (cube_edge, cube_edge, cube_edge)
grid = Grid.create(grid_shape)

f = grid.create_field(cardinality=19, callback=EquilibriumInitializer(grid))

solver = IncompressibleNavierStokes(grid, omega=1.0)

# Ahead-of-Time Compilation to remove JIT overhead


if xlb.current_backend() == ComputeBackends.JAX:
lowered = jax.jit(solver.step).lower(f, timestep=0)
solver_step_compiled = lowered.compile()

start_time = time.time()

for step in range(num_steps):
f = solver_step_compiled(f, timestep=step)

end_time = time.time()
total_lattice_updates = cube_edge**3 * num_steps
total_time_seconds = end_time - start_time
mlups = (total_lattice_updates / total_time_seconds) / 1e6
print(f"MLUPS: {mlups}")
7 changes: 5 additions & 2 deletions xlb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


# Config
from .global_config import init
from .global_config import init, current_backend


# Precision policy
Expand All @@ -25,4 +25,7 @@
import xlb.grid

# Solvers
import xlb.solver
import xlb.solver

# Utils
import xlb.utils
4 changes: 4 additions & 0 deletions xlb/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ def init(velocity_set, compute_backend, precision_policy):
GlobalConfig.velocity_set = velocity_set()
GlobalConfig.compute_backend = compute_backend
GlobalConfig.precision_policy = precision_policy()


def current_backend():
return GlobalConfig.compute_backend
2 changes: 1 addition & 1 deletion xlb/grid/jax_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def initialize_jax_backend(self):
self.global_mesh = (
Mesh(device_mesh, axis_names=("cardinality", "x", "y"))
if self.dim == 2
else Mesh(self.devices, axis_names=("cardinality", "x", "y", "z"))
else Mesh(device_mesh, axis_names=("cardinality", "x", "y", "z"))
)
self.sharding = (
NamedSharding(self.global_mesh, P("cardinality", "x", "y"))
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
def jax_implementation(self, rho, u):
cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0))
usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True)
w = self.velocity_set.w.reshape(-1, 1, 1)
w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1))

feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
return feq
1 change: 1 addition & 0 deletions xlb/operator/initializer/equilibrium_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
velocity_set = velocity_set or GlobalConfig.velocity_set
compute_backend = compute_backend or GlobalConfig.compute_backend
local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1)

self.init_values = np.zeros(
grid.global_to_local_shape(grid.pop_shape)
) + velocity_set.w.reshape(local_shape)
Expand Down
30 changes: 24 additions & 6 deletions xlb/operator/macroscopic/macroscopic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Base class for all equilibriums
from xlb.global_config import GlobalConfig
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.compute_backends import ComputeBackends
from xlb.operator.operator import Operator


from functools import partial
import jax.numpy as jnp
from jax import jit

from xlb.velocity_set.velocity_set import VelocitySet
from xlb.compute_backends import ComputeBackends
from xlb.operator.operator import Operator


class Macroscopic(Operator):
"""
Expand All @@ -20,16 +21,33 @@ class Macroscopic(Operator):

def __init__(
self,
velocity_set: VelocitySet,
compute_backend=ComputeBackends.JAX,
velocity_set: VelocitySet = None,
compute_backend=None,
):
self.velocity_set = velocity_set or GlobalConfig.velocity_set
self.compute_backend = compute_backend or GlobalConfig.compute_backend

super().__init__(velocity_set, compute_backend)

@Operator.register_backend(ComputeBackends.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f):
"""
Apply the macroscopic operator to the lattice distribution function
TODO: Check if the following implementation is more efficient (
as the compiler may be able to remove operations resulting in zero)
c_x = tuple(self.velocity_set.c[0])
c_y = tuple(self.velocity_set.c[1])
u_x = 0.0
u_y = 0.0
rho = jnp.sum(f, axis=0, keepdims=True)
for i in range(self.velocity_set.q):
u_x += c_x[i] * f[i, ...]
u_y += c_y[i] * f[i, ...]
return rho, jnp.stack((u_x, u_y))
"""
rho = jnp.sum(f, axis=0, keepdims=True)
u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho
Expand Down
3 changes: 1 addition & 2 deletions xlb/operator/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None)
super().__init__(velocity_set, compute_backend)

@Operator.register_backend(ComputeBackends.JAX)
# @partial(jit, static_argnums=(0))
@partial(jit, static_argnums=(0))
def jax_implementation(self, f):
"""
JAX implementation of the streaming step.
Expand All @@ -38,7 +38,6 @@ def jax_implementation(self, f):
mesh=self.grid.global_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)(f)

def _streaming_jax_p(self, f):
Expand Down
6 changes: 4 additions & 2 deletions xlb/solver/nse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class IncompressibleNavierStokes(Solver):
def __init__(
self,
grid,
omega,
velocity_set: VelocitySet = None,
compute_backend=None,
precision_policy=None,
boundary_conditions=[],
collision_kernel="BGK",
):
self.grid = grid
self.omega = omega
self.collision_kernel = collision_kernel
super().__init__(velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, boundary_conditions=boundary_conditions)
self.create_operators()
Expand All @@ -39,13 +41,13 @@ def create_operators(self):
)
self.collision = (
KBC(
omega=1.0,
omega=self.omega,
velocity_set=self.velocity_set,
compute_backend=self.compute_backend,
)
if self.collision_kernel == "KBC"
else BGK(
omega=1.0,
omega=self.omega,
velocity_set=self.velocity_set,
compute_backend=self.compute_backend,
)
Expand Down
1 change: 1 addition & 0 deletions xlb/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import downsample_field, save_image, save_fields_vtk, save_BCs_vtk, rotate_geometry, voxelize_stl, axangle2mat
Loading

0 comments on commit 71c952b

Please sign in to comment.