From b3acb4ac89fd58da608101c6525ac24a3d9534ba Mon Sep 17 00:00:00 2001 From: oliver Date: Fri, 15 Dec 2023 15:06:34 -0800 Subject: [PATCH] xlb refactor --- examples/backend_comparisons/README.md | 14 + .../backend_comparisons/lattice_boltzmann.py | 1136 ++++++++++++++++ examples/refactor/README.md | 31 + examples/refactor/example_jax.py | 107 ++ examples/refactor/example_jax_out_of_core.py | 336 +++++ examples/refactor/example_numba.py | 78 ++ setup.py | 11 + src/boundary_conditions.py | 1175 ----------------- src/lattice.py | 281 ---- src/models.py | 260 ---- xlb/__init__.py | 18 + {src => xlb}/base.py | 0 xlb/compute_backend.py | 9 + xlb/experimental/__init__.py | 1 + xlb/experimental/ooc/__init__.py | 2 + xlb/experimental/ooc/ooc_array.py | 485 +++++++ xlb/experimental/ooc/out_of_core.py | 110 ++ xlb/experimental/ooc/tiles/__init__.py | 0 xlb/experimental/ooc/tiles/compressed_tile.py | 273 ++++ xlb/experimental/ooc/tiles/dense_tile.py | 96 ++ xlb/experimental/ooc/tiles/dynamic_array.py | 72 + xlb/experimental/ooc/tiles/tile.py | 115 ++ xlb/experimental/ooc/utils.py | 79 ++ xlb/operator/__init__.py | 1 + xlb/operator/boundary_condition/__init__.py | 5 + .../boundary_condition/boundary_condition.py | 100 ++ xlb/operator/boundary_condition/do_nothing.py | 56 + .../equilibrium_boundary.py | 69 + .../boundary_condition/full_bounce_back.py | 57 + .../boundary_condition/halfway_bounce_back.py | 97 ++ xlb/operator/collision/__init__.py | 3 + xlb/operator/collision/bgk.py | 109 ++ xlb/operator/collision/collision.py | 48 + xlb/operator/collision/kbc.py | 203 +++ xlb/operator/equilibrium/__init__.py | 1 + xlb/operator/equilibrium/equilibrium.py | 88 ++ xlb/operator/macroscopic/__init__.py | 1 + xlb/operator/macroscopic/macroscopic.py | 38 + xlb/operator/operator.py | 81 ++ xlb/operator/stepper/__init__.py | 2 + xlb/operator/stepper/nse.py | 93 ++ xlb/operator/stepper/stepper.py | 84 ++ xlb/operator/stream/__init__.py | 1 + xlb/operator/stream/stream.py | 88 ++ xlb/physics_type.py | 7 + xlb/precision_policy/__init__.py | 2 + xlb/precision_policy/fp32fp32.py | 20 + xlb/precision_policy/precision_policy.py | 159 +++ {src => xlb/utils}/utils.py | 147 +-- xlb/velocity_set/__init__.py | 4 + xlb/velocity_set/d2q9.py | 26 + xlb/velocity_set/d3q19.py | 36 + xlb/velocity_set/d3q27.py | 32 + xlb/velocity_set/velocity_set.py | 200 +++ 54 files changed, 4734 insertions(+), 1813 deletions(-) create mode 100644 examples/backend_comparisons/README.md create mode 100644 examples/backend_comparisons/lattice_boltzmann.py create mode 100644 examples/refactor/README.md create mode 100644 examples/refactor/example_jax.py create mode 100644 examples/refactor/example_jax_out_of_core.py create mode 100644 examples/refactor/example_numba.py delete mode 100644 src/boundary_conditions.py delete mode 100644 src/lattice.py delete mode 100644 src/models.py create mode 100644 xlb/__init__.py rename {src => xlb}/base.py (100%) create mode 100644 xlb/compute_backend.py create mode 100644 xlb/experimental/__init__.py create mode 100644 xlb/experimental/ooc/__init__.py create mode 100644 xlb/experimental/ooc/ooc_array.py create mode 100644 xlb/experimental/ooc/out_of_core.py create mode 100644 xlb/experimental/ooc/tiles/__init__.py create mode 100644 xlb/experimental/ooc/tiles/compressed_tile.py create mode 100644 xlb/experimental/ooc/tiles/dense_tile.py create mode 100644 xlb/experimental/ooc/tiles/dynamic_array.py create mode 100644 xlb/experimental/ooc/tiles/tile.py create mode 100644 xlb/experimental/ooc/utils.py create mode 100644 xlb/operator/__init__.py create mode 100644 xlb/operator/boundary_condition/__init__.py create mode 100644 xlb/operator/boundary_condition/boundary_condition.py create mode 100644 xlb/operator/boundary_condition/do_nothing.py create mode 100644 xlb/operator/boundary_condition/equilibrium_boundary.py create mode 100644 xlb/operator/boundary_condition/full_bounce_back.py create mode 100644 xlb/operator/boundary_condition/halfway_bounce_back.py create mode 100644 xlb/operator/collision/__init__.py create mode 100644 xlb/operator/collision/bgk.py create mode 100644 xlb/operator/collision/collision.py create mode 100644 xlb/operator/collision/kbc.py create mode 100644 xlb/operator/equilibrium/__init__.py create mode 100644 xlb/operator/equilibrium/equilibrium.py create mode 100644 xlb/operator/macroscopic/__init__.py create mode 100644 xlb/operator/macroscopic/macroscopic.py create mode 100644 xlb/operator/operator.py create mode 100644 xlb/operator/stepper/__init__.py create mode 100644 xlb/operator/stepper/nse.py create mode 100644 xlb/operator/stepper/stepper.py create mode 100644 xlb/operator/stream/__init__.py create mode 100644 xlb/operator/stream/stream.py create mode 100644 xlb/physics_type.py create mode 100644 xlb/precision_policy/__init__.py create mode 100644 xlb/precision_policy/fp32fp32.py create mode 100644 xlb/precision_policy/precision_policy.py rename {src => xlb/utils}/utils.py (73%) create mode 100644 xlb/velocity_set/__init__.py create mode 100644 xlb/velocity_set/d2q9.py create mode 100644 xlb/velocity_set/d3q19.py create mode 100644 xlb/velocity_set/d3q27.py create mode 100644 xlb/velocity_set/velocity_set.py diff --git a/examples/backend_comparisons/README.md b/examples/backend_comparisons/README.md new file mode 100644 index 0000000..a198eb4 --- /dev/null +++ b/examples/backend_comparisons/README.md @@ -0,0 +1,14 @@ +# Performance Comparisons + +This directory contains a minimal LBM implementation in Warp, Numba, and Jax. The +code can be run with the following command: + +```bash +python3 lattice_boltzmann.py +``` + +This will give MLUPs numbers for each implementation. The Warp implementation +is the fastest, followed by Numba, and then Jax. + +This example should be used as a test for properly implementing more backends in +XLB. diff --git a/examples/backend_comparisons/lattice_boltzmann.py b/examples/backend_comparisons/lattice_boltzmann.py new file mode 100644 index 0000000..a8e1c39 --- /dev/null +++ b/examples/backend_comparisons/lattice_boltzmann.py @@ -0,0 +1,1136 @@ +# Description: This file contains a simple example of using the OOCmap +# decorator to apply a function to a distributed array. +# Solves Lattice Boltzmann Taylor Green vortex decay + +import time +import warp as wp +import matplotlib.pyplot as plt +from tqdm import tqdm +import numpy as np +import cupy as cp +import time +from tqdm import tqdm +from numba import cuda +import numba +import math +import jax.numpy as jnp +import jax +from jax import jit +from functools import partial + +# Initialize Warp +wp.init() + +@wp.func +def warp_set_f( + f: wp.array4d(dtype=float), + value: float, + q: int, + i: int, + j: int, + k: int, + width: int, + height: int, + length: int, +): + # Modulo + if i < 0: + i += width + if j < 0: + j += height + if k < 0: + k += length + if i >= width: + i -= width + if j >= height: + j -= height + if k >= length: + k -= length + f[q, i, j, k] = value + +@wp.kernel +def warp_collide_stream( + f0: wp.array4d(dtype=float), + f1: wp.array4d(dtype=float), + width: int, + height: int, + length: int, + tau: float, +): + + # get index + x, y, z = wp.tid() + + # sample needed points + f_1_1_1 = f0[0, x, y, z] + f_2_1_1 = f0[1, x, y, z] + f_0_1_1 = f0[2, x, y, z] + f_1_2_1 = f0[3, x, y, z] + f_1_0_1 = f0[4, x, y, z] + f_1_1_2 = f0[5, x, y, z] + f_1_1_0 = f0[6, x, y, z] + f_1_2_2 = f0[7, x, y, z] + f_1_0_0 = f0[8, x, y, z] + f_1_2_0 = f0[9, x, y, z] + f_1_0_2 = f0[10, x, y, z] + f_2_1_2 = f0[11, x, y, z] + f_0_1_0 = f0[12, x, y, z] + f_2_1_0 = f0[13, x, y, z] + f_0_1_2 = f0[14, x, y, z] + f_2_2_1 = f0[15, x, y, z] + f_0_0_1 = f0[16, x, y, z] + f_2_0_1 = f0[17, x, y, z] + f_0_2_1 = f0[18, x, y, z] + + # compute u and p + p = (f_1_1_1 + + f_2_1_1 + f_0_1_1 + + f_1_2_1 + f_1_0_1 + + f_1_1_2 + f_1_1_0 + + f_1_2_2 + f_1_0_0 + + f_1_2_0 + f_1_0_2 + + f_2_1_2 + f_0_1_0 + + f_2_1_0 + f_0_1_2 + + f_2_2_1 + f_0_0_1 + + f_2_0_1 + f_0_2_1) + u = (f_2_1_1 - f_0_1_1 + + f_2_1_2 - f_0_1_0 + + f_2_1_0 - f_0_1_2 + + f_2_2_1 - f_0_0_1 + + f_2_0_1 - f_0_2_1) + v = (f_1_2_1 - f_1_0_1 + + f_1_2_2 - f_1_0_0 + + f_1_2_0 - f_1_0_2 + + f_2_2_1 - f_0_0_1 + - f_2_0_1 + f_0_2_1) + w = (f_1_1_2 - f_1_1_0 + + f_1_2_2 - f_1_0_0 + - f_1_2_0 + f_1_0_2 + + f_2_1_2 - f_0_1_0 + - f_2_1_0 + f_0_1_2) + res_p = 1.0 / p + u = u * res_p + v = v * res_p + w = w * res_p + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + 1.0)) + f_eq_0_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + 1.0)) + f_eq_1_2_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_0_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_1_2 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + 1.0)) + f_eq_1_1_0 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + 1.0)) + f_eq_1_2_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + 1.0)) + f_eq_1_0_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + 1.0)) + f_eq_1_2_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + 1.0)) + f_eq_1_0_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + 1.0)) + f_eq_2_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + 1.0)) + f_eq_0_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + 1.0)) + f_eq_2_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + 1.0)) + f_eq_0_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + 1.0)) + f_eq_2_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + 1.0)) + f_eq_0_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + 1.0)) + f_eq_2_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + 1.0)) + f_eq_0_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + 1.0)) + + # set next lattice state + inv_tau = (1.0 / tau) + warp_set_f(f1, f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1), 0, x, y, z, width, height, length) + warp_set_f(f1, f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1), 1, x + 1, y, z, width, height, length) + warp_set_f(f1, f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1), 2, x - 1, y, z, width, height, length) + warp_set_f(f1, f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1), 3, x, y + 1, z, width, height, length) + warp_set_f(f1, f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1), 4, x, y - 1, z, width, height, length) + warp_set_f(f1, f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2), 5, x, y, z + 1, width, height, length) + warp_set_f(f1, f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0), 6, x, y, z - 1, width, height, length) + warp_set_f(f1, f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2), 7, x, y + 1, z + 1, width, height, length) + warp_set_f(f1, f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0), 8, x, y - 1, z - 1, width, height, length) + warp_set_f(f1, f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0), 9, x, y + 1, z - 1, width, height, length) + warp_set_f(f1, f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2), 10, x, y - 1, z + 1, width, height, length) + warp_set_f(f1, f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2), 11, x + 1, y, z + 1, width, height, length) + warp_set_f(f1, f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0), 12, x - 1, y, z - 1, width, height, length) + warp_set_f(f1, f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0), 13, x + 1, y, z - 1, width, height, length) + warp_set_f(f1, f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2), 14, x - 1, y, z + 1, width, height, length) + warp_set_f(f1, f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1), 15, x + 1, y + 1, z, width, height, length) + warp_set_f(f1, f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1), 16, x - 1, y - 1, z, width, height, length) + warp_set_f(f1, f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1), 17, x + 1, y - 1, z, width, height, length) + warp_set_f(f1, f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1), 18, x - 1, y + 1, z, width, height, length) + +@wp.kernel +def warp_initialize_taylor_green( + f: wp.array4d(dtype=wp.float32), + dx: float, + vel: float, + start_x: int, + start_y: int, + start_z: int, +): + + # get index + i, j, k = wp.tid() + + # get real pos + x = wp.float(i + start_x) * dx + y = wp.float(j + start_y) * dx + z = wp.float(k + start_z) * dx + + # compute u + u = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) + v = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + w = 0.0 + + # compute p + p = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * (wp.cos(2.0 * x) + wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0)) + + 1.0 + ) + + # compute u X u + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0.0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (-uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_1 - uxu) + + factor_2 * (exu_2_1_1 * exu_2_1_1) + + 1.0 + ) + ) + f_eq_0_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_1 - uxu) + + factor_2 * (exu_0_1_1 * exu_0_1_1) + + 1.0 + ) + ) + f_eq_1_2_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_0_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_1_2 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_2 - uxu) + + factor_2 * (exu_1_1_2 * exu_1_1_2) + + 1.0 + ) + ) + f_eq_1_1_0 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_0 - uxu) + + factor_2 * (exu_1_1_0 * exu_1_1_0) + + 1.0 + ) + ) + f_eq_1_2_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_2 - uxu) + + factor_2 * (exu_1_2_2 * exu_1_2_2) + + 1.0 + ) + ) + f_eq_1_0_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_0 - uxu) + + factor_2 * (exu_1_0_0 * exu_1_0_0) + + 1.0 + ) + ) + f_eq_1_2_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_0 - uxu) + + factor_2 * (exu_1_2_0 * exu_1_2_0) + + 1.0 + ) + ) + f_eq_1_0_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_2 - uxu) + + factor_2 * (exu_1_0_2 * exu_1_0_2) + + 1.0 + ) + ) + f_eq_2_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_2 - uxu) + + factor_2 * (exu_2_1_2 * exu_2_1_2) + + 1.0 + ) + ) + f_eq_0_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_0 - uxu) + + factor_2 * (exu_0_1_0 * exu_0_1_0) + + 1.0 + ) + ) + f_eq_2_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_0 - uxu) + + factor_2 * (exu_2_1_0 * exu_2_1_0) + + 1.0 + ) + ) + f_eq_0_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_2 - uxu) + + factor_2 * (exu_0_1_2 * exu_0_1_2) + + 1.0 + ) + ) + f_eq_2_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_2_1 - uxu) + + factor_2 * (exu_2_2_1 * exu_2_2_1) + + 1.0 + ) + ) + f_eq_0_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_0_1 - uxu) + + factor_2 * (exu_0_0_1 * exu_0_0_1) + + 1.0 + ) + ) + f_eq_2_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_0_1 - uxu) + + factor_2 * (exu_2_0_1 * exu_2_0_1) + + 1.0 + ) + ) + f_eq_0_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_2_1 - uxu) + + factor_2 * (exu_0_2_1 * exu_0_2_1) + + 1.0 + ) + ) + + # set next lattice state + f[0, i, j, k] = f_eq_1_1_1 + f[1, i, j, k] = f_eq_2_1_1 + f[2, i, j, k] = f_eq_0_1_1 + f[3, i, j, k] = f_eq_1_2_1 + f[4, i, j, k] = f_eq_1_0_1 + f[5, i, j, k] = f_eq_1_1_2 + f[6, i, j, k] = f_eq_1_1_0 + f[7, i, j, k] = f_eq_1_2_2 + f[8, i, j, k] = f_eq_1_0_0 + f[9, i, j, k] = f_eq_1_2_0 + f[10, i, j, k] = f_eq_1_0_2 + f[11, i, j, k] = f_eq_2_1_2 + f[12, i, j, k] = f_eq_0_1_0 + f[13, i, j, k] = f_eq_2_1_0 + f[14, i, j, k] = f_eq_0_1_2 + f[15, i, j, k] = f_eq_2_2_1 + f[16, i, j, k] = f_eq_0_0_1 + f[17, i, j, k] = f_eq_2_0_1 + f[18, i, j, k] = f_eq_0_2_1 + + +def warp_initialize_f(f, dx: float): + # Get inputs + cs = 1.0 / np.sqrt(3.0) + vel = 0.1 * cs + + # Launch kernel + wp.launch( + kernel=warp_initialize_taylor_green, + dim=list(f.shape[1:]), + inputs=[f, dx, vel, 0, 0, 0], + device=f.device, + ) + + return f + + +def warp_apply_collide_stream(f0, f1, tau: float): + # Apply streaming and collision step + wp.launch( + kernel=warp_collide_stream, + dim=list(f0.shape[1:]), + inputs=[f0, f1, f0.shape[1], f0.shape[2], f0.shape[3], tau], + device=f0.device, + ) + + return f1, f0 + + +@cuda.jit("void(float32[:,:,:,::1], float32, int32, int32, int32, int32, int32, int32, int32)", device=True) +def numba_set_f( + f: numba.cuda.cudadrv.devicearray.DeviceNDArray, + value: float, + q: int, + i: int, + j: int, + k: int, + width: int, + height: int, + length: int, +): + # Modulo + if i < 0: + i += width + if j < 0: + j += height + if k < 0: + k += length + if i >= width: + i -= width + if j >= height: + j -= height + if k >= length: + k -= length + f[i, j, k, q] = value + +#@cuda.jit +@cuda.jit("void(float32[:,:,:,::1], float32[:,:,:,::1], int32, int32, int32, float32)") +def numba_collide_stream( + f0: numba.cuda.cudadrv.devicearray.DeviceNDArray, + f1: numba.cuda.cudadrv.devicearray.DeviceNDArray, + width: int, + height: int, + length: int, + tau: float, +): + + x, y, z = cuda.grid(3) + + # sample needed points + f_1_1_1 = f0[x, y, z, 0] + f_2_1_1 = f0[x, y, z, 1] + f_0_1_1 = f0[x, y, z, 2] + f_1_2_1 = f0[x, y, z, 3] + f_1_0_1 = f0[x, y, z, 4] + f_1_1_2 = f0[x, y, z, 5] + f_1_1_0 = f0[x, y, z, 6] + f_1_2_2 = f0[x, y, z, 7] + f_1_0_0 = f0[x, y, z, 8] + f_1_2_0 = f0[x, y, z, 9] + f_1_0_2 = f0[x, y, z, 10] + f_2_1_2 = f0[x, y, z, 11] + f_0_1_0 = f0[x, y, z, 12] + f_2_1_0 = f0[x, y, z, 13] + f_0_1_2 = f0[x, y, z, 14] + f_2_2_1 = f0[x, y, z, 15] + f_0_0_1 = f0[x, y, z, 16] + f_2_0_1 = f0[x, y, z, 17] + f_0_2_1 = f0[x, y, z, 18] + + # compute u and p + p = (f_1_1_1 + + f_2_1_1 + f_0_1_1 + + f_1_2_1 + f_1_0_1 + + f_1_1_2 + f_1_1_0 + + f_1_2_2 + f_1_0_0 + + f_1_2_0 + f_1_0_2 + + f_2_1_2 + f_0_1_0 + + f_2_1_0 + f_0_1_2 + + f_2_2_1 + f_0_0_1 + + f_2_0_1 + f_0_2_1) + u = (f_2_1_1 - f_0_1_1 + + f_2_1_2 - f_0_1_0 + + f_2_1_0 - f_0_1_2 + + f_2_2_1 - f_0_0_1 + + f_2_0_1 - f_0_2_1) + v = (f_1_2_1 - f_1_0_1 + + f_1_2_2 - f_1_0_0 + + f_1_2_0 - f_1_0_2 + + f_2_2_1 - f_0_0_1 + - f_2_0_1 + f_0_2_1) + w = (f_1_1_2 - f_1_1_0 + + f_1_2_2 - f_1_0_0 + - f_1_2_0 + f_1_0_2 + + f_2_1_2 - f_0_1_0 + - f_2_1_0 + f_0_1_2) + res_p = numba.float32(1.0) / p + u = u * res_p + v = v * res_p + w = w * res_p + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = numba.float32(0.0) + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = numba.float32(1.5) + factor_2 = numba.float32(4.5) + weight_0 = numba.float32(0.33333333) + weight_1 = numba.float32(0.05555555) + weight_2 = numba.float32(0.02777777) + + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + numba.float32(1.0))) + f_eq_2_1_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + numba.float32(1.0))) + f_eq_0_1_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + numba.float32(1.0))) + f_eq_1_2_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + numba.float32(1.0))) + f_eq_1_0_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + numba.float32(1.0))) + f_eq_1_1_2 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + numba.float32(1.0))) + f_eq_1_1_0 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + numba.float32(1.0))) + f_eq_1_2_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + numba.float32(1.0))) + f_eq_1_0_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + numba.float32(1.0))) + f_eq_1_2_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + numba.float32(1.0))) + f_eq_1_0_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + numba.float32(1.0))) + f_eq_2_1_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + numba.float32(1.0))) + f_eq_0_1_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + numba.float32(1.0))) + f_eq_2_1_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + numba.float32(1.0))) + f_eq_0_1_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + numba.float32(1.0))) + f_eq_2_2_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + numba.float32(1.0))) + f_eq_0_0_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + numba.float32(1.0))) + f_eq_2_0_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + numba.float32(1.0))) + f_eq_0_2_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + numba.float32(1.0))) + + # set next lattice state + inv_tau = numba.float32((numba.float32(1.0) / tau)) + numba_set_f(f1, f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1), 0, x, y, z, width, height, length) + numba_set_f(f1, f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1), 1, x + 1, y, z, width, height, length) + numba_set_f(f1, f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1), 2, x - 1, y, z, width, height, length) + numba_set_f(f1, f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1), 3, x, y + 1, z, width, height, length) + numba_set_f(f1, f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1), 4, x, y - 1, z, width, height, length) + numba_set_f(f1, f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2), 5, x, y, z + 1, width, height, length) + numba_set_f(f1, f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0), 6, x, y, z - 1, width, height, length) + numba_set_f(f1, f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2), 7, x, y + 1, z + 1, width, height, length) + numba_set_f(f1, f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0), 8, x, y - 1, z - 1, width, height, length) + numba_set_f(f1, f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0), 9, x, y + 1, z - 1, width, height, length) + numba_set_f(f1, f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2), 10, x, y - 1, z + 1, width, height, length) + numba_set_f(f1, f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2), 11, x + 1, y, z + 1, width, height, length) + numba_set_f(f1, f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0), 12, x - 1, y, z - 1, width, height, length) + numba_set_f(f1, f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0), 13, x + 1, y, z - 1, width, height, length) + numba_set_f(f1, f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2), 14, x - 1, y, z + 1, width, height, length) + numba_set_f(f1, f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1), 15, x + 1, y + 1, z, width, height, length) + numba_set_f(f1, f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1), 16, x - 1, y - 1, z, width, height, length) + numba_set_f(f1, f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1), 17, x + 1, y - 1, z, width, height, length) + numba_set_f(f1, f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1), 18, x - 1, y + 1, z, width, height, length) + + +@cuda.jit +def numba_initialize_taylor_green( + f, + dx, + vel, + start_x, + start_y, + start_z, +): + + i, j, k = cuda.grid(3) + + # get real pos + x = numba.float32(i + start_x) * dx + y = numba.float32(j + start_y) * dx + z = numba.float32(k + start_z) * dx + + # compute u + u = vel * math.sin(x) * math.cos(y) * math.cos(z) + v = -vel * math.cos(x) * math.sin(y) * math.cos(z) + w = 0.0 + + # compute p + p = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * (math.cos(2.0 * x) + math.cos(2.0 * y) * (math.cos(2.0 * z) + 2.0)) + + 1.0 + ) + + # compute u X u + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0.0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (-uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_1 - uxu) + + factor_2 * (exu_2_1_1 * exu_2_1_1) + + 1.0 + ) + ) + f_eq_0_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_1 - uxu) + + factor_2 * (exu_0_1_1 * exu_0_1_1) + + 1.0 + ) + ) + f_eq_1_2_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_0_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_1_2 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_2 - uxu) + + factor_2 * (exu_1_1_2 * exu_1_1_2) + + 1.0 + ) + ) + f_eq_1_1_0 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_0 - uxu) + + factor_2 * (exu_1_1_0 * exu_1_1_0) + + 1.0 + ) + ) + f_eq_1_2_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_2 - uxu) + + factor_2 * (exu_1_2_2 * exu_1_2_2) + + 1.0 + ) + ) + f_eq_1_0_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_0 - uxu) + + factor_2 * (exu_1_0_0 * exu_1_0_0) + + 1.0 + ) + ) + f_eq_1_2_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_0 - uxu) + + factor_2 * (exu_1_2_0 * exu_1_2_0) + + 1.0 + ) + ) + f_eq_1_0_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_2 - uxu) + + factor_2 * (exu_1_0_2 * exu_1_0_2) + + 1.0 + ) + ) + f_eq_2_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_2 - uxu) + + factor_2 * (exu_2_1_2 * exu_2_1_2) + + 1.0 + ) + ) + f_eq_0_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_0 - uxu) + + factor_2 * (exu_0_1_0 * exu_0_1_0) + + 1.0 + ) + ) + f_eq_2_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_0 - uxu) + + factor_2 * (exu_2_1_0 * exu_2_1_0) + + 1.0 + ) + ) + f_eq_0_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_2 - uxu) + + factor_2 * (exu_0_1_2 * exu_0_1_2) + + 1.0 + ) + ) + f_eq_2_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_2_1 - uxu) + + factor_2 * (exu_2_2_1 * exu_2_2_1) + + 1.0 + ) + ) + f_eq_0_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_0_1 - uxu) + + factor_2 * (exu_0_0_1 * exu_0_0_1) + + 1.0 + ) + ) + f_eq_2_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_0_1 - uxu) + + factor_2 * (exu_2_0_1 * exu_2_0_1) + + 1.0 + ) + ) + f_eq_0_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_2_1 - uxu) + + factor_2 * (exu_0_2_1 * exu_0_2_1) + + 1.0 + ) + ) + + # set next lattice state + f[i, j, k, 0] = f_eq_1_1_1 + f[i, j, k, 1] = f_eq_2_1_1 + f[i, j, k, 2] = f_eq_0_1_1 + f[i, j, k, 3] = f_eq_1_2_1 + f[i, j, k, 4] = f_eq_1_0_1 + f[i, j, k, 5] = f_eq_1_1_2 + f[i, j, k, 6] = f_eq_1_1_0 + f[i, j, k, 7] = f_eq_1_2_2 + f[i, j, k, 8] = f_eq_1_0_0 + f[i, j, k, 9] = f_eq_1_2_0 + f[ i, j, k, 10] = f_eq_1_0_2 + f[ i, j, k, 11] = f_eq_2_1_2 + f[ i, j, k, 12] = f_eq_0_1_0 + f[ i, j, k, 13] = f_eq_2_1_0 + f[ i, j, k, 14] = f_eq_0_1_2 + f[ i, j, k, 15] = f_eq_2_2_1 + f[ i, j, k, 16] = f_eq_0_0_1 + f[ i, j, k, 17] = f_eq_2_0_1 + f[ i, j, k, 18] = f_eq_0_2_1 + + +def numba_initialize_f(f, dx: float): + # Get inputs + cs = 1.0 / np.sqrt(3.0) + vel = 0.1 * cs + + # Launch kernel + blockdim = (16, 16, 1) + griddim = ( + int(np.ceil(f.shape[0] / blockdim[0])), + int(np.ceil(f.shape[1] / blockdim[1])), + int(np.ceil(f.shape[2] / blockdim[2])), + ) + numba_initialize_taylor_green[griddim, blockdim]( + f, dx, vel, 0, 0, 0 + ) + + return f + +def numba_apply_collide_stream(f0, f1, tau: float): + # Apply streaming and collision step + blockdim = (8, 8, 8) + griddim = ( + int(np.ceil(f0.shape[0] / blockdim[0])), + int(np.ceil(f0.shape[1] / blockdim[1])), + int(np.ceil(f0.shape[2] / blockdim[2])), + ) + numba_collide_stream[griddim, blockdim]( + f0, f1, f0.shape[0], f0.shape[1], f0.shape[2], tau + ) + + return f1, f0 + +@partial(jit, static_argnums=(1), donate_argnums=(0)) +def jax_apply_collide_stream(f, tau: float): + + # Get f directions + f_1_1_1 = f[:, :, :, 0] + f_2_1_1 = f[:, :, :, 1] + f_0_1_1 = f[:, :, :, 2] + f_1_2_1 = f[:, :, :, 3] + f_1_0_1 = f[:, :, :, 4] + f_1_1_2 = f[:, :, :, 5] + f_1_1_0 = f[:, :, :, 6] + f_1_2_2 = f[:, :, :, 7] + f_1_0_0 = f[:, :, :, 8] + f_1_2_0 = f[:, :, :, 9] + f_1_0_2 = f[:, :, :, 10] + f_2_1_2 = f[:, :, :, 11] + f_0_1_0 = f[:, :, :, 12] + f_2_1_0 = f[:, :, :, 13] + f_0_1_2 = f[:, :, :, 14] + f_2_2_1 = f[:, :, :, 15] + f_0_0_1 = f[:, :, :, 16] + f_2_0_1 = f[:, :, :, 17] + f_0_2_1 = f[:, :, :, 18] + + # compute u and p + p = (f_1_1_1 + + f_2_1_1 + f_0_1_1 + + f_1_2_1 + f_1_0_1 + + f_1_1_2 + f_1_1_0 + + f_1_2_2 + f_1_0_0 + + f_1_2_0 + f_1_0_2 + + f_2_1_2 + f_0_1_0 + + f_2_1_0 + f_0_1_2 + + f_2_2_1 + f_0_0_1 + + f_2_0_1 + f_0_2_1) + u = (f_2_1_1 - f_0_1_1 + + f_2_1_2 - f_0_1_0 + + f_2_1_0 - f_0_1_2 + + f_2_2_1 - f_0_0_1 + + f_2_0_1 - f_0_2_1) + v = (f_1_2_1 - f_1_0_1 + + f_1_2_2 - f_1_0_0 + + f_1_2_0 - f_1_0_2 + + f_2_2_1 - f_0_0_1 + - f_2_0_1 + f_0_2_1) + w = (f_1_1_2 - f_1_1_0 + + f_1_2_2 - f_1_0_0 + - f_1_2_0 + f_1_0_2 + + f_2_1_2 - f_0_1_0 + - f_2_1_0 + f_0_1_2) + res_p = 1.0 / p + u = u * res_p + v = v * res_p + w = w * res_p + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + 1.0)) + f_eq_0_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + 1.0)) + f_eq_1_2_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_0_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_1_2 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + 1.0)) + f_eq_1_1_0 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + 1.0)) + f_eq_1_2_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + 1.0)) + f_eq_1_0_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + 1.0)) + f_eq_1_2_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + 1.0)) + f_eq_1_0_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + 1.0)) + f_eq_2_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + 1.0)) + f_eq_0_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + 1.0)) + f_eq_2_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + 1.0)) + f_eq_0_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + 1.0)) + f_eq_2_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + 1.0)) + f_eq_0_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + 1.0)) + f_eq_2_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + 1.0)) + f_eq_0_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + 1.0)) + + # set next lattice state + inv_tau = (1.0 / tau) + f_1_1_1 = f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1) + f_2_1_1 = f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1) + f_0_1_1 = f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1) + f_1_2_1 = f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1) + f_1_0_1 = f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1) + f_1_1_2 = f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2) + f_1_1_0 = f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0) + f_1_2_2 = f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2) + f_1_0_0 = f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0) + f_1_2_0 = f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0) + f_1_0_2 = f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2) + f_2_1_2 = f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2) + f_0_1_0 = f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0) + f_2_1_0 = f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0) + f_0_1_2 = f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2) + f_2_2_1 = f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1) + f_0_0_1 = f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1) + f_2_0_1 = f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1) + f_0_2_1 = f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1) + + # Roll fs and concatenate + f_2_1_1 = jnp.roll(f_2_1_1, -1, axis=0) + f_0_1_1 = jnp.roll(f_0_1_1, 1, axis=0) + f_1_2_1 = jnp.roll(f_1_2_1, -1, axis=1) + f_1_0_1 = jnp.roll(f_1_0_1, 1, axis=1) + f_1_1_2 = jnp.roll(f_1_1_2, -1, axis=2) + f_1_1_0 = jnp.roll(f_1_1_0, 1, axis=2) + f_1_2_2 = jnp.roll(jnp.roll(f_1_2_2, -1, axis=1), -1, axis=2) + f_1_0_0 = jnp.roll(jnp.roll(f_1_0_0, 1, axis=1), 1, axis=2) + f_1_2_0 = jnp.roll(jnp.roll(f_1_2_0, -1, axis=1), 1, axis=2) + f_1_0_2 = jnp.roll(jnp.roll(f_1_0_2, 1, axis=1), -1, axis=2) + f_2_1_2 = jnp.roll(jnp.roll(f_2_1_2, -1, axis=0), -1, axis=2) + f_0_1_0 = jnp.roll(jnp.roll(f_0_1_0, 1, axis=0), 1, axis=2) + f_2_1_0 = jnp.roll(jnp.roll(f_2_1_0, -1, axis=0), 1, axis=2) + f_0_1_2 = jnp.roll(jnp.roll(f_0_1_2, 1, axis=0), -1, axis=2) + f_2_2_1 = jnp.roll(jnp.roll(f_2_2_1, -1, axis=0), -1, axis=1) + f_0_0_1 = jnp.roll(jnp.roll(f_0_0_1, 1, axis=0), 1, axis=1) + f_2_0_1 = jnp.roll(jnp.roll(f_2_0_1, -1, axis=0), 1, axis=1) + f_0_2_1 = jnp.roll(jnp.roll(f_0_2_1, 1, axis=0), -1, axis=1) + + return jnp.stack( + [ + f_1_1_1, + f_2_1_1, + f_0_1_1, + f_1_2_1, + f_1_0_1, + f_1_1_2, + f_1_1_0, + f_1_2_2, + f_1_0_0, + f_1_2_0, + f_1_0_2, + f_2_1_2, + f_0_1_0, + f_2_1_0, + f_0_1_2, + f_2_2_1, + f_0_0_1, + f_2_0_1, + f_0_2_1, + ], + axis=-1, + ) + + + +if __name__ == "__main__": + + # Sim Parameters + n = 256 + tau = 0.505 + dx = 2.0 * np.pi / n + nr_steps = 128 + + # Bar plot + backend = [] + mlups = [] + + ######### Warp ######### + # Make f0, f1 + f0 = wp.empty((19, n, n, n), dtype=wp.float32, device="cuda:0") + f1 = wp.empty((19, n, n, n), dtype=wp.float32, device="cuda:0") + + # Initialize f0 + f0 = warp_initialize_f(f0, dx) + + # Apply streaming and collision + t0 = time.time() + for _ in tqdm(range(nr_steps)): + f0, f1 = warp_apply_collide_stream(f0, f1, tau) + wp.synchronize() + t1 = time.time() + + # Compute MLUPS + mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 + backend.append("Warp") + mlups.append(mlups) + + # Plot results + np_f = f0.numpy() + plt.imshow(np_f[3, :, :, 0]) + plt.colorbar() + plt.savefig("warp_f_.png") + plt.close() + + ######### Numba ######### + # Make f0, f1 + f0 = cp.ascontiguousarray(cp.empty((n, n, n, 19), dtype=np.float32)) + f1 = cp.ascontiguousarray(cp.empty((n, n, n, 19), dtype=np.float32)) + + # Initialize f0 + f0 = numba_initialize_f(f0, dx) + + # Apply streaming and collision + t0 = time.time() + for _ in tqdm(range(nr_steps)): + f0, f1 = numba_apply_collide_stream(f0, f1, tau) + cp.cuda.Device(0).synchronize() + t1 = time.time() + + # Compute MLUPS + mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 + backend.append("Numba") + mlups.append(mlups) + + # Plot results + np_f = f0 + plt.imshow(np_f[:, :, 0, 3].get()) + plt.colorbar() + plt.savefig("numba_f_.png") + plt.close() + + ######### Jax ######### + # Make f0, f1 + f = jnp.zeros((n, n, n, 19), dtype=jnp.float32) + + # Initialize f0 + # f = jax_initialize_f(f, dx) + + # Apply streaming and collision + t0 = time.time() + for _ in tqdm(range(nr_steps)): + f = jax_apply_collide_stream(f, tau) + t1 = time.time() + + # Compute MLUPS + mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 + backend.append("Jax") + mlups.append(mlups) + + # Plot results + np_f = f + plt.imshow(np_f[:, :, 0, 3]) + plt.colorbar() + plt.savefig("jax_f_.png") + plt.close() + + + + diff --git a/examples/refactor/README.md b/examples/refactor/README.md new file mode 100644 index 0000000..37b17e7 --- /dev/null +++ b/examples/refactor/README.md @@ -0,0 +1,31 @@ +# Refactor Examples + +This directory contains several example of using the refactored XLB library. + +These examples are not meant to be veiwed as the new interface to XLB but only how +to expose the compute kernels to a user. Development is still ongoing. + +## Examples + +### JAX Example + +The JAX example is a simple example of using the refactored XLB library +with JAX. The example is located in the `example_jax.py`. It shows +a very basic flow past a cyliner. + +### NUMBA Example + +TODO: Not working yet + +The NUMBA example is a simple example of using the refactored XLB library +with NUMBA. The example is located in the `example_numba.py`. It shows +a very basic flow past a cyliner. This example is not working yet though and +is still under development for numba backend. + +### Out of Core JAX Example + +This shoes how we can use out of core memory with JAX. The example is located +in the `example_jax_out_of_core.py`. It shows a very basic flow past a cyliner. +The basic idea is to create an out of core memory array using the implementation +in XLB. Then we run the simulation using the jax functions implementation obtained +from XLB. Some rendering is done using PhantomGaze. diff --git a/examples/refactor/example_jax.py b/examples/refactor/example_jax.py new file mode 100644 index 0000000..9092022 --- /dev/null +++ b/examples/refactor/example_jax.py @@ -0,0 +1,107 @@ +# from IPython import display +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt + +import xlb + +if __name__ == "__main__": + # Simulation parameters + nr = 128 + vel = 0.05 + visc = 0.00001 + omega = 1.0 / (3.0 * visc + 0.5) + length = 2 * np.pi + + # Geometry (sphere) + lin = np.linspace(0, length, nr) + X, Y, Z = np.meshgrid(lin, lin, lin, indexing="ij") + XYZ = np.stack([X, Y, Z], axis=-1) + radius = np.pi / 8.0 + + # XLB precision policy + precision_policy = xlb.precision_policy.Fp32Fp32() + + # XLB lattice + velocity_set = xlb.velocity_set.D3Q27() + + # XLB equilibrium + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(velocity_set=velocity_set) + + # XLB macroscopic + macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set) + + # XLB collision + collision = xlb.operator.collision.KBC(omega=omega, velocity_set=velocity_set) + + # XLB stream + stream = xlb.operator.stream.Stream(velocity_set=velocity_set) + + # XLB noslip boundary condition (sphere) + in_cylinder = ((X - np.pi/2.0)**2 + (Y - np.pi)**2 + (Z - np.pi)**2) < radius**2 + indices = np.argwhere(in_cylinder) + bounce_back = xlb.operator.boundary_condition.FullBounceBack.from_indices( + indices=indices, + velocity_set=velocity_set + ) + + # XLB outflow boundary condition + outflow = xlb.operator.boundary_condition.DoNothing.from_indices( + indices=np.argwhere(XYZ[..., 0] == length), + velocity_set=velocity_set + ) + + # XLB inflow boundary condition + inflow = xlb.operator.boundary_condition.EquilibriumBoundary.from_indices( + indices=np.argwhere(XYZ[..., 0] == 0.0), + velocity_set=velocity_set, + rho=1.0, + u=np.array([vel, 0.0, 0.0]), + equilibrium=equilibrium + ) + + # XLB stepper + stepper = xlb.operator.stepper.NSE( + collision=collision, + stream=stream, + equilibrium=equilibrium, + macroscopic=macroscopic, + boundary_conditions=[bounce_back, outflow, inflow], + precision_policy=precision_policy, + ) + + # Make initial condition + u = jnp.stack([vel * jnp.ones_like(X), jnp.zeros_like(X), jnp.zeros_like(X)], axis=-1) + rho = jnp.expand_dims(jnp.ones_like(X), axis=-1) + f = equilibrium(rho, u) + + # Get boundary id and mask + ijk = jnp.meshgrid(jnp.arange(nr), jnp.arange(nr), jnp.arange(nr), indexing="ij") + boundary_id, mask = stepper.set_boundary(jnp.stack(ijk, axis=-1)) + + # Run simulation + tic = time.time() + nr_iter = 4096 + for i in tqdm(range(nr_iter)): + f = stepper(f, boundary_id, mask, i) + + if i % 32 == 0: + # Get u, rho from f + rho, u = macroscopic(f) + norm_u = jnp.linalg.norm(u, axis=-1) + norm_u = (1.0 - jnp.minimum(boundary_id, 1.0)) * norm_u + + # Plot + plt.imshow(norm_u[..., nr//2], cmap="jet") + plt.colorbar() + plt.savefig(f"img_{str(i).zfill(5)}.png") + plt.close() + + # Sync to host + f = f.block_until_ready() + toc = time.time() + print(f"MLUPS: {(nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_jax_out_of_core.py b/examples/refactor/example_jax_out_of_core.py new file mode 100644 index 0000000..aeb604f --- /dev/null +++ b/examples/refactor/example_jax_out_of_core.py @@ -0,0 +1,336 @@ +# from IPython import display +import os +os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.7' + +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt +from mpi4py import MPI +import cupy as cp + +import xlb +from xlb.experimental.ooc import OOCmap, OOCArray + +import phantomgaze as pg + +comm = MPI.COMM_WORLD + +@jax.jit +def q_criterion(u): + # Compute derivatives + u_x = u[..., 0] + u_y = u[..., 1] + u_z = u[..., 2] + + # Compute derivatives + u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2 + u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2 + u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2 + u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2 + u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2 + u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2 + u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2 + u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2 + u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2 + + # Compute vorticity + mu_x = u_z_dy - u_y_dz + mu_y = u_x_dz - u_z_dx + mu_z = u_y_dx - u_x_dy + norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2) + + # Compute strain rate + s_0_0 = u_x_dx + s_0_1 = 0.5 * (u_x_dy + u_y_dx) + s_0_2 = 0.5 * (u_x_dz + u_z_dx) + s_1_0 = s_0_1 + s_1_1 = u_y_dy + s_1_2 = 0.5 * (u_y_dz + u_z_dy) + s_2_0 = s_0_2 + s_2_1 = s_1_2 + s_2_2 = u_z_dz + s_dot_s = ( + s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 + + s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 + + s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2 + ) + + # Compute omega + omega_0_0 = 0.0 + omega_0_1 = 0.5 * (u_x_dy - u_y_dx) + omega_0_2 = 0.5 * (u_x_dz - u_z_dx) + omega_1_0 = -omega_0_1 + omega_1_1 = 0.0 + omega_1_2 = 0.5 * (u_y_dz - u_z_dy) + omega_2_0 = -omega_0_2 + omega_2_1 = -omega_1_2 + omega_2_2 = 0.0 + omega_dot_omega = ( + omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 + + omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 + + omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2 + ) + + # Compute q-criterion + q = 0.5 * (omega_dot_omega - s_dot_s) + + return norm_mu, q + + +if __name__ == "__main__": + # Simulation parameters + nr = 256 + nx = 3 * nr + ny = nr + nz = nr + vel = 0.05 + visc = 0.00001 + omega = 1.0 / (3.0 * visc + 0.5) + length = 2 * np.pi + dx = length / (ny - 1) + radius = np.pi / 3.0 + + # OOC parameters + sub_steps = 8 + sub_nr = 128 + padding = (sub_steps, sub_steps, sub_steps, 0) + + # XLB precision policy + precision_policy = xlb.precision_policy.Fp32Fp32() + + # XLB lattice + velocity_set = xlb.velocity_set.D3Q27() + + # XLB equilibrium + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(velocity_set=velocity_set) + + # XLB macroscopic + macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set) + + # XLB collision + collision = xlb.operator.collision.KBC(omega=omega, velocity_set=velocity_set) + + # XLB stream + stream = xlb.operator.stream.Stream(velocity_set=velocity_set) + + # XLB noslip boundary condition (sphere) + # Create a mask function + def set_boundary_sphere(ijk, boundary_id, mask, id_number): + # Get XYZ + XYZ = ijk * dx + sphere_mask = jnp.linalg.norm(XYZ - length / 2.0, axis=-1) < radius + boundary_id = boundary_id.at[sphere_mask].set(id_number) + mask = mask.at[sphere_mask].set(True) + return boundary_id, mask + bounce_back = xlb.operator.boundary_condition.FullBounceBack( + set_boundary=set_boundary_sphere, + velocity_set=velocity_set + ) + + # XLB outflow boundary condition + def set_boundary_outflow(ijk, boundary_id, mask, id_number): + # Get XYZ + XYZ = ijk * dx + outflow_mask = XYZ[..., 0] >= (length * 3.0) - dx + boundary_id = boundary_id.at[outflow_mask].set(id_number) + mask = mask.at[outflow_mask].set(True) + return boundary_id, mask + outflow = xlb.operator.boundary_condition.DoNothing( + set_boundary=set_boundary_outflow, + velocity_set=velocity_set + ) + + # XLB inflow boundary condition + def set_boundary_inflow(ijk, boundary_id, mask, id_number): + # Get XYZ + XYZ = ijk * dx + inflow_mask = XYZ[..., 0] == 0.0 + boundary_id = boundary_id.at[inflow_mask].set(id_number) + mask = mask.at[inflow_mask].set(True) + return boundary_id, mask + inflow = xlb.operator.boundary_condition.EquilibriumBoundary( + set_boundary=set_boundary_inflow, + velocity_set=velocity_set, + rho=1.0, + u=np.array([vel, 0.0, 0.0]), + equilibrium=equilibrium + ) + + # XLB stepper + stepper = xlb.operator.stepper.NSE( + collision=collision, + stream=stream, + equilibrium=equilibrium, + macroscopic=macroscopic, + boundary_conditions=[bounce_back, outflow, inflow], + precision_policy=precision_policy, + ) + + # Make OOC arrays + f = OOCArray( + shape=(nx, ny, nz, velocity_set.q), + dtype=np.float32, + tile_shape=(sub_nr, sub_nr, sub_nr, velocity_set.q), + padding=padding, + comm=comm, + devices=[cp.cuda.Device(0) for i in range(comm.size)], + codec=None, + nr_compute_tiles=1, + ) + + camera_radius = length * 2.0 + focal_point = (3.0 * length / 2.0, length / 2.0, length / 2.0) + angle = 1 * 0.0001 + camera_position = (focal_point[0] + camera_radius * np.sin(angle), focal_point[1], focal_point[2] + camera_radius * np.cos(angle)) + camera = pg.Camera( + position=camera_position, + focal_point=focal_point, + view_up=(0.0, 1.0, 0.0), + height=1440, + width=2560, + max_depth=6.0 * length, + ) + screen_buffer = pg.ScreenBuffer.from_camera(camera) + + + # Initialize f + @OOCmap(comm, (0,), backend="jax") + def initialize_f(f): + # Get inputs + shape = f.shape[:-1] + u = jnp.stack([vel * jnp.ones(shape), jnp.zeros(shape), jnp.zeros(shape)], axis=-1) + rho = jnp.expand_dims(jnp.ones(shape), axis=-1) + f = equilibrium(rho, u) + return f + f = initialize_f(f) + + # Stepping function + @OOCmap(comm, (0,), backend="jax", add_index=True) + def ooc_stepper(f): + + # Get tensors + f, global_index = f + + # Get ijk + lin_i = jnp.arange(global_index[0], global_index[0] + f.shape[0]) + lin_j = jnp.arange(global_index[1], global_index[1] + f.shape[1]) + lin_k = jnp.arange(global_index[2], global_index[2] + f.shape[2]) + ijk = jnp.meshgrid(lin_i, lin_j, lin_k, indexing="ij") + ijk = jnp.stack(ijk, axis=-1) + + # Set boundary_id and mask + boundary_id, mask = stepper.set_boundary(ijk) + + # Run stepper + for _ in range(sub_steps): + f = stepper(f, boundary_id, mask, _) + + # Wait till f is computed using jax + f = f.block_until_ready() + + return f + + # Make a render function + @OOCmap(comm, (0,), backend="jax", add_index=True) + def render(f, screen_buffer, camera): + + # Get tensors + f, global_index = f + + # Get ijk + lin_i = jnp.arange(global_index[0], global_index[0] + f.shape[0]) + lin_j = jnp.arange(global_index[1], global_index[1] + f.shape[1]) + lin_k = jnp.arange(global_index[2], global_index[2] + f.shape[2]) + ijk = jnp.meshgrid(lin_i, lin_j, lin_k, indexing="ij") + ijk = jnp.stack(ijk, axis=-1) + + # Set boundary_id and mask + boundary_id, mask = stepper.set_boundary(ijk) + sphere = (boundary_id == 1).astype(jnp.float32)[1:-1, 1:-1, 1:-1] + + # Get rho, u + rho, u = macroscopic(f) + + # Get q-cr + norm_mu, q = q_criterion(u) + + # Make volumes + origin = ((global_index[0] + 1) * dx, (global_index[1] + 1) * dx, (global_index[2] + 1) * dx) + q_volume = pg.objects.Volume( + q, spacing=(dx, dx, dx), origin=origin + ) + norm_mu_volume = pg.objects.Volume( + norm_mu, spacing=(dx, dx, dx), origin=origin + ) + sphere_volume = pg.objects.Volume( + sphere, spacing=(dx, dx, dx), origin=origin + ) + + # Render + screen_buffer = pg.render.contour( + q_volume, + threshold=0.000005, + color=norm_mu_volume, + colormap=pg.Colormap("jet", vmin=0.0, vmax=0.025), + camera=camera, + screen_buffer=screen_buffer, + ) + screen_buffer = pg.render.contour( + sphere_volume, + threshold=0.5, + camera=camera, + screen_buffer=screen_buffer, + ) + + return f + + # Run simulation + tic = time.time() + nr_iter = 128 * nr // sub_steps + nr_frames = 1024 + for i in tqdm(range(nr_iter)): + f = ooc_stepper(f) + + if i % (nr_iter // nr_frames) == 0: + # Rotate camera + camera_radius = length * 1.0 + focal_point = (length / 2.0, length / 2.0, length / 2.0) + angle = (np.pi / nr_iter) * i + camera_position = (focal_point[0] + camera_radius * np.sin(angle), focal_point[1], focal_point[2] + camera_radius * np.cos(angle)) + camera = pg.Camera( + position=camera_position, + focal_point=focal_point, + view_up=(0.0, 1.0, 0.0), + height=1080, + width=1920, + max_depth=6.0 * length, + ) + + # Render global setup + screen_buffer = pg.render.wireframe( + lower_bound=(0.0, 0.0, 0.0), + upper_bound=(3.0*length, length, length), + thickness=length/100.0, + camera=camera, + ) + screen_buffer = pg.render.axes( + size=length/30.0, + center=(0.0, 0.0, length*1.1), + camera=camera, + screen_buffer=screen_buffer + ) + + # Render + render(f, screen_buffer, camera) + + # Save image + plt.imsave('./q_criterion_' + str(i).zfill(7) + '.png', np.minimum(screen_buffer.image.get(), 1.0)) + + # Sync to host + cp.cuda.runtime.deviceSynchronize() + toc = time.time() + print(f"MLUPS: {(sub_steps * nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_numba.py b/examples/refactor/example_numba.py new file mode 100644 index 0000000..9183067 --- /dev/null +++ b/examples/refactor/example_numba.py @@ -0,0 +1,78 @@ +# from IPython import display +import numpy as np +import cupy as cp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt +from numba import cuda, config + +import xlb + +config.CUDA_ARRAY_INTERFACE_SYNC = False + +if __name__ == "__main__": + # XLB precision policy + precision_policy = xlb.precision_policy.Fp32Fp32() + + # XLB lattice + lattice = xlb.lattice.D3Q19() + + # XLB collision model + collision = xlb.collision.BGK() + + # Make XLB compute kernels + compute = xlb.compute_constructor.NumbaConstructor( + lattice=lattice, + collision=collision, + boundary_conditions=[], + forcing=None, + precision_policy=precision_policy, + ) + + # Make taylor green vortex initial condition + tau = 0.505 + vel = 0.1 * 1.0 / np.sqrt(3.0) + nr = 256 + lin = cp.linspace(0, 2 * np.pi, nr, endpoint=False, dtype=cp.float32) + X, Y, Z = cp.meshgrid(lin, lin, lin, indexing="ij") + X = X[None, ...] + Y = Y[None, ...] + Z = Z[None, ...] + u = vel * cp.sin(X) * cp.cos(Y) * cp.cos(Z) + v = -vel * cp.cos(X) * cp.sin(Y) * cp.cos(Z) + w = cp.zeros_like(X) + rho = ( + 3.0 + * vel**2 + * (1.0 / 16.0) + * (cp.cos(2 * X) + cp.cos(2 * Y) + cp.cos(2 * Z)) + + 1.0) + u = cp.concatenate([u, v, w], axis=-1) + + # Allocate f + f0 = cp.zeros((19, nr, nr, nr), dtype=cp.float32) + f1 = cp.zeros((19, nr, nr, nr), dtype=cp.float32) + + # Get f from u, rho + compute.equilibrium(rho, u, f0) + + # Run compute kernel on f + tic = time.time() + nr_iter = 128 + for i in tqdm(range(nr_iter)): + compute.step(f0, f1, i) + f0, f1 = f1, f0 + + if i % 4 == 0: + ## Get u, rho from f + #rho, u = compute.macroscopic(f) + #norm_u = jnp.linalg.norm(u, axis=-1) + + # Plot + plt.imsave(f"img_{str(i).zfill(5)}.png", f0[8, nr // 2, :, :].get(), cmap="jet") + + # Sync to host + cp.cuda.stream.get_current_stream().synchronize() + toc = time.time() + print(f"MLUPS: {(nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/setup.py b/setup.py index e69de29..5ef3ed7 100644 --- a/setup.py +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + +setup( + name="XLB", + version="0.0.1", + author="", + packages=find_packages(), + install_requires=[ + ], + include_package_data=True, +) diff --git a/src/boundary_conditions.py b/src/boundary_conditions.py deleted file mode 100644 index cd98c91..0000000 --- a/src/boundary_conditions.py +++ /dev/null @@ -1,1175 +0,0 @@ -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -class BoundaryCondition(object): - """ - Base class for boundary conditions in a LBM simulation. - - This class provides a general structure for implementing boundary conditions. It includes methods for preparing the - boundary attributes and for applying the boundary condition. Specific boundary conditions should be implemented as - subclasses of this class, with the `apply` method overridden as necessary. - - Attributes - ---------- - lattice : Lattice - The lattice used in the simulation. - nx: - The number of nodes in the x direction. - ny: - The number of nodes in the y direction. - nz: - The number of nodes in the z direction. - dim : int - The number of dimensions in the simulation (2 or 3). - precision_policy : PrecisionPolicy - The precision policy used in the simulation. - indices : array-like - The indices of the boundary nodes. - name : str or None - The name of the boundary condition. This should be set in subclasses. - isSolid : bool - Whether the boundary condition is for a solid boundary. This should be set in subclasses. - isDynamic : bool - Whether the boundary condition is dynamic (changes over time). This should be set in subclasses. - needsExtraConfiguration : bool - Whether the boundary condition requires extra configuration. This should be set in subclasses. - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. This should be set in subclasses. - """ - - def __init__(self, indices, gridInfo, precision_policy): - self.lattice = gridInfo["lattice"] - self.nx = gridInfo["nx"] - self.ny = gridInfo["ny"] - self.nz = gridInfo["nz"] - self.dim = gridInfo["dim"] - self.precisionPolicy = precision_policy - self.indices = indices - self.name = None - self.isSolid = False - self.isDynamic = False - self.needsExtraConfiguration = False - self.implementationStep = "PostStreaming" - - def create_local_mask_and_normal_arrays(self, grid_mask): - - """ - Creates local mask and normal arrays for the boundary condition. - - Parameters - ---------- - grid_mask : array-like - The grid mask for the lattice. - - Returns - ------- - None - - Notes - ----- - This method creates local mask and normal arrays for the boundary condition based on the grid mask. - If the boundary condition requires extra configuration, the `configure` method is called. - """ - - if self.needsExtraConfiguration: - boundaryMask = self.get_boundary_mask(grid_mask) - self.configure(boundaryMask) - self.needsExtraConfiguration = False - - boundaryMask = self.get_boundary_mask(grid_mask) - self.normals = self.get_normals(boundaryMask) - self.imissing, self.iknown = self.get_missing_indices(boundaryMask) - self.imissingMask, self.iknownMask, self.imiddleMask = self.get_missing_mask(boundaryMask) - - return - - def get_boundary_mask(self, grid_mask): - """ - Add jax.device_count() to the self.indices in x-direction, and 1 to the self.indices other directions - This is to make sure the boundary condition is applied to the correct nodes as grid_mask is - expanded by (jax.device_count(), 1, 1) - - Parameters - ---------- - grid_mask : array-like - The grid mask for the lattice. - - Returns - ------- - boundaryMask : array-like - """ - shifted_indices = np.array(self.indices) - shifted_indices[0] += device_count() - shifted_indices[1:] += 1 - # Convert back to tuple - shifted_indices = tuple(shifted_indices) - boundaryMask = np.array(grid_mask[shifted_indices]) - - return boundaryMask - - def configure(self, boundaryMask): - """ - Configures the boundary condition. - - Parameters - ---------- - boundaryMask : array-like - The grid mask for the boundary voxels. - - Returns - ------- - None - - Notes - ----- - This method should be overridden in subclasses if the boundary condition requires extra configuration. - """ - return - - @partial(jit, static_argnums=(0, 3), inline=True) - def prepare_populations(self, fout, fin, implementation_step): - """ - Prepares the distribution functions for the boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The incoming distribution functions. - fin : jax.numpy.ndarray - The outgoing distribution functions. - implementation_step : str - The step in the lattice Boltzmann method algorithm at which the preparation is applied. - - Returns - ------- - jax.numpy.ndarray - The prepared distribution functions. - - Notes - ----- - This method should be overridden in subclasses if the boundary condition requires preparation of the distribution functions during post-collision or post-streaming. See ExtrapolationBoundaryCondition for an example. - """ - return fout - - def get_normals(self, boundaryMask): - """ - Calculates the normal vectors at the boundary nodes. - - Parameters - ---------- - boundaryMask : array-like - The boundary mask for the lattice. - - Returns - ------- - array-like - The normal vectors at the boundary nodes. - - Notes - ----- - This method calculates the normal vectors by dotting the boundary mask with the main lattice directions. - """ - main_c = self.lattice.c.T[self.lattice.main_indices] - m = boundaryMask[..., self.lattice.main_indices] - normals = -np.dot(m, main_c) - return normals - - def get_missing_indices(self, boundaryMask): - """ - Returns two int8 arrays the same shape as boundaryMask. The non-zero entries of these arrays indicate missing - directions that require BCs (imissing) as well as their corresponding opposite directions (iknown). - - Parameters - ---------- - boundaryMask : array-like - The boundary mask for the lattice. - - Returns - ------- - tuple of array-like - The missing and known indices for the boundary condition. - - Notes - ----- - This method calculates the missing and known indices based on the boundary mask. The missing indices are the - non-zero entries of the boundary mask, and the known indices are their corresponding opposite directions. - """ - - # Find imissing, iknown 1-to-1 corresponding indices - # Note: the "zero" index is used as default value here and won't affect BC computations - nbd = len(self.indices[0]) - imissing = np.vstack([np.arange(self.lattice.q, dtype='uint8')] * nbd) - iknown = np.vstack([self.lattice.opp_indices] * nbd) - imissing[~boundaryMask] = 0 - iknown[~boundaryMask] = 0 - return imissing, iknown - - def get_missing_mask(self, boundaryMask): - """ - Returns three boolean arrays the same shape as boundaryMask. - Note: these boundary masks are useful for reduction (eg. summation) operators of selected q-directions. - - Parameters - ---------- - boundaryMask : array-like - The boundary mask for the lattice. - - Returns - ------- - tuple of array-like - The missing, known, and middle masks for the boundary condition. - - Notes - ----- - This method calculates the missing, known, and middle masks based on the boundary mask. The missing mask - is the boundary mask, the known mask is the opposite directions of the missing mask, and the middle mask - is the directions that are neither missing nor known. - """ - # Find masks for imissing, iknown and imiddle - imissingMask = boundaryMask - iknownMask = imissingMask[:, self.lattice.opp_indices] - imiddleMask = ~(imissingMask | iknownMask) - return imissingMask, iknownMask, imiddleMask - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - None - - Notes - ----- - This method should be overridden in subclasses to implement the specific boundary condition. The method should - modify the output distribution functions in place to apply the boundary condition. - """ - pass - - @partial(jit, static_argnums=(0,)) - def equilibrium(self, rho, u): - """ - Compute equilibrium distribution function. - - Parameters - ---------- - rho : jax.numpy.ndarray - The density at each node in the lattice. - u : jax.numpy.ndarray - The velocity at each node in the lattice. - - Returns - ------- - jax.numpy.ndarray - The equilibrium distribution function at each node in the lattice. - - Notes - ----- - This method computes the equilibrium distribution function based on the density and velocity. The computation is - performed in the compute precision specified by the precision policy. The result is not cast to the output precision as - this is function is used inside other functions that require the compute precision. - """ - rho, u = self.precisionPolicy.cast_to_compute((rho, u)) - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 3.0 * jnp.dot(u, c) - usqr = 1.5 * jnp.sum(u**2, axis=-1, keepdims=True) - feq = rho * self.lattice.w * (1.0 + 1.0 * cu + 0.5 * cu**2 - usqr) - - return feq - - @partial(jit, static_argnums=(0,)) - def momentum_flux(self, fneq): - """ - Compute the momentum flux. - - Parameters - ---------- - fneq : jax.numpy.ndarray - The non-equilibrium distribution function at each node in the lattice. - - Returns - ------- - jax.numpy.ndarray - The momentum flux at each node in the lattice. - - Notes - ----- - This method computes the momentum flux by dotting the non-equilibrium distribution function with the lattice - direction vectors. - """ - return jnp.dot(fneq, self.lattice.cc) - - @partial(jit, static_argnums=(0,)) - def momentum_exchange_force(self, f_poststreaming, f_postcollision): - """ - Using the momentum exchange method to compute the boundary force vector exerted on the solid geometry - based on [1] as described in [3]. Ref [2] shows how [1] is applicable to curved geometries only by using a - bounce-back method (e.g. Bouzidi) that accounts for curved boundaries. - NOTE: this function should be called after BC's are imposed. - [1] A.J.C. Ladd, Numerical simulations of particular suspensions via a discretized Boltzmann equation. - Part 2 (numerical results), J. Fluid Mech. 271 (1994) 311-339. - [2] R. Mei, D. Yu, W. Shyy, L.-S. Luo, Force evaluation in the lattice Boltzmann method involving - curved geometry, Phys. Rev. E 65 (2002) 041203. - [3] Caiazzo, A., & Junk, M. (2008). Boundary forces in lattice Boltzmann: Analysis of momentum exchange - algorithm. Computers & Mathematics with Applications, 55(7), 1415-1423. - - Parameters - ---------- - f_poststreaming : jax.numpy.ndarray - The post-streaming distribution function at each node in the lattice. - f_postcollision : jax.numpy.ndarray - The post-collision distribution function at each node in the lattice. - - Returns - ------- - jax.numpy.ndarray - The force exerted on the solid geometry at each boundary node. - - Notes - ----- - This method computes the force exerted on the solid geometry at each boundary node using the momentum exchange method. - The force is computed based on the post-streaming and post-collision distribution functions. This method - should be called after the boundary conditions are imposed. - """ - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - phi = f_postcollision[self.indices][bindex, self.iknown] + \ - f_poststreaming[self.indices][bindex, self.imissing] - force = jnp.sum(c[:, self.iknown] * phi, axis=-1).T - return force - -class BounceBack(BoundaryCondition): - """ - Bounce-back boundary condition for a lattice Boltzmann method simulation. - - This class implements a full-way bounce-back boundary condition, where particles hitting the boundary are reflected - back in the direction they came from. The boundary condition is applied after the collision step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "BounceBackFullway". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostCollision". - """ - def __init__(self, indices, gridInfo, precision_policy): - super().__init__(indices, gridInfo, precision_policy) - self.name = "BounceBackFullway" - self.implementationStep = "PostCollision" - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the bounce-back boundary condition by reflecting the input distribution functions at the - boundary nodes in the opposite direction. - """ - - return fin[self.indices][..., self.lattice.opp_indices] - -class BounceBackMoving(BoundaryCondition): - """ - Moving bounce-back boundary condition for a lattice Boltzmann method simulation. - - This class implements a moving bounce-back boundary condition, where particles hitting the boundary are reflected - back in the direction they came from, with an additional velocity due to the movement of the boundary. The boundary - condition is applied after the collision step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "BounceBackFullwayMoving". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostCollision". - isDynamic : bool - Whether the boundary condition is dynamic (changes over time). For this class, it is True. - update_function : function - A function that updates the boundary condition. For this class, it is a function that updates the boundary - condition based on the current time step. The signature of the function is `update_function(time) -> (indices, vel)`, - - """ - def __init__(self, gridInfo, precision_policy, update_function=None): - # We get the indices at time zero to pass to the parent class for initialization - indices, _ = update_function(0) - super().__init__(indices, gridInfo, precision_policy) - self.name = "BounceBackFullwayMoving" - self.implementationStep = "PostCollision" - self.isDynamic = True - self.update_function = jit(update_function) - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin, time): - """ - Applies the moving bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - time : int - The current time step. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - indices, vel = self.update_function(time) - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 6.0 * self.lattice.w * jnp.dot(vel, c) - return fout.at[indices].set(fin[indices][..., self.lattice.opp_indices] - cu) - - -class BounceBackHalfway(BoundaryCondition): - """ - Halfway bounce-back boundary condition for a lattice Boltzmann method simulation. - - This class implements a halfway bounce-back boundary condition. The boundary condition is applied after - the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "BounceBackHalfway". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - needsExtraConfiguration : bool - Whether the boundary condition needs extra configuration before it can be applied. For this class, it is True. - isSolid : bool - Whether the boundary condition represents a solid boundary. For this class, it is True. - vel : array-like - The prescribed value of velocity vector for the boundary condition. No-slip BC is assumed if vel=None (default). - """ - def __init__(self, indices, gridInfo, precision_policy, vel=None): - super().__init__(indices, gridInfo, precision_policy) - self.name = "BounceBackHalfway" - self.implementationStep = "PostStreaming" - self.needsExtraConfiguration = True - self.isSolid = True - self.vel = vel - - def configure(self, boundaryMask): - """ - Configures the boundary condition. - - Parameters - ---------- - boundaryMask : array-like - The grid mask for the boundary voxels. - - Returns - ------- - None - - Notes - ----- - This method performs an index shift for the halfway bounce-back boundary condition. It updates the indices of - the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes. - """ - # Perform index shift for halfway BB. - hasFluidNeighbour = ~boundaryMask[:, self.lattice.opp_indices] - nbd_orig = len(self.indices[0]) - idx = np.array(self.indices).T - idx_trg = [] - for i in range(self.lattice.q): - idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) - indices_new = np.unique(np.vstack(idx_trg), axis=0) - self.indices = tuple(indices_new.T) - nbd_modified = len(self.indices[0]) - if (nbd_orig != nbd_modified) and self.vel is not None: - vel_avg = np.mean(self.vel, axis=0) - self.vel = jnp.zeros(indices_new.shape, dtype=self.precisionPolicy.compute_dtype) + vel_avg - print("WARNING: assuming a constant averaged velocity vector is imposed at all BC cells!") - - return - - @partial(jit, static_argnums=(0,)) - def impose_boundary_vel(self, fbd, bindex): - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c) - fbd = fbd.at[bindex, self.imissing].add(-cu[bindex, self.iknown]) - return fbd - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the halfway bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) - if self.vel is not None: - fbd = self.impose_boundary_vel(fbd, bindex) - return fbd - -class EquilibriumBC(BoundaryCondition): - """ - Equilibrium boundary condition for a lattice Boltzmann method simulation. - - This class implements an equilibrium boundary condition, where the distribution function at the boundary nodes is - set to the equilibrium distribution function. The boundary condition is applied after the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "EquilibriumBC". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - out : jax.numpy.ndarray - The equilibrium distribution function at the boundary nodes. - """ - - def __init__(self, indices, gridInfo, precision_policy, rho, u): - super().__init__(indices, gridInfo, precision_policy) - self.out = self.precisionPolicy.cast_to_output(self.equilibrium(rho, u)) - self.name = "EquilibriumBC" - self.implementationStep = "PostStreaming" - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the equilibrium boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the equilibrium boundary condition by setting the output distribution functions at the - boundary nodes to the equilibrium distribution function. - """ - return self.out - -class DoNothing(BoundaryCondition): - def __init__(self, indices, gridInfo, precision_policy): - """ - Do-nothing boundary condition for a lattice Boltzmann method simulation. - - This class implements a do-nothing boundary condition, where no action is taken at the boundary nodes. The boundary - condition is applied after the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "DoNothing". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - - Notes - ----- - This boundary condition enforces skipping of streaming altogether as it sets post-streaming equal to post-collision - populations (so no streaming at this BC voxels). The problem with returning post-streaming values or "fout[self.indices] - is that the information that exit the domain on the opposite side of this boundary, would "re-enter". This is because - we roll the entire array and so the boundary condition acts like a one-way periodic BC. If EquilibriumBC is used as - the BC for that opposite boundary, then the rolled-in values are taken from the initial condition at equilibrium. - Otherwise if ZouHe is used for example the simulation looks like a run-down simulation at low-Re. The opposite boundary - may be even a wall (consider pipebend example). If we correct imissing directions and assign "fin", this method becomes - much less stable and also one needs to correctly take care of corner cases. - """ - super().__init__(indices, gridInfo, precision_policy) - self.name = "DoNothing" - self.implementationStep = "PostStreaming" - - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the do-nothing boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the do-nothing boundary condition by simply returning the input distribution functions at the - boundary nodes. - """ - return fin[self.indices] - - -class ZouHe(BoundaryCondition): - """ - Zou-He boundary condition for a lattice Boltzmann method simulation. - - This class implements the Zou-He boundary condition, which is a non-equilibrium bounce-back boundary condition. - It can be used to set inflow and outflow boundary conditions with prescribed pressure or velocity. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "ZouHe". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - type : str - The type of the boundary condition. It can be either 'velocity' for a prescribed velocity boundary condition, - or 'pressure' for a prescribed pressure boundary condition. - prescribed : float or array-like - The prescribed values for the boundary condition. It can be either the prescribed velocities for a 'velocity' - boundary condition, or the prescribed pressures for a 'pressure' boundary condition. - - References - ---------- - Zou, Q., & He, X. (1997). On pressure and velocity boundary conditions for the lattice Boltzmann BGK model. - Physics of Fluids, 9(6), 1591-1598. doi:10.1063/1.869307 - """ - def __init__(self, indices, gridInfo, precision_policy, type, prescribed): - super().__init__(indices, gridInfo, precision_policy) - self.name = "ZouHe" - self.implementationStep = "PostStreaming" - self.type = type - self.prescribed = prescribed - self.needsExtraConfiguration = True - - def configure(self, boundaryMask): - """ - Correct boundary indices to ensure that only voxelized surfaces with normal vectors along main cartesian axes - are assigned this type of BC. - """ - nv = np.dot(self.lattice.c, ~boundaryMask.T) - corner_voxels = np.count_nonzero(nv, axis=0) > 1 - # removed_voxels = np.array(self.indices)[:, corner_voxels] - self.indices = tuple(np.array(self.indices)[:, ~corner_voxels]) - self.prescribed = self.prescribed[~corner_voxels] - return - - @partial(jit, static_argnums=(0,), inline=True) - def calculate_vel(self, fpop, rho): - """ - Calculate velocity based on the prescribed pressure/density (Zou/He BC) - """ - unormal = -1. + 1. / rho * (jnp.sum(fpop[self.indices] * self.imiddleMask, axis=1) + - 2. * jnp.sum(fpop[self.indices] * self.iknownMask, axis=1)) - - # Return the above unormal as a normal vector which sets the tangential velocities to zero - vel = unormal[:, None] * self.normals - return vel - - @partial(jit, static_argnums=(0,), inline=True) - def calculate_rho(self, fpop, vel): - """ - Calculate density based on the prescribed velocity (Zou/He BC) - """ - unormal = np.sum(self.normals*vel, axis=1) - - rho = (1.0/(1.0 + unormal))[..., None] * (jnp.sum(fpop[self.indices] * self.imiddleMask, axis=1, keepdims=True) + - 2.*jnp.sum(fpop[self.indices] * self.iknownMask, axis=1, keepdims=True)) - return rho - - @partial(jit, static_argnums=(0,), inline=True) - def calculate_equilibrium(self, fpop): - """ - This is the ZouHe method of calculating the missing macroscopic variables at the boundary. - """ - if self.type == 'velocity': - vel = self.prescribed - rho = self.calculate_rho(fpop, vel) - elif self.type == 'pressure': - rho = self.prescribed - vel = self.calculate_vel(fpop, rho) - else: - raise ValueError(f"type = {self.type} not supported! Use \'pressure\' or \'velocity\'.") - - # compute feq at the boundary - feq = self.equilibrium(rho, vel) - return feq - - @partial(jit, static_argnums=(0,), inline=True) - def bounceback_nonequilibrium(self, fpop, feq): - """ - Calculate unknown populations using bounce-back of non-equilibrium populations - a la original Zou & He formulation - """ - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fpop[self.indices] - fknown = fpop[self.indices][bindex, self.iknown] + feq[bindex, self.imissing] - feq[bindex, self.iknown] - fbd = fbd.at[bindex, self.imissing].set(fknown) - return fbd - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, _): - """ - Applies the Zou-He boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - _ : jax.numpy.ndarray - The input distribution functions. This is not used in this method. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the Zou-He boundary condition by first computing the equilibrium distribution functions based - on the prescribed values and the type of boundary condition, and then setting the unknown distribution functions - based on the non-equilibrium bounce-back method. - Tangential velocity is not ensured to be zero by adding transverse contributions based on - Hecth & Harting (2010) (doi:10.1088/1742-5468/2010/01/P01018) as it caused numerical instabilities at higher - Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. - """ - # compute the equilibrium based on prescribed values and the type of BC - feq = self.calculate_equilibrium(fout) - - # set the unknown f populations based on the non-equilibrium bounce-back method - fbd = self.bounceback_nonequilibrium(fout, feq) - - - return fbd - -class Regularized(ZouHe): - """ - Regularized boundary condition for a lattice Boltzmann method simulation. - - This class implements the regularized boundary condition, which is a non-equilibrium bounce-back boundary condition - with additional regularization. It can be used to set inflow and outflow boundary conditions with prescribed pressure - or velocity. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "Regularized". - Qi : numpy.ndarray - The Qi tensor, which is used in the regularization of the distribution functions. - - References - ---------- - Latt, J. (2007). Hydrodynamic limit of lattice Boltzmann equations. PhD thesis, University of Geneva. - Latt, J., Chopard, B., Malaspinas, O., Deville, M., & Michler, A. (2008). Straight velocity boundaries in the - lattice Boltzmann method. Physical Review E, 77(5), 056703. doi:10.1103/PhysRevE.77.056703 - """ - - def __init__(self, indices, gridInfo, precision_policy, type, prescribed): - super().__init__(indices, gridInfo, precision_policy, type, prescribed) - self.name = "Regularized" - #TODO for Hesam: check to understand why corner cases cause instability here. - # self.needsExtraConfiguration = False - self.construct_symmetric_lattice_moment() - - def construct_symmetric_lattice_moment(self): - """ - Construct the symmetric lattice moment Qi. - - The Qi tensor is used in the regularization of the distribution functions. It is defined as Qi = cc - cs^2*I, - where cc is the tensor of lattice velocities, cs is the speed of sound, and I is the identity tensor. - """ - Qi = self.lattice.cc - if self.dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif self.dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {self.dim} not supported") - - # Qi = cc - cs^2*I - Qi = Qi.at[:, diagonal].set(self.lattice.cc[:, diagonal] - 1./3.) - - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi = Qi.at[:, offdiagonal].set(self.lattice.cc[:, offdiagonal] * 2.0) - - self.Qi = Qi.T - return - - @partial(jit, static_argnums=(0,), inline=True) - def regularize_fpop(self, fpop, feq): - """ - Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. - - Parameters - ---------- - fpop : jax.numpy.ndarray - The distribution functions. - feq : jax.numpy.ndarray - The equilibrium distribution functions. - - Returns - ------- - jax.numpy.ndarray - The regularized distribution functions. - """ - - # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} - f_neq = fpop - feq - PiNeq = self.momentum_flux(f_neq) - # PiNeq = self.momentum_flux(fpop) - self.momentum_flux(feq) - - # Compute double dot product Qi:Pi1 - # QiPi1 = np.zeros_like(fpop) - # Pi1 = PiNeq - # QiPi1 = jnp.dot(Qi, Pi1) - QiPi1 = jnp.dot(PiNeq, self.Qi) - - # assign all populations based on eq 45 of Latt et al (2008) - # fneq ~ f^1 - fpop1 = 9. / 2. * self.lattice.w[None, :] * QiPi1 - fpop_regularized = feq + fpop1 - - return fpop_regularized - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, _): - """ - Applies the regularized boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - _ : jax.numpy.ndarray - The input distribution functions. This is not used in this method. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the regularized boundary condition by first computing the equilibrium distribution functions based - on the prescribed values and the type of boundary condition, then setting the unknown distribution functions - based on the non-equilibrium bounce-back method, and finally regularizing the distribution functions. - """ - - # compute the equilibrium based on prescribed values and the type of BC - feq = self.calculate_equilibrium(fout) - - # set the unknown f populations based on the non-equilibrium bounce-back method - fbd = self.bounceback_nonequilibrium(fout, feq) - - # Regularize the boundary fpop - fbd = self.regularize_fpop(fbd, feq) - return fbd - - -class ExtrapolationOutflow(BoundaryCondition): - """ - Extrapolation outflow boundary condition for a lattice Boltzmann method simulation. - - This class implements the extrapolation outflow boundary condition, which is a type of outflow boundary condition - that uses extrapolation to avoid strong wave reflections. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "ExtrapolationOutflow". - sound_speed : float - The speed of sound in the simulation. - - References - ---------- - Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three - dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547. - doi:10.1016/j.camwa.2015.05.001. - """ - - def __init__(self, indices, gridInfo, precision_policy): - super().__init__(indices, gridInfo, precision_policy) - self.name = "ExtrapolationOutflow" - self.needsExtraConfiguration = True - self.sound_speed = 1./jnp.sqrt(3.) - - def configure(self, boundaryMask): - """ - Configure the boundary condition by finding neighbouring voxel indices. - - Parameters - ---------- - boundaryMask : np.ndarray - The grid mask for the boundary voxels. - """ - hasFluidNeighbour = ~boundaryMask[:, self.lattice.opp_indices] - idx = np.array(self.indices).T - idx_trg = [] - for i in range(self.lattice.q): - idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) - indices_nbr = np.unique(np.vstack(idx_trg), axis=0) - self.indices_nbr = tuple(indices_nbr.T) - - return - - @partial(jit, static_argnums=(0, 3), inline=True) - def prepare_populations(self, fout, fin, implementation_step): - """ - Prepares the distribution functions for the boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The incoming distribution functions. - fin : jax.numpy.ndarray - The outgoing distribution functions. - implementation_step : str - The step in the lattice Boltzmann method algorithm at which the preparation is applied. - - Returns - ------- - jax.numpy.ndarray - The prepared distribution functions. - - Notes - ----- - Because this function is called "PostCollision", f_poststreaming refers to previous time step or t-1 - """ - f_postcollision = fout - f_poststreaming = fin - if implementation_step == 'PostStreaming': - return f_postcollision - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fps_bdr = f_poststreaming[self.indices] - fps_nbr = f_poststreaming[self.indices_nbr] - fpc_bdr = f_postcollision[self.indices] - fpop = fps_bdr[bindex, self.imissing] - fpop_neighbour = fps_nbr[bindex, self.imissing] - fpop_extrapolated = self.sound_speed * fpop_neighbour + (1. - self.sound_speed) * fpop - - # Use the iknown directions of f_postcollision that leave the domain during streaming to store the BC data - fpc_bdr = fpc_bdr.at[bindex, self.iknown].set(fpop_extrapolated) - f_postcollision = f_postcollision.at[self.indices].set(fpc_bdr) - return f_postcollision - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the extrapolation outflow boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) - return fbd - - -class InterpolatedBounceBackBouzidi(BounceBackHalfway): - """ - A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice - Boltzmann method simulation. - - This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after - the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi". - implicit_distances : array-like - An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls - weights : array-like - An array of shape (number_of_bc_cells, q) initialized as None and constructed using implicit_distances array - during runtime. These "weights" are associated with the fractional distance of fluid cell to the boundary - position defined as: weights(dir_i) = |x_fluid - x_boundary(dir_i)| / |x_fluid - x_solid(dir_i)|. - """ - - def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): - - super().__init__(indices, grid_info, precision_policy, vel=vel) - self.name = "InterpolatedBounceBackBouzidi" - self.implicit_distances = implicit_distances - self.weights = None - - def set_proximity_ratio(self): - """ - Creates the interpolation data needed for the boundary condition. - - Returns - ------- - None. The function updates the object's weights attribute in place. - """ - idx = np.array(self.indices).T - self.weights = np.full((idx.shape[0], self.lattice.q), 0.5) - c = np.array(self.lattice.c) - sdf_f = self.implicit_distances[self.indices] - for q in range(1, self.lattice.q): - solid_indices = idx + c[:, q] - solid_indices_tuple = tuple(map(tuple, solid_indices.T)) - sdf_s = self.implicit_distances[solid_indices_tuple] - mask = self.iknownMask[:, q] - self.weights[mask, q] = sdf_f[mask] / (sdf_f[mask] - sdf_s[mask]) - return - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the halfway bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - if self.weights is None: - self.set_proximity_ratio() - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - f_postcollision_iknown = fin[self.indices][bindex, self.iknown] - f_postcollision_imissing = fin[self.indices][bindex, self.imissing] - f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] - - # if weights<0.5 - fs_near = 2. * self.weights * f_postcollision_iknown + \ - (1.0 - 2.0 * self.weights) * f_poststreaming_iknown - - # if weights>=0.5 - fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \ - (2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing - - # combine near and far contributions - fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far) - fbd = fbd.at[bindex, self.imissing].set(fmissing) - - if self.vel is not None: - fbd = self.impose_boundary_vel(fbd, bindex) - return fbd - - -class InterpolatedBounceBackDifferentiable(InterpolatedBounceBackBouzidi): - """ - A differentiable variant of the "InterpolatedBounceBackBouzidi" BC scheme. This BC is now differentiable at - self.weight = 0.5 unlike the original Bouzidi scheme which switches between 2 equations at weight=0.5. Refer to - [1] (their Appendix E) for more information. - - References - ---------- - [1] Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three - dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547. - doi:10.1016/j.camwa.2015.05.001. - - - This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after - the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable". - """ - - def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): - - super().__init__(indices, implicit_distances, grid_info, precision_policy, vel=vel) - self.name = "InterpolatedBounceBackDifferentiable" - - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the halfway bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - if self.weights is None: - self.set_proximity_ratio() - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - f_postcollision_iknown = fin[self.indices][bindex, self.iknown] - f_postcollision_imissing = fin[self.indices][bindex, self.imissing] - f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] - fmissing = ((1. - self.weights) * f_poststreaming_iknown + - self.weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + self.weights) - fbd = fbd.at[bindex, self.imissing].set(fmissing) - - if self.vel is not None: - fbd = self.impose_boundary_vel(fbd, bindex) - return fbd \ No newline at end of file diff --git a/src/lattice.py b/src/lattice.py deleted file mode 100644 index 788796b..0000000 --- a/src/lattice.py +++ /dev/null @@ -1,281 +0,0 @@ -import re -import numpy as np -import jax.numpy as jnp - - -class Lattice(object): - """ - This class represents a lattice in the Lattice Boltzmann Method. - - It stores the properties of the lattice, including the dimensions, the number of - velocities, the velocity vectors, the weights, the moments, and the indices of the - opposite, main, right, and left velocities. - - The class also provides methods to construct these properties based on the name of the - lattice. - - Parameters - ---------- - name: str - The name of the lattice, which specifies the dimensions and the number of velocities. - For example, "D2Q9" represents a 2D lattice with 9 velocities. - precision: str, optional - The precision of the computations. It can be "f32/f32", "f32/f16", "f64/f64", - "f64/f32", or "f64/f16". The first part before the slash is the precision of the - computations, and the second part after the slash is the precision of the outputs. - """ - def __init__(self, name, precision="f32/f32") -> None: - self.name = name - dq = re.findall(r"\d+", name) - self.precision = precision - self.d = int(dq[0]) - self.q = int(dq[1]) - if precision == "f32/f32" or precision == "f32/f16": - self.precisionPolicy = jnp.float32 - elif precision == "f64/f64" or precision == "f64/f32" or precision == "f64/f16": - self.precisionPolicy = jnp.float64 - elif precision == "f16/f16": - self.precisionPolicy = jnp.float16 - else: - raise ValueError("precision not supported") - - # Construct the properties of the lattice - self.c = jnp.array(self.construct_lattice_velocity(), dtype=jnp.int8) - self.w = jnp.array(self.construct_lattice_weight(), dtype=self.precisionPolicy) - self.cc = jnp.array(self.construct_lattice_moment(), dtype=self.precisionPolicy) - self.opp_indices = jnp.array(self.construct_opposite_indices(), dtype=jnp.int8) - self.main_indices = jnp.array(self.construct_main_indices(), dtype=jnp.int8) - self.right_indices = np.array(self.construct_right_indices(), dtype=jnp.int8) - self.left_indices = np.array(self.construct_left_indices(), dtype=jnp.int8) - - def construct_opposite_indices(self): - """ - This function constructs the indices of the opposite velocities for each velocity. - - The opposite velocity of a velocity is the velocity that has the same magnitude but the - opposite direction. - - Returns - ------- - opposite: numpy.ndarray - The indices of the opposite velocities. - """ - c = self.c.T - opposite = np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) - return opposite - - def construct_right_indices(self): - """ - This function constructs the indices of the velocities that point in the positive - x-direction. - - Returns - ------- - numpy.ndarray - The indices of the right velocities. - """ - c = self.c.T - return np.nonzero(c[:, 0] == 1)[0] - - def construct_left_indices(self): - """ - This function constructs the indices of the velocities that point in the negative - x-direction. - - Returns - ------- - numpy.ndarray - The indices of the left velocities. - """ - c = self.c.T - return np.nonzero(c[:, 0] == -1)[0] - - def construct_main_indices(self): - """ - This function constructs the indices of the main velocities. - - The main velocities are the velocities that have a magnitude of 1 in lattice units. - - Returns - ------- - numpy.ndarray - The indices of the main velocities. - """ - c = self.c.T - if self.d == 2: - return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] - - elif self.d == 3: - return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] - - def construct_lattice_velocity(self): - """ - This function constructs the velocity vectors of the lattice. - - The velocity vectors are defined based on the name of the lattice. For example, for a D2Q9 - lattice, there are 9 velocities: (0,0), (1,0), (-1,0), (0,1), (0,-1), (1,1), (-1,-1), - (1,-1), and (-1,1). - - Returns - ------- - c.T: numpy.ndarray - The velocity vectors of the lattice. - """ - if self.name == "D2Q9": # D2Q9 - cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] - cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] - c = np.array(tuple(zip(cx, cy))) - elif self.name == "D3Q19": # D3Q19 - c = [(x, y, z) for x in [0, -1, 1] for y in [0, -1, 1] for z in [0, -1, 1]] - c = np.array([ci for ci in c if np.linalg.norm(ci) < 1.5]) - elif self.name == "D3Q27": # D3Q27 - c = [(x, y, z) for x in [0, -1, 1] for y in [0, -1, 1] for z in [0, -1, 1]] - # c = np.array([ci for ci in c if np.linalg.norm(ci) < 1.5]) - c = np.array(c) - else: - raise ValueError("Supported Lattice types are D2Q9, D3Q19 and D3Q27") - - return c.T - - def construct_lattice_weight(self): - """ - This function constructs the weights of the lattice. - - The weights are defined based on the name of the lattice. For example, for a D2Q9 lattice, - the weights are 4/9 for the rest velocity, 1/9 for the main velocities, and 1/36 for the - diagonal velocities. - - Returns - ------- - w: numpy.ndarray - The weights of the lattice. - """ - # Get the transpose of the lattice vector - c = self.c.T - - # Initialize the weights to be 1/36 - w = 1.0 / 36.0 * np.ones(self.q) - - # Update the weights for 2D and 3D lattices - if self.name == "D2Q9": - w[np.linalg.norm(c, axis=1) < 1.1] = 1.0 / 9.0 - w[0] = 4.0 / 9.0 - elif self.name == "D3Q19": - w[np.linalg.norm(c, axis=1) < 1.1] = 2.0 / 36.0 - w[0] = 1.0 / 3.0 - elif self.name == "D3Q27": - cl = np.linalg.norm(c, axis=1) - w[np.isclose(cl, 1.0, atol=1e-8)] = 2.0 / 27.0 - w[(cl > 1) & (cl <= np.sqrt(2))] = 1.0 / 54.0 - w[(cl > np.sqrt(2)) & (cl <= np.sqrt(3))] = 1.0 / 216.0 - w[0] = 8.0 / 27.0 - else: - raise ValueError("Supported Lattice types are D2Q9, D3Q19 and D3Q27") - - # Return the weights - return w - - def construct_lattice_moment(self): - """ - This function constructs the moments of the lattice. - - The moments are the products of the velocity vectors, which are used in the computation of - the equilibrium distribution functions and the collision operator in the Lattice Boltzmann - Method (LBM). - - Returns - ------- - cc: numpy.ndarray - The moments of the lattice. - """ - c = self.c.T - # Counter for the loop - cntr = 0 - - # nt: number of independent elements of a symmetric tensor - nt = self.d * (self.d + 1) // 2 - - cc = np.zeros((self.q, nt)) - for a in range(0, self.d): - for b in range(a, self.d): - cc[:, cntr] = c[:, a] * c[:, b] - cntr += 1 - - return cc - - def __str__(self): - return self.name - -class LatticeD2Q9(Lattice): - """ - Lattice class for 2D D2Q9 lattice. - - D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the - Lat tice Boltzmann Method for simulating fluid flows in two dimensions. - - Parameters - ---------- - precision: str, optional - The precision of the lattice. The default is "f32/f32" - """ - def __init__(self, precision="f32/f32"): - super().__init__("D2Q9", precision) - self._set_constants() - - def _set_constants(self): - self.cs = jnp.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - self.i_s = jnp.asarray(list(range(9))) - self.im = 3 # Number of imiddles (includes center) - self.ik = 3 # Number of iknowns or iunknowns - - -class LatticeD3Q19(Lattice): - """ - Lattice class for 3D D3Q19 lattice. - - D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the - Lattice Boltzmann Method for simulating fluid flows in three dimensions. - - Parameters - ---------- - precision: str, optional - The precision of the lattice. The default is "f32/f32" - """ - def __init__(self, precision="f32/f32"): - super().__init__("D3Q19", precision) - self._set_constants() - - def _set_constants(self): - self.cs = jnp.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - self.i_s = jnp.asarray(list(range(19)), dtype=jnp.int8) - - self.im = 9 # Number of imiddles (includes center) - self.ik = 5 # Number of iknowns or iunknowns - - -class LatticeD3Q27(Lattice): - """ - Lattice class for 3D D3Q27 lattice. - - D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the - Lattice Boltzmann Method for simulating fluid flows in three dimensions. - - Parameters - ---------- - precision: str, optional - The precision of the lattice. The default is "f32/f32" - """ - - def __init__(self, precision="f32/f32"): - super().__init__("D3Q27", precision) - self._set_constants() - - def _set_constants(self): - self.cs = jnp.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - self.i_s = jnp.asarray(list(range(27)), dtype=jnp.int8) \ No newline at end of file diff --git a/src/models.py b/src/models.py deleted file mode 100644 index a0500c8..0000000 --- a/src/models.py +++ /dev/null @@ -1,260 +0,0 @@ -import jax.numpy as jnp -from jax import jit -from functools import partial -from src.base import LBMBase -""" -Collision operators are defined in this file for different models. -""" - -class BGKSim(LBMBase): - """ - BGK simulation class. - - This class implements the Bhatnagar-Gross-Krook (BGK) approximation for the collision step in the Lattice Boltzmann Method. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, f): - """ - BGK collision step for lattice. - - The collision step is where the main physics of the LBM is applied. In the BGK approximation, - the distribution function is relaxed towards the equilibrium distribution function. - """ - f = self.precisionPolicy.cast_to_compute(f) - rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, cast_output=False) - fneq = f - feq - fout = f - self.omega * fneq - if self.force is not None: - fout = self.apply_force(fout, feq, rho, u) - return self.precisionPolicy.cast_to_output(fout) - -class KBCSim(LBMBase): - """ - KBC simulation class. - - This class implements the Karlin-Bösch-Chikatamarla (KBC) model for the collision step in the Lattice Boltzmann Method. - """ - def __init__(self, **kwargs): - if kwargs.get('lattice').name != 'D3Q27' and kwargs.get('nz') > 0: - raise ValueError("KBC collision operator in 3D must only be used with D3Q27 lattice.") - super().__init__(**kwargs) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, f): - """ - KBC collision step for lattice. - """ - f = self.precisionPolicy.cast_to_compute(f) - tiny = 1e-32 - beta = self.omega * 0.5 - rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, cast_output=False) - fneq = f - feq - if self.dim == 2: - deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 - else: - deltaS = self.fdecompose_shear_d3q27(fneq) * rho - deltaH = fneq - deltaS - invBeta = 1.0 / beta - gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq)) - - fout = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH) - - # add external force - if self.force is not None: - fout = self.apply_force(fout, feq, rho, u) - return self.precisionPolicy.cast_to_output(fout) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision_modified(self, f): - """ - Alternative KBC collision step for lattice. - Note: - At low Reynolds number the orignal KBC collision above produces inaccurate results because - it does not check for the entropy increase/decrease. The KBC stabalizations should only be - applied in principle to cells whose entropy decrease after a regular BGK collision. This is - the case in most cells at higher Reynolds numbers and hence a check may not be needed. - Overall the following alternative collision is more reliable and may replace the original - implementation. The issue at the moment is that it is about 60-80% slower than the above method. - """ - f = self.precisionPolicy.cast_to_compute(f) - tiny = 1e-32 - beta = self.omega * 0.5 - rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, castOutput=False) - - # Alternative KBC: only stabalizes for voxels whose entropy decreases after BGK collision. - f_bgk = f - self.omega * (f - feq) - H_fin = jnp.sum(f * jnp.log(f / self.w), axis=-1, keepdims=True) - H_fout = jnp.sum(f_bgk * jnp.log(f_bgk / self.w), axis=-1, keepdims=True) - - # the rest is identical to collision_deprecated - fneq = f - feq - if self.dim == 2: - deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 - else: - deltaS = self.fdecompose_shear_d3q27(fneq) * rho - deltaH = fneq - deltaS - invBeta = 1.0 / beta - gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq)) - - f_kbc = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH) - fout = jnp.where(H_fout > H_fin, f_kbc, f_bgk) - - # add external force - if self.force is not None: - fout = self.apply_force(fout, feq, rho, u) - return self.precisionPolicy.cast_to_output(fout) - - @partial(jit, static_argnums=(0,), inline=True) - def entropic_scalar_product(self, x, y, feq): - """ - Compute the entropic scalar product of x and y to approximate gamma in KBC. - - Returns - ------- - jax.numpy.array - Entropic scalar product of x, y, and feq. - """ - return jnp.sum(x * y / feq, axis=-1) - - @partial(jit, static_argnums=(0,), inline=True) - def fdecompose_shear_d2q9(self, fneq): - """ - Decompose fneq into shear components for D2Q9 lattice. - - Parameters - ---------- - fneq : jax.numpy.array - Non-equilibrium distribution function. - - Returns - ------- - jax.numpy.array - Shear components of fneq. - """ - Pi = self.momentum_flux(fneq) - N = Pi[..., 0] - Pi[..., 2] - s = jnp.zeros_like(fneq) - s = s.at[..., 6].set(N) - s = s.at[..., 3].set(N) - s = s.at[..., 2].set(-N) - s = s.at[..., 1].set(-N) - s = s.at[..., 8].set(Pi[..., 1]) - s = s.at[..., 4].set(-Pi[..., 1]) - s = s.at[..., 5].set(-Pi[..., 1]) - s = s.at[..., 7].set(Pi[..., 1]) - - return s - - @partial(jit, static_argnums=(0,), inline=True) - def fdecompose_shear_d3q27(self, fneq): - """ - Decompose fneq into shear components for D3Q27 lattice. - - Parameters - ---------- - fneq : jax.numpy.ndarray - Non-equilibrium distribution function. - - Returns - ------- - jax.numpy.ndarray - Shear components of fneq. - """ - # if self.grid.dim == 3: - # diagonal = (0, 3, 5) - # offdiagonal = (1, 2, 4) - # elif self.grid.dim == 2: - # diagonal = (0, 2) - # offdiagonal = (1,) - - # c= - # array([[0, 0, 0],-----0 - # [0, 0, -1],----1 - # [0, 0, 1],-----2 - # [0, -1, 0],----3 - # [0, -1, -1],---4 - # [0, -1, 1],----5 - # [0, 1, 0],-----6 - # [0, 1, -1],----7 - # [0, 1, 1],-----8 - # [-1, 0, 0],----9 - # [-1, 0, -1],--10 - # [-1, 0, 1],---11 - # [-1, -1, 0],--12 - # [-1, -1, -1],-13 - # [-1, -1, 1],--14 - # [-1, 1, 0],---15 - # [-1, 1, -1],--16 - # [-1, 1, 1],---17 - # [1, 0, 0],----18 - # [1, 0, -1],---19 - # [1, 0, 1],----20 - # [1, -1, 0],---21 - # [1, -1, -1],--22 - # [1, -1, 1],---23 - # [1, 1, 0],----24 - # [1, 1, -1],---25 - # [1, 1, 1]])---26 - Pi = self.momentum_flux(fneq) - Nxz = Pi[..., 0] - Pi[..., 5] - Nyz = Pi[..., 3] - Pi[..., 5] - - # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) - s = jnp.zeros_like(fneq) - s = s.at[..., 9].set((2.0 * Nxz - Nyz) / 6.0) - s = s.at[..., 18].set((2.0 * Nxz - Nyz) / 6.0) - s = s.at[..., 3].set((-Nxz + 2.0 * Nyz) / 6.0) - s = s.at[..., 6].set((-Nxz + 2.0 * Nyz) / 6.0) - s = s.at[..., 1].set((-Nxz - Nyz) / 6.0) - s = s.at[..., 2].set((-Nxz - Nyz) / 6.0) - - # For c = (i, j, 0) - s = s.at[..., 12].set(Pi[..., 1] / 4.0) - s = s.at[..., 24].set(Pi[..., 1] / 4.0) - s = s.at[..., 21].set(-Pi[..., 1] / 4.0) - s = s.at[..., 15].set(-Pi[..., 1] / 4.0) - - # For c = (i, 0, k) - s = s.at[..., 10].set(Pi[..., 2] / 4.0) - s = s.at[..., 20].set(Pi[..., 2] / 4.0) - s = s.at[..., 19].set(-Pi[..., 2] / 4.0) - s = s.at[..., 11].set(-Pi[..., 2] / 4.0) - - # For c = (0, j, k) - s = s.at[..., 8].set(Pi[..., 4] / 4.0) - s = s.at[..., 4].set(Pi[..., 4] / 4.0) - s = s.at[..., 7].set(-Pi[..., 4] / 4.0) - s = s.at[..., 5].set(-Pi[..., 4] / 4.0) - - return s - - -class AdvectionDiffusionBGK(LBMBase): - """ - Advection Diffusion Model based on the BGK model. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.vel = kwargs.get("vel", None) - if self.vel is None: - raise ValueError("Velocity must be specified for AdvectionDiffusionBGK.") - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, f): - """ - BGK collision step for lattice. - """ - f = self.precisionPolicy.cast_to_compute(f) - rho =jnp.sum(f, axis=-1, keepdims=True) - feq = self.equilibrium(rho, self.vel, cast_output=False) - fneq = f - feq - fout = f - self.omega * fneq - return self.precisionPolicy.cast_to_output(fout) \ No newline at end of file diff --git a/xlb/__init__.py b/xlb/__init__.py new file mode 100644 index 0000000..8f49baa --- /dev/null +++ b/xlb/__init__.py @@ -0,0 +1,18 @@ +# Enum classes +from xlb.compute_backend import ComputeBackend +from xlb.physics_type import PhysicsType + +# Precision policy +import xlb.precision_policy + +# Velocity Set +import xlb.velocity_set + +# Operators +import xlb.operator.collision +import xlb.operator.stream +import xlb.operator.boundary_condition +# import xlb.operator.force +import xlb.operator.equilibrium +import xlb.operator.macroscopic +import xlb.operator.stepper diff --git a/src/base.py b/xlb/base.py similarity index 100% rename from src/base.py rename to xlb/base.py diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py new file mode 100644 index 0000000..dee998f --- /dev/null +++ b/xlb/compute_backend.py @@ -0,0 +1,9 @@ +# Enum used to keep track of the compute backends + +from enum import Enum + +class ComputeBackend(Enum): + JAX = 1 + NUMBA = 2 + PYTORCH = 3 + WARP = 4 diff --git a/xlb/experimental/__init__.py b/xlb/experimental/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/xlb/experimental/__init__.py @@ -0,0 +1 @@ + diff --git a/xlb/experimental/ooc/__init__.py b/xlb/experimental/ooc/__init__.py new file mode 100644 index 0000000..5206cc1 --- /dev/null +++ b/xlb/experimental/ooc/__init__.py @@ -0,0 +1,2 @@ +from xlb.experimental.ooc.out_of_core import OOCmap +from xlb.experimental.ooc.ooc_array import OOCArray diff --git a/xlb/experimental/ooc/ooc_array.py b/xlb/experimental/ooc/ooc_array.py new file mode 100644 index 0000000..f5e2c34 --- /dev/null +++ b/xlb/experimental/ooc/ooc_array.py @@ -0,0 +1,485 @@ +import numpy as np +import cupy as cp +#from mpi4py import MPI +import itertools +from dataclasses import dataclass + +from xlb.experimental.ooc.tiles.dense_tile import DenseTile, DenseGPUTile, DenseCPUTile +from xlb.experimental.ooc.tiles.compressed_tile import CompressedTile, CompressedGPUTile, CompressedCPUTile + + +class OOCArray: + """An out-of-core distributed array class. + + Parameters + ---------- + shape : tuple + The shape of the array. + dtype : cp.dtype + The data type of the array. + tile_shape : tuple + The shape of the tiles. Should be a factor of the shape. + padding : int or tuple + The padding of the tiles. + comm : MPI communicator + The MPI communicator. + devices : list of cp.cuda.Device + The list of GPU devices to use. + codec : Codec + The codec to use for compression. None for no compression (Dense tiles). + nr_compute_tiles : int + The number of compute tiles used for asynchronous copies. + TODO currently only 1 is supported when using JAX. + """ + + def __init__( + self, + shape, + dtype, + tile_shape, + padding=1, + comm=None, + devices=[cp.cuda.Device(0)], + codec=None, + nr_compute_tiles=1, + ): + self.shape = shape + self.tile_shape = tile_shape + self.dtype = dtype + if isinstance(padding, int): + padding = (padding,) * len(shape) + self.padding = padding + self.comm = comm + self.devices = devices + self.codec = codec + self.nr_compute_tiles = nr_compute_tiles + + # Set tile class + if self.codec is None: + self.Tile = DenseTile + self.DeviceTile = DenseGPUTile + self.HostTile = ( + DenseCPUTile # TODO: Possibly make HardDiskTile or something + ) + + else: + self.Tile = CompressedTile + self.DeviceTile = CompressedGPUTile + self.HostTile = CompressedCPUTile + + # Get process id and number of processes + self.pid = self.comm.Get_rank() + self.nr_proc = self.comm.Get_size() + + # Check that the tile shape divides the array shape. + if any([shape[i] % tile_shape[i] != 0 for i in range(len(shape))]): + raise ValueError(f"Tile shape {tile_shape} does not divide shape {shape}.") + self.tile_dims = tuple([shape[i] // tile_shape[i] for i in range(len(shape))]) + self.nr_tiles = np.prod(self.tile_dims) + + # Get number of tiles per process + if self.nr_tiles % self.nr_proc != 0: + raise ValueError( + f"Number of tiles {self.nr_tiles} does not divide number of processes {self.nr_proc}." + ) + self.nr_tiles_per_proc = self.nr_tiles // self.nr_proc + + # Make the tile mapppings + self.tile_process_map = {} + self.tile_device_map = {} + for i, tile_index in enumerate( + itertools.product(*[range(n) for n in self.tile_dims]) + ): + self.tile_process_map[tile_index] = i % self.nr_proc + self.tile_device_map[tile_index] = devices[ + i % len(devices) + ] # Checkoboard pattern, TODO: may not be optimal + + # Get my device + if self.nr_proc != len(self.devices): + raise ValueError( + f"Number of processes {self.nr_proc} does not equal number of devices {len(self.devices)}." + ) + self.device = self.devices[self.pid] + + # Make the tiles + self.tiles = {} + for tile_index in self.tile_process_map.keys(): + if self.pid == self.tile_process_map[tile_index]: + self.tiles[tile_index] = self.HostTile( + self.tile_shape, self.dtype, self.padding, self.codec + ) + + # Make GPU tiles for copying data between CPU and GPU + if self.nr_tiles % self.nr_compute_tiles != 0: + raise ValueError( + f"Number of tiles {self.nr_tiles} does not divide number of compute tiles {self.nr_compute_tiles}. This is used for asynchronous copies." + ) + compute_array_shape = [ + s + 2 * p for (s, p) in zip(self.tile_shape, self.padding) + ] + self.compute_tiles_htd = [] + self.compute_tiles_dth = [] + self.compute_streams_htd = [] + self.compute_streams_dth = [] + self.compute_arrays = [] + self.current_compute_index = 0 + with cp.cuda.Device(self.device): + for i in range(self.nr_compute_tiles): + # Make compute tiles for copying data + compute_tile = self.DeviceTile( + self.tile_shape, self.dtype, self.padding, self.codec + ) + self.compute_tiles_htd.append(compute_tile) + compute_tile = self.DeviceTile( + self.tile_shape, self.dtype, self.padding, self.codec + ) + self.compute_tiles_dth.append(compute_tile) + + # Make cupy stream + self.compute_streams_htd.append(cp.cuda.Stream(non_blocking=True)) + self.compute_streams_dth.append(cp.cuda.Stream(non_blocking=True)) + + # Make compute array + + self.compute_arrays.append(cp.empty(compute_array_shape, self.dtype)) + + # Make compute tile mappings + self.compute_tile_mapping_htd = {} + self.compute_tile_mapping_dth = {} + self.compute_stream_mapping_htd = {} + + def size(self): + """Return number of allocated bytes for all host tiles.""" + return sum([tile.size() for tile in self.tiles.values()]) + + def nbytes(self): + """Return number of bytes for all host tiles.""" + return sum([tile.nbytes for tile in self.tiles.values()]) + + def compression_ratio(self): + """Return the compression ratio for all host tiles.""" + return self.nbytes() / self.size() + + def compression_ratio(self): + """Return the compression ratio aggregated over all tiles.""" + + if self.codec is None: + return 1.0 + else: + total_bytes = 0 + total_uncompressed_bytes = 0 + for tile in self.tiles.values(): + ( + tile_total_bytes_uncompressed, + tile_total_bytes_compressed, + ) = tile.compression_ratio() + total_bytes += tile_total_bytes_compressed + total_uncompressed_bytes += tile_total_bytes_uncompressed + return total_uncompressed_bytes / total_bytes + + def update_compute_index(self): + """Update the current compute index.""" + self.current_compute_index = ( + self.current_compute_index + 1 + ) % self.nr_compute_tiles + + def _guess_next_tile_index(self, tile_index): + """Guess the next tile index to use for the compute array.""" + # TODO: This assumes access is sequential + tile_indices = list(self.tiles.keys()) + current_ind = tile_indices.index(tile_index) + next_ind = current_ind + 1 + if next_ind >= len(tile_indices): + return None + else: + return tile_indices[next_ind] + + def reset_queue_htd(self): + """Reset the queue for host to device copies.""" + + self.compute_tile_mapping_htd = {} + self.compute_stream_mapping_htd = {} + self.current_compute_index = 0 + + def managed_compute_tiles_htd(self, tile_index): + """Get the compute tiles needed for computation. + + Parameters + ---------- + tile_index : tuple + The tile index. + + Returns + ------- + compute_tile : ComputeTile + The compute tile needed for computation. + """ + + ################################################### + # TODO: This assumes access is sequential for tiles + ################################################### + + # Que up the next tiles + cur_tile_index = tile_index + cur_compute_index = self.current_compute_index + for i in range(self.nr_compute_tiles): + # Check if already in compute tile map and if not que it + if cur_tile_index not in self.compute_tile_mapping_htd.keys(): + # Get the store tile + tile = self.tiles[cur_tile_index] + + # Get the compute tile + compute_tile = self.compute_tiles_htd[cur_compute_index] + + # Get the compute stream + compute_stream = self.compute_streams_htd[cur_compute_index] + + # Copy the tile to the compute tile using the compute stream + with compute_stream: + tile.to_gpu_tile(compute_tile) + tile.to_gpu_tile(compute_tile) + + # Set the compute tile mapping + self.compute_tile_mapping_htd[cur_tile_index] = compute_tile + self.compute_stream_mapping_htd[cur_tile_index] = compute_stream + + # Update the tile index and compute index + cur_tile_index = self._guess_next_tile_index(cur_tile_index) + if cur_tile_index is None: + break + cur_compute_index = (cur_compute_index + 1) % self.nr_compute_tiles + + # Get the compute tile + self.compute_stream_mapping_htd[tile_index].synchronize() + compute_tile = self.compute_tile_mapping_htd[tile_index] + + # Pop the tile from the compute tile map + self.compute_tile_mapping_htd.pop(tile_index) + self.compute_stream_mapping_htd.pop(tile_index) + + # Return the compute tile + return compute_tile + + def get_compute_array(self, tile_index): + """Given a tile index, copy the tile to the compute array. + + Parameters + ---------- + tile_index : tuple + The tile index. + + Returns + ------- + compute_array : array + The compute array. + global_index : tuple + The lower bound index that the compute array corresponds to in the global array. + For example, if the compute array is the 0th tile and has padding 1, then the + global index will be (-1, -1, ..., -1). + """ + + # Get the compute tile + compute_tile = self.managed_compute_tiles_htd(tile_index) + + # Concatenate the sub-arrays to make the compute array + compute_tile.to_array(self.compute_arrays[self.current_compute_index]) + + # Return the compute array index in global array + global_index = tuple( + [i * s - p for (i, s, p) in zip(tile_index, self.tile_shape, self.padding)] + ) + + return self.compute_arrays[self.current_compute_index], global_index + + def set_tile(self, compute_array, tile_index): + """Given a tile index, copy the compute array to the tile. + + Parameters + ---------- + compute_array : array + The compute array. + tile_index : tuple + The tile index. + """ + + # Syncronize the current stream dth stream + stream = self.compute_streams_dth[self.current_compute_index] + stream.synchronize() + cp.cuda.get_current_stream().synchronize() + + # Set the compute tile to the correct one + compute_tile = self.compute_tiles_dth[self.current_compute_index] + + # Split the compute array into a tile + compute_tile.from_array(compute_array) + + # Syncronize the current stream and the compute stream + cp.cuda.get_current_stream().synchronize() + + # Copy the tile from the compute tile to the store tile + with stream: + compute_tile.to_cpu_tile(self.tiles[tile_index]) + compute_tile.to_cpu_tile(self.tiles[tile_index]) + + def update_padding(self): + """Perform a padding swap between neighboring tiles.""" + + # Get padding indices + pad_ind = self.compute_tiles_htd[0].pad_ind + + # Loop over tiles + comm_tag = 0 + for tile_index in self.tile_process_map.keys(): + # Loop over all padding + for pad_index in pad_ind: + # Get neighboring tile index + neigh_tile_index = tuple( + [ + (i + p) % s + for (i, p, s) in zip(tile_index, pad_index, self.tile_dims) + ] + ) + neigh_pad_index = tuple([-p for p in pad_index]) # flip + + # 4 cases: + # 1. the tile and neighboring tile are on the same process + # 2. the tile is on this process and the neighboring tile is on another process + # 3. the tile is on another process and the neighboring tile is on this process + # 4. the tile and neighboring tile are on different processes + + # Case 1: the tile and neighboring tile are on the same process + if ( + self.pid == self.tile_process_map[tile_index] + and self.pid == self.tile_process_map[neigh_tile_index] + ): + # Get the tile and neighboring tile + tile = self.tiles[tile_index] + neigh_tile = self.tiles[neigh_tile_index] + + # Get pointer to padding and neighboring padding + padding = tile._padding[pad_index] + neigh_padding = neigh_tile._buf_padding[neigh_pad_index] + + # Swap padding + tile._padding[pad_index] = neigh_padding + neigh_tile._buf_padding[neigh_pad_index] = padding + + # Case 2: the tile is on this process and the neighboring tile is on another process + if ( + self.pid == self.tile_process_map[tile_index] + and self.pid != self.tile_process_map[neigh_tile_index] + ): + # Get the tile and padding + tile = self.tiles[tile_index] + padding = tile._padding[pad_index] + + # Send padding to neighboring process + self.comm.Send( + padding, + dest=self.tile_process_map[neigh_tile_index], + tag=comm_tag, + ) + + # Case 3: the tile is on another process and the neighboring tile is on this process + if ( + self.pid != self.tile_process_map[tile_index] + and self.pid == self.tile_process_map[neigh_tile_index] + ): + # Get the neighboring tile and padding + neigh_tile = self.tiles[neigh_tile_index] + neigh_padding = neigh_tile._buf_padding[neigh_pad_index] + + # Receive padding from neighboring process + self.comm.Recv( + neigh_padding, + source=self.tile_process_map[tile_index], + tag=comm_tag, + ) + + # Case 4: the tile and neighboring tile are on different processes + if ( + self.pid != self.tile_process_map[tile_index] + and self.pid != self.tile_process_map[neigh_tile_index] + ): + pass + + # Increment the communication tag + comm_tag += 1 + + # Shuffle padding with buffers + for tile in self.tiles.values(): + tile.swap_buf_padding() + + def get_array(self): + """Get the full array out from all the sub-arrays. This should only be used for testing.""" + + # Get the full array + if self.comm.rank == 0: + array = np.ones(self.shape, dtype=self.dtype) + else: + array = None + + # Loop over tiles + comm_tag = 0 + for tile_index in self.tile_process_map.keys(): + # Set the center array in the full array + slice_index = tuple( + [ + slice(i * s, (i + 1) * s) + for (i, s) in zip(tile_index, self.tile_shape) + ] + ) + + # if tile on this process compute the center array + if self.comm.rank == self.tile_process_map[tile_index]: + # Get the tile + tile = self.tiles[tile_index] + + # Copy the tile to the compute tile + tile.to_gpu_tile(self.compute_tiles_htd[0]) + + # Get the compute array + self.compute_tiles_htd[0].to_array(self.compute_arrays[0]) + + # Get the center array + center_array = self.compute_arrays[0][tile._slice_center].get() + + # 4 cases: + # 1. the tile is on rank 0 and this process is rank 0 + # 2. the tile is on another rank and this process is rank 0 + # 3. the tile is on this rank and this process is not rank 0 + # 4. the tile is not on rank 0 and this process is not rank 0 + + # Case 1: the tile is on rank 0 + if self.comm.rank == 0 and self.tile_process_map[tile_index] == 0: + # Set the center array in the full array + array[slice_index] = center_array + + # Case 2: the tile is on another rank and this process is rank 0 + if self.comm.rank == 0 and self.tile_process_map[tile_index] != 0: + # Get the data from the other rank + center_array = np.empty(self.tile_shape, dtype=self.dtype) + self.comm.Recv( + center_array, source=self.tile_process_map[tile_index], tag=comm_tag + ) + + # Set the center array in the full array + array[slice_index] = center_array + + # Case 3: the tile is on this rank and this process is not rank 0 + if ( + self.comm.rank != 0 + and self.tile_process_map[tile_index] == self.comm.rank + ): + # Send the data to rank 0 + self.comm.Send(center_array, dest=0, tag=comm_tag) + + # Case 4: the tile is not on rank 0 and this process is not rank 0 + if self.comm.rank != 0 and self.tile_process_map[tile_index] != 0: + pass + + # Update the communication tag + comm_tag += 1 + + return array diff --git a/xlb/experimental/ooc/out_of_core.py b/xlb/experimental/ooc/out_of_core.py new file mode 100644 index 0000000..5faa422 --- /dev/null +++ b/xlb/experimental/ooc/out_of_core.py @@ -0,0 +1,110 @@ +# Out-of-core decorator for functions that take a lot of memory + +import functools +import warp as wp +import cupy as cp +import jax.dlpack as jdlpack +import jax +import numpy as np + +from xlb.experimental.ooc.ooc_array import OOCArray +from xlb.experimental.ooc.utils import _cupy_to_backend, _backend_to_cupy, _stream_to_backend + + +def OOCmap(comm, ref_args, add_index=False, backend="jax"): + """Decorator for out-of-core functions. + + Parameters + ---------- + comm : MPI communicator + The MPI communicator. (TODO add functionality) + ref_args : List[int] + The indices of the arguments that are OOC arrays to be written to by outputs of the function. + add_index : bool, optional + Whether to add the index of the global array to the arguments of the function. Default is False. + If true the function will take in a tuple of (array, index) instead of just the array. + backend : str, optional + The backend to use for the function. Default is 'jax'. + Options are 'jax' and 'warp'. + give_stream : bool, optional + Whether to give the function a stream to run on. Default is False. + If true the function will take in a last argument of the stream to run on. + """ + + def decorator(func): + def wrapper(*args): + # Get list of OOC arrays + ooc_array_args = [] + for arg in args: + if isinstance(arg, OOCArray): + ooc_array_args.append(arg) + + # Check that all ooc arrays are compatible + # TODO: Add better checks + for ooc_array in ooc_array_args: + if ooc_array_args[0].tile_dims != ooc_array.tile_dims: + raise ValueError( + f"Tile dimensions of ooc arrays do not match. {ooc_array_args[0].tile_dims} != {ooc_array.tile_dims}" + ) + + # Apply the function to each of the ooc arrays + for tile_index in ooc_array_args[0].tiles.keys(): + # Run through args and kwargs and replace ooc arrays with their compute arrays + new_args = [] + for arg in args: + if isinstance(arg, OOCArray): + # Get the compute array (this performs all the memory copies) + compute_array, global_index = arg.get_compute_array(tile_index) + + # Convert to backend array + compute_array = _cupy_to_backend(compute_array, backend) + + # Add index to the arguments if requested + if add_index: + compute_array = (compute_array, global_index) + + new_args.append(compute_array) + else: + new_args.append(arg) + + # Run the function + results = func(*new_args) + + # Convert the results to a tuple if not already + if not isinstance(results, tuple): + results = (results,) + + # Convert the results back to cupy arrays + results = tuple( + [_backend_to_cupy(result, backend) for result in results] + ) + + # Write the results back to the ooc array + for arg_index, result in zip(ref_args, results): + args[arg_index].set_tile(result, tile_index) + + # Update the ooc arrays compute tile index + for ooc_array in ooc_array_args: + ooc_array.update_compute_index() + + # Syncronize all processes + cp.cuda.Device().synchronize() + comm.Barrier() + + # Update the ooc arrays padding + for i, ooc_array in enumerate(ooc_array_args): + if i in ref_args: + ooc_array.update_padding() + + # Reset que + ooc_array.reset_queue_htd() + + # Return OOC arrays + if len(ref_args) == 1: + return ooc_array_args[ref_args[0]] + else: + return tuple([args[arg_index] for arg_index in ref_args]) + + return wrapper + + return decorator diff --git a/xlb/experimental/ooc/tiles/__init__.py b/xlb/experimental/ooc/tiles/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xlb/experimental/ooc/tiles/compressed_tile.py b/xlb/experimental/ooc/tiles/compressed_tile.py new file mode 100644 index 0000000..415f83b --- /dev/null +++ b/xlb/experimental/ooc/tiles/compressed_tile.py @@ -0,0 +1,273 @@ +import numpy as np +import cupy as cp +import itertools +from dataclasses import dataclass +import warnings +import time + +try: + from kvikio._lib.arr import asarray +except ImportError: + warnings.warn("kvikio not installed. Compression will not work.") + +from xlb.experimental.ooc.tiles.tile import Tile +from xlb.experimental.ooc.tiles.dense_tile import DenseGPUTile +from xlb.experimental.ooc.tiles.dynamic_array import DynamicPinnedArray + + +def _decode(comp_array, dest_array, codec): + """ + Decompresses comp_array into dest_array. + + Parameters + ---------- + comp_array : cupy array + The compressed array to be decompressed. Data type uint8. + dest_array : cupy array + The storage array for the decompressed data. + codec : Codec + The codec to use for decompression. For example, `kvikio.nvcomp.CascadedManager`. + + """ + + # Store needed information + dtype = dest_array.dtype + shape = dest_array.shape + + # Reshape dest_array to match to make into buffer + dest_array = dest_array.view(cp.uint8).reshape(-1) + + # Decompress + codec._manager.decompress(asarray(dest_array), asarray(comp_array)) + + return dest_array.view(dtype).reshape(shape) + + +def _encode(array, dest_array, codec): + """ + Compresses array into dest_array. + + Parameters + ---------- + array : cupy array + The array to be compressed. + dest_array : cupy array + The storage array for the compressed data. Data type uint8. + codec : Codec + The codec to use for compression. For example, `kvikio.nvcomp.CascadedManager`. + """ + + # Make sure array is contiguous + array = cp.ascontiguousarray(array) + + # Configure compression + codec._manager.configure_compression(array.nbytes) + + # Compress + size = codec._manager.compress(asarray(array), asarray(dest_array)) + return size + + +class CompressedTile(Tile): + """A Tile where the data is stored in compressed form.""" + + def __init__(self, shape, dtype, padding, codec): + super().__init__(shape, dtype, padding, codec) + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + raise NotImplementedError + + def to_array(self, array): + """Copy a tile to a full array.""" + # Only implemented for GPU tiles + raise NotImplementedError + + def from_array(self, array): + """Copy a full array to tile.""" + # Only implemented for GPU tiles + raise NotImplementedError + + def compression_ratio(self): + """Returns the compression ratio of the tile.""" + # Get total number of bytes in tile + total_bytes = self._array.size() + for pad_ind in self.pad_ind: + total_bytes += self._padding[pad_ind].size() + + # Get total number of bytes in uncompressed tile + total_bytes_uncompressed = np.prod(self.shape) * self.dtype_itemsize + for pad_ind in self.pad_ind: + total_bytes_uncompressed += ( + np.prod(self._padding_shape[pad_ind]) * self.dtype_itemsize + ) + + # Return compression ratio + return total_bytes_uncompressed, total_bytes + + +class CompressedCPUTile(CompressedTile): + """A tile with cells on the CPU.""" + + def __init__(self, shape, dtype, padding, codec): + super().__init__(shape, dtype, padding, codec) + + def size(self): + """Returns the size of the tile in bytes.""" + size = self._array.size() + for pad_ind in self.pad_ind: + size += self._padding[pad_ind].size() + return size + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + # Make zero array + cp_array = cp.zeros(shape, dtype=self.dtype) + + # compress array + codec = self.codec() + compressed_cp_array = codec.compress(cp_array) + + # Allocate array on CPU + array = DynamicPinnedArray(compressed_cp_array.nbytes) + + # Copy array + compressed_cp_array.get(out=array.array) + + # add nbytes + self.nbytes += cp_array.nbytes + + # delete GPU arrays + del compressed_cp_array + del cp_array + + return array + + def to_gpu_tile(self, dst_gpu_tile): + """Copy tile to a GPU tile.""" + + # Check tile is Compressed + assert isinstance( + dst_gpu_tile, CompressedGPUTile + ), "Destination tile must be a CompressedGPUTile" + + # Copy array + dst_gpu_tile._array[: len(self._array.array)].set(self._array.array) + dst_gpu_tile._array_bytes = self._array.nbytes + + # Copy padding + for pad_ind in self.pad_ind: + dst_gpu_tile._padding[pad_ind][: len(self._padding[pad_ind].array)].set( + self._padding[pad_ind].array + ) + dst_gpu_tile._padding_bytes[pad_ind] = self._padding[pad_ind].nbytes + + +class CompressedGPUTile(CompressedTile): + """A sub-array with ghost cells on the GPU.""" + + def __init__(self, shape, dtype, padding, codec): + super().__init__(shape, dtype, padding, codec) + + # Allocate dense GPU tile + self.dense_gpu_tile = DenseGPUTile(shape, dtype, padding) + + # Set bytes for each array and padding + self._array_bytes = -1 + self._padding_bytes = {} + for pad_ind in self.pad_ind: + self._padding_bytes[pad_ind] = -1 + + # Set codec for each array and padding + self._array_codec = None + self._padding_codec = {} + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + nbytes = np.prod(shape) * self.dtype_itemsize + codec = self.codec() + max_compressed_buffer = codec._manager.configure_compression(nbytes)[ + "max_compressed_buffer_size" + ] + array = cp.zeros((max_compressed_buffer,), dtype=np.uint8) + return array + + def to_array(self, array): + """Copy a tile to a full array.""" + + # Copy center array + if self._array_codec is None: + self._array_codec = self.codec() + self._array_codec._manager.configure_decompression_with_compressed_buffer( + asarray(self._array[: self._array_bytes]) + ) + self._array_codec.decompression_config = self._array_codec._manager.configure_decompression_with_compressed_buffer( + asarray(self._array[: self._array_bytes]) + ) + self.dense_gpu_tile._array = _decode( + self._array[: self._array_bytes], + self.dense_gpu_tile._array, + self._array_codec, + ) + array[self._slice_center] = self.dense_gpu_tile._array + + # Copy padding + for pad_ind in self.pad_ind: + if pad_ind not in self._padding_codec: + self._padding_codec[pad_ind] = self.codec() + self._padding_codec[pad_ind].decompression_config = self._padding_codec[ + pad_ind + ]._manager.configure_decompression_with_compressed_buffer( + asarray(self._padding[pad_ind][: self._padding_bytes[pad_ind]]) + ) + self.dense_gpu_tile._padding[pad_ind] = _decode( + self._padding[pad_ind][: self._padding_bytes[pad_ind]], + self.dense_gpu_tile._padding[pad_ind], + self._padding_codec[pad_ind], + ) + array[self._slice_padding_to_array[pad_ind]] = self.dense_gpu_tile._padding[ + pad_ind + ] + + def from_array(self, array): + """Copy a full array to tile.""" + + # Copy center array + if self._array_codec is None: + self._array_codec = self.codec() + self._array_codec.configure_compression(self._array.nbytes) + self._array_bytes = _encode( + array[self._slice_center], self._array, self._array_codec + ) + + # Copy padding + for pad_ind in self.pad_ind: + if pad_ind not in self._padding_codec: + self._padding_codec[pad_ind] = self.codec() + self._padding_codec[pad_ind].configure_compression( + self._padding[pad_ind].nbytes + ) + self._padding_bytes[pad_ind] = _encode( + array[self._slice_array_to_padding[pad_ind]], + self._padding[pad_ind], + self._padding_codec[pad_ind], + ) + + def to_cpu_tile(self, dst_cpu_tile): + """Copy tile to a CPU tile.""" + + # Check tile is Compressed + assert isinstance( + dst_cpu_tile, CompressedCPUTile + ), "Destination tile must be a CompressedCPUTile" + + # Copy array + dst_cpu_tile._array.resize(self._array_bytes) + self._array[: self._array_bytes].get(out=dst_cpu_tile._array.array) + + # Copy padding + for pad_ind in self.pad_ind: + dst_cpu_tile._padding[pad_ind].resize(self._padding_bytes[pad_ind]) + self._padding[pad_ind][: self._padding_bytes[pad_ind]].get( + out=dst_cpu_tile._padding[pad_ind].array + ) diff --git a/xlb/experimental/ooc/tiles/dense_tile.py b/xlb/experimental/ooc/tiles/dense_tile.py new file mode 100644 index 0000000..8a303e4 --- /dev/null +++ b/xlb/experimental/ooc/tiles/dense_tile.py @@ -0,0 +1,96 @@ +import numpy as np +import cupy as cp +import itertools +from dataclasses import dataclass + +from xlb.experimental.ooc.tiles.tile import Tile + + +class DenseTile(Tile): + """A Tile where the data is stored in a dense array of the requested dtype.""" + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + raise NotImplementedError + + def to_array(self, array): + """Copy a tile to a full array.""" + # TODO: This can be done with a single kernel call, profile to see if it is faster and needs to be done. + + # Copy center array + array[self._slice_center] = self._array + + # Copy padding + for pad_ind in self.pad_ind: + array[self._slice_padding_to_array[pad_ind]] = self._padding[pad_ind] + + def from_array(self, array): + """Copy a full array to tile.""" + # TODO: This can be done with a single kernel call, profile to see if it is faster and needs to be done. + + # Copy center array + self._array[...] = array[self._slice_center] + + # Copy padding + for pad_ind in self.pad_ind: + self._padding[pad_ind][...] = array[self._slice_array_to_padding[pad_ind]] + + +class DenseCPUTile(DenseTile): + """A dense tile with cells on the CPU.""" + + def __init__(self, shape, dtype, padding, codec=None): + super().__init__(shape, dtype, padding, None) + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + # TODO: Seems hacky, but it works. Is there a better way? + mem = cp.cuda.alloc_pinned_memory(np.prod(shape) * self.dtype_itemsize) + array = np.frombuffer(mem, dtype=self.dtype, count=np.prod(shape)).reshape( + shape + ) + self.nbytes += mem.size() + return array + + def to_gpu_tile(self, dst_gpu_tile): + """Copy tile to a GPU tile.""" + + # Check that the destination tile is on the GPU + assert isinstance(dst_gpu_tile, DenseGPUTile), "Destination tile must be on GPU" + + # Copy array + dst_gpu_tile._array.set(self._array) + + # Copy padding + for src_array, dst_gpu_array in zip( + self._padding.values(), dst_gpu_tile._padding.values() + ): + dst_gpu_array.set(src_array) + + +class DenseGPUTile(DenseTile): + """A sub-array with ghost cells on the GPU.""" + + def __init__(self, shape, dtype, padding, codec=None): + super().__init__(shape, dtype, padding, None) + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + array = cp.zeros(shape, dtype=self.dtype) + self.nbytes += array.nbytes + return array + + def to_cpu_tile(self, dst_cpu_tile): + """Copy tile to a CPU tile.""" + + # Check that the destination tile is on the CPU + assert isinstance(dst_cpu_tile, DenseCPUTile), "Destination tile must be on CPU" + + # Copy arra + self._array.get(out=dst_cpu_tile._array) + + # Copy padding + for src_array, dst_array in zip( + self._padding.values(), dst_cpu_tile._padding.values() + ): + src_array.get(out=dst_array) diff --git a/xlb/experimental/ooc/tiles/dynamic_array.py b/xlb/experimental/ooc/tiles/dynamic_array.py new file mode 100644 index 0000000..2b05b2e --- /dev/null +++ b/xlb/experimental/ooc/tiles/dynamic_array.py @@ -0,0 +1,72 @@ +# Dynamic array class for pinned memory allocation + +import math +import cupy as cp +import numpy as np +import time + + +class DynamicArray: + """ + Dynamic pinned memory array class. + + Attributes + ---------- + nbytes : int + The number of bytes in the array. + bytes_resize : int + The number of bytes to resize the array by if the number of bytes requested exceeds the allocated number of bytes. + """ + + def __init__(self, nbytes, bytes_resize_factor=0.025): + # Set the number of bytes + self.nbytes = nbytes + self.bytes_resize_factor = bytes_resize_factor + self.bytes_resize = math.ceil(bytes_resize_factor * nbytes) + + # Set the number of bytes + self.allocated_bytes = math.ceil(nbytes / self.bytes_resize) * self.bytes_resize + + +class DynamicPinnedArray(DynamicArray): + def __init__(self, nbytes, bytes_resize_factor=0.05): + super().__init__(nbytes, bytes_resize_factor) + + # Allocate the memory + self.mem = cp.cuda.alloc_pinned_memory(self.allocated_bytes) + + # Make np array that points to the pinned memory + self.array = np.frombuffer(self.mem, dtype=np.uint8, count=int(self.nbytes)) + + def size(self): + return self.mem.size() + + def resize(self, nbytes): + # Set the new number of bytes + self.nbytes = nbytes + + # Check if the number of bytes requested is less than 2xbytes_resize or if the number of bytes requested exceeds the allocated number of bytes + if ( + nbytes < (self.allocated_bytes - 2 * self.bytes_resize) + or nbytes > self.allocated_bytes + ): + ## Free the memory + # del self.mem + + # Set the new number of allocated bytes + self.allocated_bytes = ( + math.ceil(nbytes / self.bytes_resize) * self.bytes_resize + ) + + # Allocate the memory + self.mem = cp.cuda.alloc_pinned_memory(self.allocated_bytes) + + # Make np array that points to the pinned memory + self.array = np.frombuffer(self.mem, dtype=np.uint8, count=int(self.nbytes)) + + # Set new resize number of bytes + self.bytes_resize = math.ceil(self.bytes_resize_factor * nbytes) + + # Otherwise change numpy array size + else: + self.array = np.frombuffer(self.mem, dtype=np.uint8, count=int(self.nbytes)) diff --git a/xlb/experimental/ooc/tiles/tile.py b/xlb/experimental/ooc/tiles/tile.py new file mode 100644 index 0000000..90c3334 --- /dev/null +++ b/xlb/experimental/ooc/tiles/tile.py @@ -0,0 +1,115 @@ +import numpy as np +import cupy as cp +import itertools +from dataclasses import dataclass + + +class Tile: + """Base class for Tile with ghost cells. This tile is used to build a distributed array. + + Attributes + ---------- + shape : tuple + Shape of the tile. This will be the shape of the array without padding/ghost cells. + dtype : cp.dtype + Data type the tile represents. Note that the data data may be stored in a different + data type. For example, if it is stored in compressed form. + padding : tuple + Number of padding/ghost cells in each dimension. + """ + + def __init__(self, shape, dtype, padding, codec=None): + # Store parameters + self.shape = shape + self.dtype = dtype + self.padding = padding + self.dtype_itemsize = cp.dtype(self.dtype).itemsize + self.nbytes = 0 # Updated when array is allocated + self.codec = ( + codec # Codec to use for compression TODO: Find better abstraction for this + ) + + # Make center array + self._array = self.allocate_array(self.shape) + + # Make padding indices + pad_dir = [] + for i in range(len(self.shape)): + if self.padding[i] == 0: + pad_dir.append((0,)) + else: + pad_dir.append((-1, 0, 1)) + self.pad_ind = list(itertools.product(*pad_dir)) + self.pad_ind.remove((0,) * len(self.shape)) + + # Make padding and padding buffer arrays + self._padding = {} + self._buf_padding = {} + for ind in self.pad_ind: + # determine array shape + shape = [] + for i in range(len(self.shape)): + if ind[i] == -1 or ind[i] == 1: + shape.append(self.padding[i]) + else: + shape.append(self.shape[i]) + + # Make padding and padding buffer + self._padding[ind] = self.allocate_array(shape) + self._buf_padding[ind] = self.allocate_array(shape) + + # Get slicing for array copies + self._slice_center = tuple( + [slice(pad, pad + shape) for (pad, shape) in zip(self.padding, self.shape)] + ) + self._slice_padding_to_array = {} + self._slice_array_to_padding = {} + self._padding_shape = {} + for pad_ind in self.pad_ind: + slice_padding_to_array = [] + slice_array_to_padding = [] + padding_shape = [] + for pad, ind, s in zip(self.padding, pad_ind, self.shape): + if ind == -1: + slice_padding_to_array.append(slice(0, pad)) + slice_array_to_padding.append(slice(pad, 2 * pad)) + padding_shape.append(pad) + elif ind == 0: + slice_padding_to_array.append(slice(pad, s + pad)) + slice_array_to_padding.append(slice(pad, s + pad)) + padding_shape.append(s) + else: + slice_padding_to_array.append(slice(s + pad, s + 2 * pad)) + slice_array_to_padding.append(slice(s, s + pad)) + padding_shape.append(pad) + self._slice_padding_to_array[pad_ind] = tuple(slice_padding_to_array) + self._slice_array_to_padding[pad_ind] = tuple(slice_array_to_padding) + self._padding_shape[pad_ind] = tuple(padding_shape) + + def size(self): + """Returns the number of bytes allocated for the tile.""" + raise NotImplementedError + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + raise NotImplementedError + + def copy_tile(self, dst_tile): + """Copy a tile from one tile to another.""" + raise NotImplementedError + + def to_array(self, array): + """Copy a tile to a full array.""" + raise NotImplementedError + + def from_array(self, array): + """Copy a full array to a tile.""" + raise NotImplementedError + + def swap_buf_padding(self): + """Swap the padding buffer pointer with the padding pointer.""" + for index in self.pad_ind: + (self._buf_padding[index], self._padding[index]) = ( + self._padding[index], + self._buf_padding[index], + ) diff --git a/xlb/experimental/ooc/utils.py b/xlb/experimental/ooc/utils.py new file mode 100644 index 0000000..f607128 --- /dev/null +++ b/xlb/experimental/ooc/utils.py @@ -0,0 +1,79 @@ +import warp as wp +import cupy as cp +import jax.dlpack as jdlpack +import jax + + +def _cupy_to_backend(cupy_array, backend): + """ + Convert cupy array to backend array + + Parameters + ---------- + cupy_array : cupy.ndarray + Input cupy array + backend : str + Backend to convert to. Options are "jax", "warp", or "cupy" + """ + + # Convert cupy array to backend array + dl_array = cupy_array.toDlpack() + if backend == "jax": + backend_array = jdlpack.from_dlpack(dl_array) + elif backend == "warp": + backend_array = wp.from_dlpack(dl_array) + elif backend == "cupy": + backend_array = cupy_array + else: + raise ValueError(f"Backend {backend} not supported") + return backend_array + + +def _backend_to_cupy(backend_array, backend): + """ + Convert backend array to cupy array + + Parameters + ---------- + backend_array : backend.ndarray + Input backend array + backend : str + Backend to convert from. Options are "jax", "warp", or "cupy" + """ + + # Convert backend array to cupy array + if backend == "jax": + (jax.device_put(0.0) + 0).block_until_ready() + dl_array = jdlpack.to_dlpack(backend_array) + elif backend == "warp": + dl_array = wp.to_dlpack(backend_array) + elif backend == "cupy": + return backend_array + else: + raise ValueError(f"Backend {backend} not supported") + cupy_array = cp.fromDlpack(dl_array) + return cupy_array + + +def _stream_to_backend(stream, backend): + """ + Convert cupy stream to backend stream + + Parameters + ---------- + stream : cupy.cuda.Stream + Input cupy stream + backend : str + Backend to convert to. Options are "jax", "warp", or "cupy" + """ + + # Convert stream to backend stream + if backend == "jax": + raise ValueError("Jax currently does not support streams") + elif backend == "warp": + backend_stream = wp.Stream(cuda_stream=stream.ptr) + elif backend == "cupy": + backend_stream = stream + else: + raise ValueError(f"Backend {backend} not supported") + return backend_stream diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py new file mode 100644 index 0000000..1cf32f9 --- /dev/null +++ b/xlb/operator/__init__.py @@ -0,0 +1 @@ +from xlb.operator.operator import Operator diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py new file mode 100644 index 0000000..9937bc3 --- /dev/null +++ b/xlb/operator/boundary_condition/__init__.py @@ -0,0 +1,5 @@ +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition, ImplementationStep +from xlb.operator.boundary_condition.full_bounce_back import FullBounceBack +from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBack +from xlb.operator.boundary_condition.do_nothing import DoNothing +from xlb.operator.boundary_condition.equilibrium_boundary import EquilibriumBoundary diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py new file mode 100644 index 0000000..7e8909a --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -0,0 +1,100 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.operator.operator import Operator +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend + +# Enum for implementation step +class ImplementationStep(Enum): + COLLISION = 1 + STREAMING = 2 + +class BoundaryCondition(Operator): + """ + Base class for boundary conditions in a LBM simulation. + """ + + def __init__( + self, + set_boundary, + implementation_step: ImplementationStep, + velocity_set: VelocitySet, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + # Set implementation step + self.implementation_step = implementation_step + + # Set boundary function + if compute_backend == ComputeBackend.JAX: + self.set_boundary = set_boundary + else: + raise NotImplementedError + + @classmethod + def from_indices(cls, indices, implementation_step: ImplementationStep): + """ + Creates a boundary condition from a list of indices. + """ + raise NotImplementedError + + @partial(jit, static_argnums=(0,)) + def apply_jax(self, f_pre, f_post, mask, velocity_set: VelocitySet): + """ + Applies the boundary condition. + """ + pass + + @staticmethod + def _indices_to_tuple(indices): + """ + Converts a tensor of indices to a tuple for indexing + TODO: Might be better to index + """ + return tuple([indices[:, i] for i in range(indices.shape[1])]) + + @staticmethod + def _set_boundary_from_indices(indices): + """ + This create the standard set_boundary function from a list of indices. + `boundary_id` is set to `id_number` at the indices and `mask` is set to `True` at the indices. + Many boundary conditions can be created from this function however some may require a custom function such as + HalfwayBounceBack. + """ + + # Create a mask function + def set_boundary(ijk, boundary_id, mask, id_number): + """ + Sets the mask id for the boundary condition. + + Parameters + ---------- + ijk : jnp.ndarray + Array of shape (N, N, N, 3) containing the meshgrid of lattice points. + boundary_id : jnp.ndarray + Array of shape (N, N, N) containing the boundary id. This will be modified in place and returned. + mask : jnp.ndarray + Array of shape (N, N, N, Q) containing the mask. This will be modified in place and returned. + """ + + # Get local indices from the meshgrid and the indices + local_indices = ijk[BoundaryCondition._indices_to_tuple(indices)] + + # Set the boundary id + boundary_id = boundary_id.at[BoundaryCondition._indices_to_tuple(indices)].set(id_number) + + # Set the mask + mask = mask.at[BoundaryCondition._indices_to_tuple(indices)].set(True) + + return boundary_id, mask + + return set_boundary diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py new file mode 100644 index 0000000..f8f28ed --- /dev/null +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -0,0 +1,56 @@ +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream.stream import Stream +from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class DoNothing(BoundaryCondition): + """ + A boundary condition that skips the streaming step. + """ + + def __init__( + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.STREAMING, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @classmethod + def from_indices( + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + return cls( + set_boundary=cls._set_boundary_from_indices(indices), + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) + f = lax.select(do_nothing, f_pre, f_post) + return f diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py new file mode 100644 index 0000000..f615a53 --- /dev/null +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -0,0 +1,69 @@ +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream.stream import Stream +from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class EquilibriumBoundary(BoundaryCondition): + """ + A boundary condition that skips the streaming step. + """ + + def __init__( + self, + set_boundary, + rho: float, + u: tuple[float, float], + equilibrium: Equilibrium, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.STREAMING, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + self.f = equilibrium(rho, u) + + @classmethod + def from_indices( + cls, + indices, + rho: float, + u: tuple[float, float], + equilibrium: Equilibrium, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + return cls( + set_boundary=cls._set_boundary_from_indices(indices), + rho=rho, + u=u, + equilibrium=equilibrium, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) + equilibrium_f = jnp.repeat(self.f[None, ...], boundary.shape[0], axis=0) + equilibrium_f = jnp.repeat(equilibrium_f[:, None], boundary.shape[1], axis=1) + equilibrium_f = jnp.repeat(equilibrium_f[:, :, None], boundary.shape[2], axis=2) + f = lax.select(equilibrium_mask, equilibrium_f, f_post) + return f diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py new file mode 100644 index 0000000..fc883c8 --- /dev/null +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -0,0 +1,57 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class FullBounceBack(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + def __init__( + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.COLLISION, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @classmethod + def from_indices( + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + return cls( + set_boundary=cls._set_boundary_from_indices(indices), + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + flip = jnp.repeat(boundary[..., jnp.newaxis], self.velocity_set.q, axis=-1) + flipped_f = lax.select(flip, f_pre[..., self.velocity_set.opp_indices], f_post) + return flipped_f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py new file mode 100644 index 0000000..3b5b6de --- /dev/null +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -0,0 +1,97 @@ +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream.stream import Stream +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class HalfwayBounceBack(BoundaryCondition): + """ + Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + def __init__( + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.STREAMING, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @classmethod + def from_indices( + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + # Make stream operator to get edge points + stream = Stream(velocity_set=velocity_set) + + # Create a mask function + def set_boundary(ijk, boundary_id, mask, id_number): + """ + Sets the mask id for the boundary condition. + Halfway bounce-back is implemented by setting the mask to True for points in the boundary, + then streaming the mask to get the points on the surface. + + Parameters + ---------- + ijk : jnp.ndarray + Array of shape (N, N, N, 3) containing the meshgrid of lattice points. + boundary_id : jnp.ndarray + Array of shape (N, N, N) containing the boundary id. This will be modified in place and returned. + mask : jnp.ndarray + Array of shape (N, N, N, Q) containing the mask. This will be modified in place and returned. + """ + + # Get local indices from the meshgrid and the indices + local_indices = ijk[tuple(s[:, 0] for s in jnp.split(indices, velocity_set.d, axis=1))] + + # Make mask then stream to get the edge points + pre_stream_mask = jnp.zeros_like(mask) + pre_stream_mask = pre_stream_mask.at[tuple([s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)])].set(True) + post_stream_mask = stream(pre_stream_mask) + + # Set false for points inside the boundary + post_stream_mask = post_stream_mask.at[post_stream_mask[..., 0] == True].set(False) + + # Get indices on edges + edge_indices = jnp.argwhere(post_stream_mask) + + # Set the boundary id + boundary_id = boundary_id.at[tuple([s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)])].set(id_number) + + # Set the mask + mask = mask.at[edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], :].set(post_stream_mask[edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], :]) + + return boundary_id, mask + + return cls( + set_boundary=set_boundary, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + flip_mask = boundary[..., jnp.newaxis] & mask + flipped_f = lax.select(flip_mask, f_pre[..., self.velocity_set.opp_indices], f_post) + return flipped_f diff --git a/xlb/operator/collision/__init__.py b/xlb/operator/collision/__init__.py new file mode 100644 index 0000000..77395e6 --- /dev/null +++ b/xlb/operator/collision/__init__.py @@ -0,0 +1,3 @@ +from xlb.operator.collision.collision import Collision +from xlb.operator.collision.bgk import BGK +from xlb.operator.collision.kbc import KBC diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py new file mode 100644 index 0000000..19b5846 --- /dev/null +++ b/xlb/operator/collision/bgk.py @@ -0,0 +1,109 @@ +""" +BGK collision operator for LBM. +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +from numba import cuda, float32 + +from xlb.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.collision.collision import Collision + + +class BGK(Collision): + """ + BGK collision operator for LBM. + + The BGK collision operator is the simplest collision operator for LBM. + It is based on the Bhatnagar-Gross-Krook approximation to the Boltzmann equation. + Reference: https://en.wikipedia.org/wiki/Bhatnagar%E2%80%93Gross%E2%80%93Krook_operator + """ + + def __init__( + self, + omega: float, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__( + omega=omega, + velocity_set=velocity_set, + compute_backend=compute_backend + ) + + @partial(jit, static_argnums=(0), donate_argnums=(1,2,3,4)) + def apply_jax( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + u : jnp.ndarray, + ): + """ + BGK collision step for lattice. + + The collision step is where the main physics of the LBM is applied. In the BGK approximation, + the distribution function is relaxed towards the equilibrium distribution function. + + Parameters + ---------- + f : jax.numpy.ndarray + The distribution function + feq : jax.numpy.ndarray + The equilibrium distribution function + rho : jax.numpy.ndarray + The macroscopic density + u : jax.numpy.ndarray + The macroscopic velocity + + """ + fneq = f - feq + fout = f - self.omega * fneq + return fout + + def construct_numba(self): + """ + Numba implementation of the collision operator. + + Returns + ------- + _collision : numba.cuda.jit + The compiled numba function for the collision operator. + """ + + # Get needed parameters for numba function + omega = self.omega + omega = float32(omega) + + # Make numba function + @cuda.jit(device=True) + def _collision(f, feq, rho, u, fout): + """ + Numba BGK collision step for lattice. + + The collision step is where the main physics of the LBM is applied. In the BGK approximation, + the distribution function is relaxed towards the equilibrium distribution function. + + Parameters + ---------- + f : cuda.local.array + The distribution function + feq : cuda.local.array + The equilibrium distribution function + rho : cuda.local.array + The macroscopic density + u : cuda.local.array + The macroscopic velocity + fout : cuda.local.array + The output distribution function + """ + + # Relaxation + for i in range(f.shape[0]): + fout[i] = f[i] - omega * (f[i] - feq[i]) + + return fout + + return _collision diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py new file mode 100644 index 0000000..728f40c --- /dev/null +++ b/xlb/operator/collision/collision.py @@ -0,0 +1,48 @@ +""" +Base class for Collision operators +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +import numba + +from xlb.compute_backend import ComputeBackend +from xlb.velocity_set import VelocitySet +from xlb.operator import Operator + + +class Collision(Operator): + """ + Base class for collision operators. + + This class defines the collision step for the Lattice Boltzmann Method. + + Parameters + ---------- + omega : float + Relaxation parameter for collision step. Default value is 0.6. + shear : bool + Flag to indicate whether the collision step requires the shear stress. + """ + + def __init__( + self, + omega: float, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + self.omega = omega + + def apply_jax(self, f, feq, rho, u): + """ + Jax implementation of collision step. + """ + raise NotImplementedError("Child class must implement apply_jax.") + + def construct_numba(self, velocity_set: VelocitySet, dtype=numba.float32): + """ + Construct numba implementation of collision step. + """ + raise NotImplementedError("Child class must implement construct_numba.") diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py new file mode 100644 index 0000000..b5567bb --- /dev/null +++ b/xlb/operator/collision/kbc.py @@ -0,0 +1,203 @@ +""" +KBC collision operator for LBM. +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +from numba import cuda, float32 + +from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 +from xlb.compute_backend import ComputeBackend +from xlb.operator.collision.collision import Collision + + +class KBC(Collision): + """ + KBC collision operator for LBM. + + This class implements the Karlin-Bösch-Chikatamarla (KBC) model for the collision step in the Lattice Boltzmann Method. + """ + + def __init__( + self, + omega, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__( + omega=omega, + velocity_set=velocity_set, + compute_backend=compute_backend + ) + self.epsilon = 1e-32 + self.beta = self.omega * 0.5 + self.inv_beta = 1.0 / self.beta + + @partial(jit, static_argnums=(0,), donate_argnums=(1,2,3,4)) + def apply_jax( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + u: jnp.ndarray, + ): + """ + KBC collision step for lattice. + + Parameters + ---------- + f : jax.numpy.array + Distribution function. + feq : jax.numpy.array + Equilibrium distribution function. + rho : jax.numpy.array + Density. + u : jax.numpy.array + Velocity. + """ + + # Compute shear TODO: Generalize this and possibly make it an operator or something + fneq = f - feq + if isinstance(self.velocity_set, D2Q9): + shear = self.decompose_shear_d2q9_jax(fneq) + delta_s = shear * rho / 4.0 # TODO: Check this + elif isinstance(self.velocity_set, D3Q27): + shear = self.decompose_shear_d3q27_jax(fneq) + delta_s = shear * rho + + # Perform collision + delta_h = fneq - delta_s + gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product( + delta_s, delta_h, feq + ) / (self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq)) + + fout = f - self.beta * (2.0 * delta_s + gamma[..., None] * delta_h) + + return fout + + @partial(jit, static_argnums=(0,), inline=True) + def entropic_scalar_product( + self, + x: jnp.ndarray, + y: jnp.ndarray, + feq: jnp.ndarray + ): + """ + Compute the entropic scalar product of x and y to approximate gamma in KBC. + + Returns + ------- + jax.numpy.array + Entropic scalar product of x, y, and feq. + """ + return jnp.sum(x * y / feq, axis=-1) + + @partial(jit, static_argnums=(0, 2), donate_argnums=(1,)) + def momentum_flux_jax( + self, + fneq: jnp.ndarray, + ): + """ + This function computes the momentum flux, which is the product of the non-equilibrium + distribution functions (fneq) and the lattice moments (cc). + + The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann + Method (LBM). + + # TODO: probably move this to equilibrium calculation + + Parameters + ---------- + fneq: jax.numpy.ndarray + The non-equilibrium distribution functions. + + Returns + ------- + jax.numpy.ndarray + The computed momentum flux. + """ + + return jnp.dot(fneq, jnp.array(self.velocity_set.cc, dtype=fneq.dtype)) + + + @partial(jit, static_argnums=(0, 2), inline=True) + def decompose_shear_d3q27_jax(self, fneq): + """ + Decompose fneq into shear components for D3Q27 lattice. + + Parameters + ---------- + fneq : jax.numpy.ndarray + Non-equilibrium distribution function. + + Returns + ------- + jax.numpy.ndarray + Shear components of fneq. + """ + + # Calculate the momentum flux + Pi = self.momentum_flux_jax(fneq) + Nxz = Pi[..., 0] - Pi[..., 5] + Nyz = Pi[..., 3] - Pi[..., 5] + + # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) + s = jnp.zeros_like(fneq) + s = s.at[..., 9].set((2.0 * Nxz - Nyz) / 6.0) + s = s.at[..., 18].set((2.0 * Nxz - Nyz) / 6.0) + s = s.at[..., 3].set((-Nxz + 2.0 * Nyz) / 6.0) + s = s.at[..., 6].set((-Nxz + 2.0 * Nyz) / 6.0) + s = s.at[..., 1].set((-Nxz - Nyz) / 6.0) + s = s.at[..., 2].set((-Nxz - Nyz) / 6.0) + + # For c = (i, j, 0) + s = s.at[..., 12].set(Pi[..., 1] / 4.0) + s = s.at[..., 24].set(Pi[..., 1] / 4.0) + s = s.at[..., 21].set(-Pi[..., 1] / 4.0) + s = s.at[..., 15].set(-Pi[..., 1] / 4.0) + + # For c = (i, 0, k) + s = s.at[..., 10].set(Pi[..., 2] / 4.0) + s = s.at[..., 20].set(Pi[..., 2] / 4.0) + s = s.at[..., 19].set(-Pi[..., 2] / 4.0) + s = s.at[..., 11].set(-Pi[..., 2] / 4.0) + + # For c = (0, j, k) + s = s.at[..., 8].set(Pi[..., 4] / 4.0) + s = s.at[..., 4].set(Pi[..., 4] / 4.0) + s = s.at[..., 7].set(-Pi[..., 4] / 4.0) + s = s.at[..., 5].set(-Pi[..., 4] / 4.0) + + return s + + @partial(jit, static_argnums=(0, 2), inline=True) + def decompose_shear_d2q9_jax(self, fneq): + """ + Decompose fneq into shear components for D2Q9 lattice. + + Parameters + ---------- + fneq : jax.numpy.array + Non-equilibrium distribution function. + + Returns + ------- + jax.numpy.array + Shear components of fneq. + """ + Pi = self.momentum_flux_jax(fneq) + N = Pi[..., 0] - Pi[..., 2] + s = jnp.zeros_like(fneq) + s = s.at[..., 6].set(N) + s = s.at[..., 3].set(N) + s = s.at[..., 2].set(-N) + s = s.at[..., 1].set(-N) + s = s.at[..., 8].set(Pi[..., 1]) + s = s.at[..., 4].set(-Pi[..., 1]) + s = s.at[..., 5].set(-Pi[..., 1]) + s = s.at[..., 7].set(Pi[..., 1]) + + return s + + diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py new file mode 100644 index 0000000..587f673 --- /dev/null +++ b/xlb/operator/equilibrium/__init__.py @@ -0,0 +1 @@ +from xlb.operator.equilibrium.equilibrium import Equilibrium, QuadraticEquilibrium diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py new file mode 100644 index 0000000..9de736f --- /dev/null +++ b/xlb/operator/equilibrium/equilibrium.py @@ -0,0 +1,88 @@ +# Base class for all equilibriums + +from functools import partial +import jax.numpy as jnp +from jax import jit +import numba +from numba import cuda + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class Equilibrium(Operator): + """ + Base class for all equilibriums + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + +class QuadraticEquilibrium(Equilibrium): + """ + Quadratic equilibrium of Boltzmann equation using hermite polynomials. + Standard equilibrium model for LBM. + + TODO: move this to a separate file and lower and higher order equilibriums + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + def apply_jax(self, rho, u): + """ + JAX implementation of the equilibrium distribution function. + + # TODO: This might be optimized using a for loop for because + # the compiler will remove 0 c terms. + """ + cu = 3.0 * jnp.dot(u, jnp.array(self.velocity_set.c, dtype=rho.dtype)) + usqr = 1.5 * jnp.sum(jnp.square(u), axis=-1, keepdims=True) + feq = ( + rho + * jnp.array(self.velocity_set.w, dtype=rho.dtype) + * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + ) + return feq + + def construct_numba(self, velocity_set: VelocitySet, dtype=numba.float32): + """ + Numba implementation of the equilibrium distribution function. + """ + # Get needed values for numba functions + q = velocity_set.q + c = velocity_set.c.T + w = velocity_set.w + + # Make numba functions + @cuda.jit(device=True) + def _equilibrium(rho, u, feq): + # Compute the equilibrium distribution function + usqr = dtype(1.5) * (u[0] * u[0] + u[1] * u[1] + u[2] * u[2]) + for i in range(q): + cu = dtype(3.0) * ( + u[0] * dtype(c[i, 0]) + + u[1] * dtype(c[i, 1]) + + u[2] * dtype(c[i, 2]) + ) + feq[i] = ( + rho[0] + * dtype(w[i]) + * (dtype(1.0) + cu * (dtype(1.0) + dtype(0.5) * cu) - usqr) + ) + + # Return the equilibrium distribution function + return feq # comma is needed for numba to return a tuple, seems like a bug in numba + + return _equilibrium diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py new file mode 100644 index 0000000..91eb36c --- /dev/null +++ b/xlb/operator/macroscopic/__init__.py @@ -0,0 +1 @@ +from xlb.operator.macroscopic.macroscopic import Macroscopic diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py new file mode 100644 index 0000000..ff6ae83 --- /dev/null +++ b/xlb/operator/macroscopic/macroscopic.py @@ -0,0 +1,38 @@ +# Base class for all equilibriums + +from functools import partial +import jax.numpy as jnp +from jax import jit + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class Macroscopic(Operator): + """ + Base class for all macroscopic operators + + TODO: Currently this is only used for the standard rho and u moments. + In the future, this should be extended to include higher order moments + and other physic types (e.g. temperature, electromagnetism, etc...) + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @partial(jit, static_argnums=(0), inline=True) + def apply_jax(self, f): + """ + Apply the macroscopic operator to the lattice distribution function + """ + + rho = jnp.sum(f, axis=-1, keepdims=True) + c = jnp.array(self.velocity_set.c, dtype=f.dtype).T + u = jnp.dot(f, c) / rho + + return rho, u diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py new file mode 100644 index 0000000..c6ef3d7 --- /dev/null +++ b/xlb/operator/operator.py @@ -0,0 +1,81 @@ +# Base class for all operators, (collision, streaming, equilibrium, etc.) + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend + + +class Operator: + """ + Base class for all operators, collision, streaming, equilibrium, etc. + + This class is responsible for handling compute backends. + """ + + def __init__(self, velocity_set, compute_backend): + self.velocity_set = velocity_set + self.compute_backend = compute_backend + + # Check if compute backend is supported + # TODO: Fix check for compute backend + #if self.compute_backend not in self.supported_compute_backends: + # raise ValueError( + # f"Compute backend {self.compute_backend} not supported by {self.__class__.__name__}" + # ) + + def __call__(self, *args, **kwargs): + """ + Apply the operator to a input. This method will call the + appropriate apply method based on the compute backend. + """ + if self.compute_backend == ComputeBackend.JAX: + return self.apply_jax(*args, **kwargs) + elif self.compute_backend == ComputeBackend.NUMBA: + return self.apply_numba(*args, **kwargs) + + def apply_jax(self, *args, **kwargs): + """ + Implement the operator using JAX. + If using the JAX backend, this method will then become + the self.__call__ method. + """ + raise NotImplementedError("Child class must implement apply_jax") + + def apply_numba(self, *args, **kwargs): + """ + Implement the operator using Numba. + If using the Numba backend, this method will then become + the self.__call__ method. + """ + raise NotImplementedError("Child class must implement apply_numba") + + def construct_numba(self): + """ + Constructs numba kernel for the operator + """ + raise NotImplementedError("Child class must implement apply_numba") + + @property + def supported_compute_backend(self): + """ + Returns the supported compute backend for the operator + """ + supported_backend = [] + if self._is_method_overridden("apply_jax"): + supported_backend.append(ComputeBackend.JAX) + elif self._is_method_overridden("apply_numba"): + supported_backend.append(ComputeBackend.NUMBA) + else: + raise NotImplementedError("No supported compute backend implemented") + return supported_backend + + def _is_method_overridden(self, method_name): + """ + Helper method to check if a method is overridden in a subclass. + """ + method = getattr(self, method_name, None) + if method is None: + return False + return method.__func__ is not getattr(Operator, method_name, None).__func__ + + def __repr__(self): + return f"{self.__class__.__name__}()" diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py new file mode 100644 index 0000000..0469f13 --- /dev/null +++ b/xlb/operator/stepper/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.stepper.nse import NSE diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py new file mode 100644 index 0000000..9ef1dbc --- /dev/null +++ b/xlb/operator/stepper/nse.py @@ -0,0 +1,93 @@ +# Base class for all stepper operators + +from functools import partial +from jax import jit + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.boundary_condition import ImplementationStep + + +class NSE(Stepper): + """ + Class that handles the construction of lattice boltzmann stepping operator for the Navier-Stokes equations + + TODO: Check that the given operators (collision, stream, equilibrium, macroscopic, ...) are compatible + with the Navier-Stokes equations + """ + + def __init__( + self, + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions=[], + forcing=None, + precision_policy=None, + ): + super().__init__( + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions, + forcing, + precision_policy, + ) + + #@partial(jit, static_argnums=(0, 5), donate_argnums=(1)) # TODO: This donate args seems to break out of core memory + @partial(jit, static_argnums=(0, 5)) + def apply_jax(self, f, boundary_id, mask, timestep): + """ + Perform a single step of the lattice boltzmann method + """ + + # Cast to compute precision + f_pre_collision = self.precision_policy.cast_to_compute_jax(f) + + # Compute the macroscopic variables + rho, u = self.macroscopic(f_pre_collision) + + # Compute equilibrium + feq = self.equilibrium(rho, u) + + # Apply collision + f_post_collision = self.collision( + f, + feq, + rho, + u, + ) + + # Apply collision type boundary conditions + for id_number, bc in self.collision_boundary_conditions.items(): + f_post_collision = bc( + f_pre_collision, + f_post_collision, + boundary_id == id_number, + mask, + ) + f_pre_streaming = f_post_collision + + ## Apply forcing + # if self.forcing_op is not None: + # f = self.forcing_op.apply_jax(f, timestep) + + # Apply streaming + f_post_streaming = self.stream(f_pre_streaming) + + # Apply boundary conditions + for id_number, bc in self.stream_boundary_conditions.items(): + f_post_streaming = bc( + f_pre_streaming, + f_post_streaming, + boundary_id == id_number, + mask, + ) + + # Copy back to store precision + f = self.precision_policy.cast_to_store_jax(f_post_streaming) + + return f diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py new file mode 100644 index 0000000..b5c0a44 --- /dev/null +++ b/xlb/operator/stepper/stepper.py @@ -0,0 +1,84 @@ +# Base class for all stepper operators + +import jax.numpy as jnp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition import ImplementationStep + + +class Stepper(Operator): + """ + Class that handles the construction of lattice boltzmann stepping operator + """ + + def __init__( + self, + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions=[], + forcing=None, + precision_policy=None, + compute_backend=ComputeBackend.JAX, + ): + # Set parameters + self.collision = collision + self.stream = stream + self.equilibrium = equilibrium + self.macroscopic = macroscopic + self.boundary_conditions = boundary_conditions + self.forcing = forcing + self.precision_policy = precision_policy + self.compute_backend = compute_backend + + # Get collision and stream boundary conditions + self.collision_boundary_conditions = {} + self.stream_boundary_conditions = {} + for id_number, bc in enumerate(self.boundary_conditions): + bc_id = id_number + 1 + if bc.implementation_step == ImplementationStep.COLLISION: + self.collision_boundary_conditions[bc_id] = bc + elif bc.implementation_step == ImplementationStep.STREAMING: + self.stream_boundary_conditions[bc_id] = bc + else: + raise ValueError("Boundary condition step not recognized") + + # Get all operators for checking + self.operators = [ + collision, + stream, + equilibrium, + macroscopic, + *boundary_conditions, + ] + + # Get velocity set and backend + velocity_sets = set([op.velocity_set for op in self.operators]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + self.velocity_set = velocity_sets.pop() + compute_backends = set([op.compute_backend for op in self.operators]) + assert len(compute_backends) == 1, "All compute backends must be the same" + self.compute_backend = compute_backends.pop() + + # Initialize operator + super().__init__(self.velocity_set, self.compute_backend) + + def set_boundary(self, ijk): + """ + Set boundary condition arrays + These store the boundary condition information for each boundary + """ + # Empty boundary condition array + boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) + mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) + + # Set boundary condition arrays + for id_number, bc in self.collision_boundary_conditions.items(): + boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) + for id_number, bc in self.stream_boundary_conditions.items(): + boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) + + return boundary_id, mask diff --git a/xlb/operator/stream/__init__.py b/xlb/operator/stream/__init__.py new file mode 100644 index 0000000..9093da7 --- /dev/null +++ b/xlb/operator/stream/__init__.py @@ -0,0 +1 @@ +from xlb.operator.stream.stream import Stream diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py new file mode 100644 index 0000000..7d31a3b --- /dev/null +++ b/xlb/operator/stream/stream.py @@ -0,0 +1,88 @@ +# Base class for all streaming operators + +from functools import partial +import jax.numpy as jnp +from jax import jit, vmap +import numba + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class Stream(Operator): + """ + Base class for all streaming operators. + + TODO: Currently only this one streaming operator is implemented but + in the future we may have more streaming operators. For example, + one might want a multi-step streaming operator. + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @partial(jit, static_argnums=(0), donate_argnums=(1,)) + def apply_jax(self, f): + """ + JAX implementation of the streaming step. + + Parameters + ---------- + f: jax.numpy.ndarray + The distribution function. + """ + + def _streaming(f, c): + """ + Perform individual streaming operation in a direction. + + Parameters + ---------- + f: The distribution function. + c: The streaming direction vector. + + Returns + ------- + jax.numpy.ndarray + The updated distribution function after streaming. + """ + if self.velocity_set.d == 2: + return jnp.roll(f, (c[0], c[1]), axis=(0, 1)) + elif self.velocity_set.d == 3: + return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) + + return vmap(_streaming, in_axes=(-1, 0), out_axes=-1)( + f, jnp.array(self.velocity_set.c).T + ) + + def construct_numba(self, dtype=numba.float32): + """ + Numba implementation of the streaming step. + """ + + # Get needed values for numba functions + d = velocity_set.d + q = velocity_set.q + c = velocity_set.c.T + + # Make numba functions + @cuda.jit(device=True) + def _streaming(f_array, f, ijk): + # Stream to the next node + for _ in range(q): + if d == 2: + i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] + j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] + f_array[i, j, _] = f[_] + else: + i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] + j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] + k = (ijk[2] + int32(c[_, 2])) % f_array.shape[2] + f_array[i, j, k, _] = f[_] + + return _streaming diff --git a/xlb/physics_type.py b/xlb/physics_type.py new file mode 100644 index 0000000..73d9e70 --- /dev/null +++ b/xlb/physics_type.py @@ -0,0 +1,7 @@ +# Enum used to keep track of the physics types supported by different operators + +from enum import Enum + +class PhysicsType(Enum): + NSE = 1 # Navier-Stokes Equations + ADE = 2 # Advection-Diffusion Equations diff --git a/xlb/precision_policy/__init__.py b/xlb/precision_policy/__init__.py new file mode 100644 index 0000000..c228387 --- /dev/null +++ b/xlb/precision_policy/__init__.py @@ -0,0 +1,2 @@ +from xlb.precision_policy.precision_policy import PrecisionPolicy +from xlb.precision_policy.fp32fp32 import Fp32Fp32 diff --git a/xlb/precision_policy/fp32fp32.py b/xlb/precision_policy/fp32fp32.py new file mode 100644 index 0000000..1e37d2c --- /dev/null +++ b/xlb/precision_policy/fp32fp32.py @@ -0,0 +1,20 @@ +# Purpose: Precision policy for lattice Boltzmann method with computation and +# storage precision both set to float32. + +import jax.numpy as jnp + +from xlb.precision_policy.precision_policy import PrecisionPolicy + + +class Fp32Fp32(PrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation and storage + precision both set to float32. + + Parameters + ---------- + None + """ + + def __init__(self): + super().__init__(jnp.float32, jnp.float32) diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py new file mode 100644 index 0000000..459ce74 --- /dev/null +++ b/xlb/precision_policy/precision_policy.py @@ -0,0 +1,159 @@ +# Precision policy for lattice Boltzmann method +# TODO: possibly refctor this to be more general + +from functools import partial +import jax.numpy as jnp +from jax import jit +import numba +from numba import cuda + + +class PrecisionPolicy(object): + """ + Base class for precision policy in lattice Boltzmann method. + Basic idea is to allow for storing the lattice in a different precision than the computation. + + Stores dtype in jax but also contains same information for other backends such as numba. + + Parameters + ---------- + compute_dtype: jax.numpy.dtype + The precision used for computation. + storage_dtype: jax.numpy.dtype + The precision used for storage. + """ + + def __init__(self, compute_dtype, storage_dtype): + # Store the dtypes (jax) + self.compute_dtype = compute_dtype + self.storage_dtype = storage_dtype + + # Get the corresponding numba dtypes + self.compute_dtype_numba = self._get_numba_dtype(compute_dtype) + self.storage_dtype_numba = self._get_numba_dtype(storage_dtype) + + # Check that compute dtype is one of the supported dtypes (float16, float32, float64) + self.supported_compute_dtypes = [jnp.float16, jnp.float32, jnp.float64] + if self.compute_dtype not in self.supported_compute_dtypes: + raise ValueError( + f"Compute dtype {self.compute_dtype} is not supported. Supported dtypes are {self.supported_compute_dtypes}" + ) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def cast_to_compute_jax(self, array): + """ + Cast the array to the computation precision + + Parameters + ---------- + Array: jax.numpy.ndarray + The array to cast. + + Returns + ------- + jax.numpy.ndarray + The casted array + """ + return array.astype(self.compute_dtype) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def cast_to_store_jax(self, array): + """ + Cast the array to the storage precision + + Parameters + ---------- + Array: jax.numpy.ndarray + The array to cast. + + Returns + ------- + jax.numpy.ndarray + The casted array + """ + return array.astype(self.storage_dtype) + + def cast_to_compute_numba(self): + """ + Constructs a numba function to cast a value to the computation precision + + Parameters + ---------- + value: float + The value to cast. + + Returns + ------- + float + The casted value + """ + return self._cast_to_dtype_numba(self.compute_dtype_numba) + + def cast_to_store_numba(self): + """ + Constructs a numba function to cast a value to the storage precision + + Parameters + ---------- + value: float + The value to cast. + + Returns + ------- + float + The casted value + """ + return self._cast_to_dtype_numba(self.storage_dtype_numba) + + def _cast_to_dytpe_numba(self, dtype): + """ + Constructs a numba function to cast a value to the computation precision + + Parameters + ---------- + value: float + The value to cast. + + Returns + ------- + float + The casted value + """ + + @cuda.jit(device=True) + def cast_to_dtype(value): + return dtype(value) + + def _get_numba_dtype(self, dtype): + """ + Get the corresponding numba dtype + + # TODO: Make this more general + + Parameters + ---------- + dtype: jax.numpy.dtype + The dtype to convert + + Returns + ------- + numba.dtype + The corresponding numba dtype + """ + if dtype == jnp.float16: + return numba.float16 + elif dtype == jnp.float32: + return numba.float32 + elif dtype == jnp.float64: + return numba.float64 + elif dtype == jnp.int32: + return numba.int32 + elif dtype == jnp.int64: + return numba.int64 + elif dtype == jnp.int16: + return numba.int16 + else: + raise ValueError(f"Unsupported dtype {dtype}") + + def __repr__(self): + return f"compute_dtype={self.compute_dtype}/{self.storage_dtype}" diff --git a/src/utils.py b/xlb/utils/utils.py similarity index 73% rename from src/utils.py rename to xlb/utils/utils.py index 1720ca6..d01c500 100644 --- a/src/utils.py +++ b/xlb/utils/utils.py @@ -15,7 +15,7 @@ @partial(jit, static_argnums=(1, 2)) -def downsample_field(field, factor, method='bicubic'): +def downsample_field(field, factor, method="bicubic"): """ Downsample a JAX array by a factor of `factor` along each axis. @@ -38,12 +38,15 @@ def downsample_field(field, factor, method='bicubic'): else: new_shape = tuple(dim // factor for dim in field.shape[:-1]) downsampled_components = [] - for i in range(field.shape[-1]): # Iterate over the last dimension (vector components) + for i in range( + field.shape[-1] + ): # Iterate over the last dimension (vector components) resized = resize(field[..., i], new_shape, method=method) downsampled_components.append(resized) return jnp.stack(downsampled_components, axis=-1) + def save_image(timestep, fld, prefix=None): """ Save an image of a field at a given timestep. @@ -78,16 +81,17 @@ def save_image(timestep, fld, prefix=None): fld = np.sqrt(fld[..., 0] ** 2 + fld[..., 1] ** 2) plt.clf() - plt.imsave(fname + '.png', fld.T, cmap=cm.nipy_spectral, origin='lower') + plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower") + -def save_fields_vtk(timestep, fields, output_dir='.', prefix='fields'): +def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): """ Save VTK fields to the specified directory. Parameters ---------- timestep (int): The timestep number to be associated with the saved fields. - fields (Dict[str, np.ndarray]): A dictionary of fields to be saved. Each field must be an array-like object + fields (Dict[str, np.ndarray]): A dictionary of fields to be saved. Each field must be an array-like object with dimensions (nx, ny) for 2D fields or (nx, ny, nz) for 3D fields, where: - nx : int, number of grid points along the x-axis - ny : int, number of grid points along the y-axis @@ -112,9 +116,11 @@ def save_fields_vtk(timestep, fields, output_dir='.', prefix='fields'): if key == list(fields.keys())[0]: dimensions = value.shape else: - assert value.shape == dimensions, "All fields must have the same dimensions!" + assert ( + value.shape == dimensions + ), "All fields must have the same dimensions!" - output_filename = os.path.join(output_dir, prefix + "_" + f"{timestep:07d}.vtk") + output_filename = os.path.join(output_dir, prefix + "_" + f"{timestep:07d}.vtk") # Add 1 to the dimensions tuple as we store cell values dimensions = tuple([dim + 1 for dim in dimensions]) @@ -123,17 +129,18 @@ def save_fields_vtk(timestep, fields, output_dir='.', prefix='fields'): if value.ndim == 2: dimensions = dimensions + (1,) - grid = pv.ImageData(dimensions=dimensions) + grid = pv.UniformGrid(dimensions=dimensions) # Add the fields to the grid for key, value in fields.items(): - grid[key] = value.flatten(order='F') + grid[key] = value.flatten(order="F") # Save the grid to a VTK file start = time() grid.save(output_filename, binary=True) print(f"Saved {output_filename} in {time() - start:.6f} seconds.") + def live_volume_randering(timestep, field): # WORK IN PROGRESS """ @@ -157,29 +164,30 @@ def live_volume_randering(timestep, field): if field.ndim != 3: raise ValueError("The input field must be 3D!") dimensions = field.shape - grid = pv.ImageData(dimensions=dimensions) + grid = pv.UniformGrid(dimensions=dimensions) # Add the field to the grid - grid['field'] = field.flatten(order='F') + grid["field"] = field.flatten(order="F") # Create the rendering scene if timestep == 0: plt.ion() plt.figure(figsize=(10, 10)) - plt.axis('off') + plt.axis("off") plt.title("Live rendering of the field") pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap='nipy_spectral', opacity='sigmoid_10', shade=False) + pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) plt.imshow(pl.screenshot()) else: pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap='nipy_spectral', opacity='sigmoid_10', shade=False) + pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) # Update the rendering scene every 0.1 seconds plt.imshow(pl.screenshot()) plt.pause(0.1) -def save_BCs_vtk(timestep, BCs, gridInfo, output_dir='.'): + +def save_BCs_vtk(timestep, BCs, gridInfo, output_dir="."): """ Save boundary conditions as VTK format to the specified directory. @@ -200,14 +208,14 @@ def save_BCs_vtk(timestep, BCs, gridInfo, output_dir='.'): """ # Create a uniform grid - if gridInfo['nz'] == 0: - gridDimensions = (gridInfo['nx'] + 1, gridInfo['ny'] + 1, 1) - fieldDimensions = (gridInfo['nx'], gridInfo['ny'], 1) + if gridInfo["nz"] == 0: + gridDimensions = (gridInfo["nx"] + 1, gridInfo["ny"] + 1, 1) + fieldDimensions = (gridInfo["nx"], gridInfo["ny"], 1) else: - gridDimensions = (gridInfo['nx'] + 1, gridInfo['ny'] + 1, gridInfo['nz'] + 1) - fieldDimensions = (gridInfo['nx'], gridInfo['ny'], gridInfo['nz']) + gridDimensions = (gridInfo["nx"] + 1, gridInfo["ny"] + 1, gridInfo["nz"] + 1) + fieldDimensions = (gridInfo["nx"], gridInfo["ny"], gridInfo["nz"]) - grid = pv.ImageData(dimensions=gridDimensions) + grid = pv.UniformGrid(dimensions=gridDimensions) # Dictionary to keep track of encountered BC names bcNamesCount = {} @@ -226,16 +234,16 @@ def save_BCs_vtk(timestep, BCs, gridInfo, output_dir='.'): bcIndices = bc.indices # Convert indices to 1D indices - if gridInfo['dim'] == 2: - bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions[:-1], order='F') + if gridInfo["dim"] == 2: + bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions[:-1], order="F") else: - bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions, order='F') + bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions, order="F") - grid[bcName] = np.zeros(fieldDimensions, dtype=bool).flatten(order='F') + grid[bcName] = np.zeros(fieldDimensions, dtype=bool).flatten(order="F") grid[bcName][bcIndices] = True # Save the grid to a VTK file - output_filename = os.path.join(output_dir, "BCs_" + f"{timestep:07d}.vtk") + output_filename = os.path.join(output_dir, "BCs_" + f"{timestep:07d}.vtk") start = time() grid.save(output_filename, binary=True) @@ -267,10 +275,15 @@ def rotate_geometry(indices, origin, axis, angle): This function rotates the mesh by applying a rotation matrix to the voxel indices. The rotation matrix is calculated using the axis-angle representation of rotations. The origin of the rotation axis is assumed to be at (0, 0, 0). """ - indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat(axis, angle) + origin - return tuple(jnp.rint(indices_rotated).astype('int32').T) + indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat( + axis, angle + ) + origin + return tuple(jnp.rint(indices_rotated).astype("int32").T) + -def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None): +def voxelize_stl( + stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None +): """ Converts an STL file to a voxelized mesh. @@ -309,7 +322,7 @@ def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, def axangle2mat(axis, angle, is_normalized=False): - ''' Rotation matrix for rotation angle `angle` around `axis` + """Rotation matrix for rotation angle `angle` around `axis` Parameters ---------- axis : 3 element sequence @@ -326,7 +339,7 @@ def axangle2mat(axis, angle, is_normalized=False): ----- From : https://github.com/matthew-brett/transforms3d Ref : http://en.wikipedia.org/wiki/Rotation_matrix#Axis_and_angle - ''' + """ x, y, z = axis if not is_normalized: n = jnp.sqrt(x * x + y * y + z * z) @@ -345,70 +358,10 @@ def axangle2mat(axis, angle, is_normalized=False): xyC = x * yC yzC = y * zC zxC = z * xC - return jnp.array([ - [x * xC + c, xyC - zs, zxC + ys], - [xyC + zs, y * yC + c, yzC - xs], - [zxC - ys, yzC + xs, z * zC + c]]) - -@partial(jit) -def q_criterion(u): - # Compute derivatives - u_x = u[..., 0] - u_y = u[..., 1] - u_z = u[..., 2] - - # Compute derivatives - u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2 - u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2 - u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2 - u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2 - u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2 - u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2 - u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2 - u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2 - u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2 - - # Compute vorticity - mu_x = u_z_dy - u_y_dz - mu_y = u_x_dz - u_z_dx - mu_z = u_y_dx - u_x_dy - norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2) - - # Compute strain rate - s_0_0 = u_x_dx - s_0_1 = 0.5 * (u_x_dy + u_y_dx) - s_0_2 = 0.5 * (u_x_dz + u_z_dx) - s_1_0 = s_0_1 - s_1_1 = u_y_dy - s_1_2 = 0.5 * (u_y_dz + u_z_dy) - s_2_0 = s_0_2 - s_2_1 = s_1_2 - s_2_2 = u_z_dz - s_dot_s = ( - s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 + - s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 + - s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2 + return jnp.array( + [ + [x * xC + c, xyC - zs, zxC + ys], + [xyC + zs, y * yC + c, yzC - xs], + [zxC - ys, yzC + xs, z * zC + c], + ] ) - - # Compute omega - omega_0_0 = 0.0 - omega_0_1 = 0.5 * (u_x_dy - u_y_dx) - omega_0_2 = 0.5 * (u_x_dz - u_z_dx) - omega_1_0 = -omega_0_1 - omega_1_1 = 0.0 - omega_1_2 = 0.5 * (u_y_dz - u_z_dy) - omega_2_0 = -omega_0_2 - omega_2_1 = -omega_1_2 - omega_2_2 = 0.0 - omega_dot_omega = ( - omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 + - omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 + - omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2 - ) - - # Compute q-criterion - q = 0.5 * (omega_dot_omega - s_dot_s) - - return norm_mu, q - - diff --git a/xlb/velocity_set/__init__.py b/xlb/velocity_set/__init__.py new file mode 100644 index 0000000..5b7b737 --- /dev/null +++ b/xlb/velocity_set/__init__.py @@ -0,0 +1,4 @@ +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set.d2q9 import D2Q9 +from xlb.velocity_set.d3q19 import D3Q19 +from xlb.velocity_set.d3q27 import D3Q27 diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py new file mode 100644 index 0000000..d5f8793 --- /dev/null +++ b/xlb/velocity_set/d2q9.py @@ -0,0 +1,26 @@ +# Description: Lattice class for 2D D2Q9 lattice. + +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet + + +class D2Q9(VelocitySet): + """ + Velocity Set for 2D D2Q9 lattice. + + D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the + Lat tice Boltzmann Method for simulating fluid flows in two dimensions. + """ + + def __init__(self): + # Construct the velocity vectors and weights + cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] + cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] + c = np.array(tuple(zip(cx, cy))).T + w = np.array( + [4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36] + ) + + # Call the parent constructor + super().__init__(2, 9, c, w) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py new file mode 100644 index 0000000..5debdea --- /dev/null +++ b/xlb/velocity_set/d3q19.py @@ -0,0 +1,36 @@ +# Description: Lattice class for 3D D3Q19 lattice. + +import itertools +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet + + +class D3Q19(VelocitySet): + """ + Velocity Set for 3D D3Q19 lattice. + + D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the + Lattice Boltzmann Method for simulating fluid flows in three dimensions. + """ + + def __init__(self): + # Construct the velocity vectors and weights + c = np.array( + [ + ci + for ci in itertools.product([-1, 0, 1], repeat=3) + if np.sum(np.abs(ci)) <= 2 + ] + ).T + w = np.zeros(19) + for i in range(19): + if np.sum(np.abs(c[:, i])) == 0: + w[i] = 1.0 / 3.0 + elif np.sum(np.abs(c[:, i])) == 1: + w[i] = 1.0 / 18.0 + elif np.sum(np.abs(c[:, i])) == 2: + w[i] = 1.0 / 36.0 + + # Initialize the lattice + super().__init__(3, 19, c, w) diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py new file mode 100644 index 0000000..702acf4 --- /dev/null +++ b/xlb/velocity_set/d3q27.py @@ -0,0 +1,32 @@ +# Description: Lattice class for 3D D3Q27 lattice. + +import itertools +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet + + +class D3Q27(VelocitySet): + """ + Velocity Set for 3D D3Q27 lattice. + + D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the + Lattice Boltzmann Method for simulating fluid flows in three dimensions. + """ + + def __init__(self): + # Construct the velocity vectors and weights + c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T + w = np.zeros(27) + for i in range(27): + if np.sum(np.abs(c[:, i])) == 0: + w[i] = 8.0 / 27.0 + elif np.sum(np.abs(c[:, i])) == 1: + w[i] = 2.0 / 27.0 + elif np.sum(np.abs(c[:, i])) == 2: + w[i] = 1.0 / 54.0 + elif np.sum(np.abs(c[:, i])) == 3: + w[i] = 1.0 / 216.0 + + # Initialize the Lattice + super().__init__(3, 27, c, w) diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py new file mode 100644 index 0000000..05ba2a2 --- /dev/null +++ b/xlb/velocity_set/velocity_set.py @@ -0,0 +1,200 @@ +# Base Velocity Set class + +import math +import numpy as np +from functools import partial +import jax.numpy as jnp +from jax import jit, vmap +import numba +from numba import cuda, float32, int32 + + +class VelocitySet(object): + """ + Base class for the velocity set of the Lattice Boltzmann Method (LBM), e.g. D2Q9, D3Q27, etc. + + Parameters + ---------- + d: int + The dimension of the lattice. + q: int + The number of velocities of the lattice. + c: numpy.ndarray + The velocity vectors of the lattice. Shape: (q, d) + w: numpy.ndarray + The weights of the lattice. Shape: (q,) + """ + + def __init__(self, d, q, c, w): + # Store the dimension and the number of velocities + self.d = d + self.q = q + + # Constants + self.cs = math.sqrt(3) / 3.0 + self.cs2 = 1.0 / 3.0 + self.inv_cs2 = 3.0 + + # Construct the properties of the lattice + self.c = c + self.w = w + self.cc = self._construct_lattice_moment() + self.opp_indices = self._construct_opposite_indices() + self.main_indices = self._construct_main_indices() + self.right_indices = self._construct_right_indices() + self.left_indices = self._construct_left_indices() + + @partial(jit, static_argnums=(0,)) + def momentum_flux_jax(self, fneq): + """ + This function computes the momentum flux, which is the product of the non-equilibrium + distribution functions (fneq) and the lattice moments (cc). + + The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann + Method (LBM). + + Parameters + ---------- + fneq: jax.numpy.ndarray + The non-equilibrium distribution functions. + + Returns + ------- + jax.numpy.ndarray + The computed momentum flux. + """ + + return jnp.dot(fneq, self.cc) + + def momentum_flux_numba(self): + """ + This function computes the momentum flux, which is the product of the non-equilibrium + """ + raise NotImplementedError + + @partial(jit, static_argnums=(0,)) + def decompose_shear_jax(self, fneq): + """ + Decompose fneq into shear components for D3Q27 lattice. + + TODO: add generali + + Parameters + ---------- + fneq : jax.numpy.ndarray + Non-equilibrium distribution function. + + Returns + ------- + jax.numpy.ndarray + Shear components of fneq. + """ + raise NotImplementedError + + def decompose_shear_numba(self): + """ + Decompose fneq into shear components for D3Q27 lattice. + """ + raise NotImplementedError + + def _construct_lattice_moment(self): + """ + This function constructs the moments of the lattice. + + The moments are the products of the velocity vectors, which are used in the computation of + the equilibrium distribution functions and the collision operator in the Lattice Boltzmann + Method (LBM). + + Returns + ------- + cc: numpy.ndarray + The moments of the lattice. + """ + c = self.c.T + # Counter for the loop + cntr = 0 + + # nt: number of independent elements of a symmetric tensor + nt = self.d * (self.d + 1) // 2 + + cc = np.zeros((self.q, nt)) + for a in range(0, self.d): + for b in range(a, self.d): + cc[:, cntr] = c[:, a] * c[:, b] + cntr += 1 + + return cc + + def _construct_opposite_indices(self): + """ + This function constructs the indices of the opposite velocities for each velocity. + + The opposite velocity of a velocity is the velocity that has the same magnitude but the + opposite direction. + + Returns + ------- + opposite: numpy.ndarray + The indices of the opposite velocities. + """ + c = self.c.T + opposite = np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) + return opposite + + def _construct_main_indices(self): + """ + This function constructs the indices of the main velocities. + + The main velocities are the velocities that have a magnitude of 1 in lattice units. + + Returns + ------- + numpy.ndarray + The indices of the main velocities. + """ + c = self.c.T + if self.d == 2: + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] + + elif self.d == 3: + return np.nonzero( + (np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1) + )[0] + + def _construct_right_indices(self): + """ + This function constructs the indices of the velocities that point in the positive + x-direction. + + Returns + ------- + numpy.ndarray + The indices of the right velocities. + """ + c = self.c.T + return np.nonzero(c[:, 0] == 1)[0] + + def _construct_left_indices(self): + """ + This function constructs the indices of the velocities that point in the negative + x-direction. + + Returns + ------- + numpy.ndarray + The indices of the left velocities. + """ + c = self.c.T + return np.nonzero(c[:, 0] == -1)[0] + + def __str__(self): + """ + This function returns the name of the lattice in the format of DxQy. + """ + return self.__repr__() + + def __repr__(self): + """ + This function returns the name of the lattice in the format of DxQy. + """ + return "D{}Q{}".format(self.d, self.q)