From b3acb4ac89fd58da608101c6525ac24a3d9534ba Mon Sep 17 00:00:00 2001 From: oliver Date: Fri, 15 Dec 2023 15:06:34 -0800 Subject: [PATCH 001/144] xlb refactor --- examples/backend_comparisons/README.md | 14 + .../backend_comparisons/lattice_boltzmann.py | 1136 ++++++++++++++++ examples/refactor/README.md | 31 + examples/refactor/example_jax.py | 107 ++ examples/refactor/example_jax_out_of_core.py | 336 +++++ examples/refactor/example_numba.py | 78 ++ setup.py | 11 + src/boundary_conditions.py | 1175 ----------------- src/lattice.py | 281 ---- src/models.py | 260 ---- xlb/__init__.py | 18 + {src => xlb}/base.py | 0 xlb/compute_backend.py | 9 + xlb/experimental/__init__.py | 1 + xlb/experimental/ooc/__init__.py | 2 + xlb/experimental/ooc/ooc_array.py | 485 +++++++ xlb/experimental/ooc/out_of_core.py | 110 ++ xlb/experimental/ooc/tiles/__init__.py | 0 xlb/experimental/ooc/tiles/compressed_tile.py | 273 ++++ xlb/experimental/ooc/tiles/dense_tile.py | 96 ++ xlb/experimental/ooc/tiles/dynamic_array.py | 72 + xlb/experimental/ooc/tiles/tile.py | 115 ++ xlb/experimental/ooc/utils.py | 79 ++ xlb/operator/__init__.py | 1 + xlb/operator/boundary_condition/__init__.py | 5 + .../boundary_condition/boundary_condition.py | 100 ++ xlb/operator/boundary_condition/do_nothing.py | 56 + .../equilibrium_boundary.py | 69 + .../boundary_condition/full_bounce_back.py | 57 + .../boundary_condition/halfway_bounce_back.py | 97 ++ xlb/operator/collision/__init__.py | 3 + xlb/operator/collision/bgk.py | 109 ++ xlb/operator/collision/collision.py | 48 + xlb/operator/collision/kbc.py | 203 +++ xlb/operator/equilibrium/__init__.py | 1 + xlb/operator/equilibrium/equilibrium.py | 88 ++ xlb/operator/macroscopic/__init__.py | 1 + xlb/operator/macroscopic/macroscopic.py | 38 + xlb/operator/operator.py | 81 ++ xlb/operator/stepper/__init__.py | 2 + xlb/operator/stepper/nse.py | 93 ++ xlb/operator/stepper/stepper.py | 84 ++ xlb/operator/stream/__init__.py | 1 + xlb/operator/stream/stream.py | 88 ++ xlb/physics_type.py | 7 + xlb/precision_policy/__init__.py | 2 + xlb/precision_policy/fp32fp32.py | 20 + xlb/precision_policy/precision_policy.py | 159 +++ {src => xlb/utils}/utils.py | 147 +-- xlb/velocity_set/__init__.py | 4 + xlb/velocity_set/d2q9.py | 26 + xlb/velocity_set/d3q19.py | 36 + xlb/velocity_set/d3q27.py | 32 + xlb/velocity_set/velocity_set.py | 200 +++ 54 files changed, 4734 insertions(+), 1813 deletions(-) create mode 100644 examples/backend_comparisons/README.md create mode 100644 examples/backend_comparisons/lattice_boltzmann.py create mode 100644 examples/refactor/README.md create mode 100644 examples/refactor/example_jax.py create mode 100644 examples/refactor/example_jax_out_of_core.py create mode 100644 examples/refactor/example_numba.py delete mode 100644 src/boundary_conditions.py delete mode 100644 src/lattice.py delete mode 100644 src/models.py create mode 100644 xlb/__init__.py rename {src => xlb}/base.py (100%) create mode 100644 xlb/compute_backend.py create mode 100644 xlb/experimental/__init__.py create mode 100644 xlb/experimental/ooc/__init__.py create mode 100644 xlb/experimental/ooc/ooc_array.py create mode 100644 xlb/experimental/ooc/out_of_core.py create mode 100644 xlb/experimental/ooc/tiles/__init__.py create mode 100644 xlb/experimental/ooc/tiles/compressed_tile.py create mode 100644 xlb/experimental/ooc/tiles/dense_tile.py create mode 100644 xlb/experimental/ooc/tiles/dynamic_array.py create mode 100644 xlb/experimental/ooc/tiles/tile.py create mode 100644 xlb/experimental/ooc/utils.py create mode 100644 xlb/operator/__init__.py create mode 100644 xlb/operator/boundary_condition/__init__.py create mode 100644 xlb/operator/boundary_condition/boundary_condition.py create mode 100644 xlb/operator/boundary_condition/do_nothing.py create mode 100644 xlb/operator/boundary_condition/equilibrium_boundary.py create mode 100644 xlb/operator/boundary_condition/full_bounce_back.py create mode 100644 xlb/operator/boundary_condition/halfway_bounce_back.py create mode 100644 xlb/operator/collision/__init__.py create mode 100644 xlb/operator/collision/bgk.py create mode 100644 xlb/operator/collision/collision.py create mode 100644 xlb/operator/collision/kbc.py create mode 100644 xlb/operator/equilibrium/__init__.py create mode 100644 xlb/operator/equilibrium/equilibrium.py create mode 100644 xlb/operator/macroscopic/__init__.py create mode 100644 xlb/operator/macroscopic/macroscopic.py create mode 100644 xlb/operator/operator.py create mode 100644 xlb/operator/stepper/__init__.py create mode 100644 xlb/operator/stepper/nse.py create mode 100644 xlb/operator/stepper/stepper.py create mode 100644 xlb/operator/stream/__init__.py create mode 100644 xlb/operator/stream/stream.py create mode 100644 xlb/physics_type.py create mode 100644 xlb/precision_policy/__init__.py create mode 100644 xlb/precision_policy/fp32fp32.py create mode 100644 xlb/precision_policy/precision_policy.py rename {src => xlb/utils}/utils.py (73%) create mode 100644 xlb/velocity_set/__init__.py create mode 100644 xlb/velocity_set/d2q9.py create mode 100644 xlb/velocity_set/d3q19.py create mode 100644 xlb/velocity_set/d3q27.py create mode 100644 xlb/velocity_set/velocity_set.py diff --git a/examples/backend_comparisons/README.md b/examples/backend_comparisons/README.md new file mode 100644 index 0000000..a198eb4 --- /dev/null +++ b/examples/backend_comparisons/README.md @@ -0,0 +1,14 @@ +# Performance Comparisons + +This directory contains a minimal LBM implementation in Warp, Numba, and Jax. The +code can be run with the following command: + +```bash +python3 lattice_boltzmann.py +``` + +This will give MLUPs numbers for each implementation. The Warp implementation +is the fastest, followed by Numba, and then Jax. + +This example should be used as a test for properly implementing more backends in +XLB. diff --git a/examples/backend_comparisons/lattice_boltzmann.py b/examples/backend_comparisons/lattice_boltzmann.py new file mode 100644 index 0000000..a8e1c39 --- /dev/null +++ b/examples/backend_comparisons/lattice_boltzmann.py @@ -0,0 +1,1136 @@ +# Description: This file contains a simple example of using the OOCmap +# decorator to apply a function to a distributed array. +# Solves Lattice Boltzmann Taylor Green vortex decay + +import time +import warp as wp +import matplotlib.pyplot as plt +from tqdm import tqdm +import numpy as np +import cupy as cp +import time +from tqdm import tqdm +from numba import cuda +import numba +import math +import jax.numpy as jnp +import jax +from jax import jit +from functools import partial + +# Initialize Warp +wp.init() + +@wp.func +def warp_set_f( + f: wp.array4d(dtype=float), + value: float, + q: int, + i: int, + j: int, + k: int, + width: int, + height: int, + length: int, +): + # Modulo + if i < 0: + i += width + if j < 0: + j += height + if k < 0: + k += length + if i >= width: + i -= width + if j >= height: + j -= height + if k >= length: + k -= length + f[q, i, j, k] = value + +@wp.kernel +def warp_collide_stream( + f0: wp.array4d(dtype=float), + f1: wp.array4d(dtype=float), + width: int, + height: int, + length: int, + tau: float, +): + + # get index + x, y, z = wp.tid() + + # sample needed points + f_1_1_1 = f0[0, x, y, z] + f_2_1_1 = f0[1, x, y, z] + f_0_1_1 = f0[2, x, y, z] + f_1_2_1 = f0[3, x, y, z] + f_1_0_1 = f0[4, x, y, z] + f_1_1_2 = f0[5, x, y, z] + f_1_1_0 = f0[6, x, y, z] + f_1_2_2 = f0[7, x, y, z] + f_1_0_0 = f0[8, x, y, z] + f_1_2_0 = f0[9, x, y, z] + f_1_0_2 = f0[10, x, y, z] + f_2_1_2 = f0[11, x, y, z] + f_0_1_0 = f0[12, x, y, z] + f_2_1_0 = f0[13, x, y, z] + f_0_1_2 = f0[14, x, y, z] + f_2_2_1 = f0[15, x, y, z] + f_0_0_1 = f0[16, x, y, z] + f_2_0_1 = f0[17, x, y, z] + f_0_2_1 = f0[18, x, y, z] + + # compute u and p + p = (f_1_1_1 + + f_2_1_1 + f_0_1_1 + + f_1_2_1 + f_1_0_1 + + f_1_1_2 + f_1_1_0 + + f_1_2_2 + f_1_0_0 + + f_1_2_0 + f_1_0_2 + + f_2_1_2 + f_0_1_0 + + f_2_1_0 + f_0_1_2 + + f_2_2_1 + f_0_0_1 + + f_2_0_1 + f_0_2_1) + u = (f_2_1_1 - f_0_1_1 + + f_2_1_2 - f_0_1_0 + + f_2_1_0 - f_0_1_2 + + f_2_2_1 - f_0_0_1 + + f_2_0_1 - f_0_2_1) + v = (f_1_2_1 - f_1_0_1 + + f_1_2_2 - f_1_0_0 + + f_1_2_0 - f_1_0_2 + + f_2_2_1 - f_0_0_1 + - f_2_0_1 + f_0_2_1) + w = (f_1_1_2 - f_1_1_0 + + f_1_2_2 - f_1_0_0 + - f_1_2_0 + f_1_0_2 + + f_2_1_2 - f_0_1_0 + - f_2_1_0 + f_0_1_2) + res_p = 1.0 / p + u = u * res_p + v = v * res_p + w = w * res_p + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + 1.0)) + f_eq_0_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + 1.0)) + f_eq_1_2_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_0_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_1_2 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + 1.0)) + f_eq_1_1_0 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + 1.0)) + f_eq_1_2_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + 1.0)) + f_eq_1_0_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + 1.0)) + f_eq_1_2_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + 1.0)) + f_eq_1_0_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + 1.0)) + f_eq_2_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + 1.0)) + f_eq_0_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + 1.0)) + f_eq_2_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + 1.0)) + f_eq_0_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + 1.0)) + f_eq_2_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + 1.0)) + f_eq_0_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + 1.0)) + f_eq_2_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + 1.0)) + f_eq_0_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + 1.0)) + + # set next lattice state + inv_tau = (1.0 / tau) + warp_set_f(f1, f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1), 0, x, y, z, width, height, length) + warp_set_f(f1, f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1), 1, x + 1, y, z, width, height, length) + warp_set_f(f1, f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1), 2, x - 1, y, z, width, height, length) + warp_set_f(f1, f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1), 3, x, y + 1, z, width, height, length) + warp_set_f(f1, f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1), 4, x, y - 1, z, width, height, length) + warp_set_f(f1, f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2), 5, x, y, z + 1, width, height, length) + warp_set_f(f1, f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0), 6, x, y, z - 1, width, height, length) + warp_set_f(f1, f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2), 7, x, y + 1, z + 1, width, height, length) + warp_set_f(f1, f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0), 8, x, y - 1, z - 1, width, height, length) + warp_set_f(f1, f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0), 9, x, y + 1, z - 1, width, height, length) + warp_set_f(f1, f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2), 10, x, y - 1, z + 1, width, height, length) + warp_set_f(f1, f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2), 11, x + 1, y, z + 1, width, height, length) + warp_set_f(f1, f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0), 12, x - 1, y, z - 1, width, height, length) + warp_set_f(f1, f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0), 13, x + 1, y, z - 1, width, height, length) + warp_set_f(f1, f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2), 14, x - 1, y, z + 1, width, height, length) + warp_set_f(f1, f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1), 15, x + 1, y + 1, z, width, height, length) + warp_set_f(f1, f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1), 16, x - 1, y - 1, z, width, height, length) + warp_set_f(f1, f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1), 17, x + 1, y - 1, z, width, height, length) + warp_set_f(f1, f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1), 18, x - 1, y + 1, z, width, height, length) + +@wp.kernel +def warp_initialize_taylor_green( + f: wp.array4d(dtype=wp.float32), + dx: float, + vel: float, + start_x: int, + start_y: int, + start_z: int, +): + + # get index + i, j, k = wp.tid() + + # get real pos + x = wp.float(i + start_x) * dx + y = wp.float(j + start_y) * dx + z = wp.float(k + start_z) * dx + + # compute u + u = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) + v = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + w = 0.0 + + # compute p + p = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * (wp.cos(2.0 * x) + wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0)) + + 1.0 + ) + + # compute u X u + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0.0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (-uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_1 - uxu) + + factor_2 * (exu_2_1_1 * exu_2_1_1) + + 1.0 + ) + ) + f_eq_0_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_1 - uxu) + + factor_2 * (exu_0_1_1 * exu_0_1_1) + + 1.0 + ) + ) + f_eq_1_2_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_0_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_1_2 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_2 - uxu) + + factor_2 * (exu_1_1_2 * exu_1_1_2) + + 1.0 + ) + ) + f_eq_1_1_0 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_0 - uxu) + + factor_2 * (exu_1_1_0 * exu_1_1_0) + + 1.0 + ) + ) + f_eq_1_2_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_2 - uxu) + + factor_2 * (exu_1_2_2 * exu_1_2_2) + + 1.0 + ) + ) + f_eq_1_0_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_0 - uxu) + + factor_2 * (exu_1_0_0 * exu_1_0_0) + + 1.0 + ) + ) + f_eq_1_2_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_0 - uxu) + + factor_2 * (exu_1_2_0 * exu_1_2_0) + + 1.0 + ) + ) + f_eq_1_0_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_2 - uxu) + + factor_2 * (exu_1_0_2 * exu_1_0_2) + + 1.0 + ) + ) + f_eq_2_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_2 - uxu) + + factor_2 * (exu_2_1_2 * exu_2_1_2) + + 1.0 + ) + ) + f_eq_0_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_0 - uxu) + + factor_2 * (exu_0_1_0 * exu_0_1_0) + + 1.0 + ) + ) + f_eq_2_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_0 - uxu) + + factor_2 * (exu_2_1_0 * exu_2_1_0) + + 1.0 + ) + ) + f_eq_0_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_2 - uxu) + + factor_2 * (exu_0_1_2 * exu_0_1_2) + + 1.0 + ) + ) + f_eq_2_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_2_1 - uxu) + + factor_2 * (exu_2_2_1 * exu_2_2_1) + + 1.0 + ) + ) + f_eq_0_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_0_1 - uxu) + + factor_2 * (exu_0_0_1 * exu_0_0_1) + + 1.0 + ) + ) + f_eq_2_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_0_1 - uxu) + + factor_2 * (exu_2_0_1 * exu_2_0_1) + + 1.0 + ) + ) + f_eq_0_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_2_1 - uxu) + + factor_2 * (exu_0_2_1 * exu_0_2_1) + + 1.0 + ) + ) + + # set next lattice state + f[0, i, j, k] = f_eq_1_1_1 + f[1, i, j, k] = f_eq_2_1_1 + f[2, i, j, k] = f_eq_0_1_1 + f[3, i, j, k] = f_eq_1_2_1 + f[4, i, j, k] = f_eq_1_0_1 + f[5, i, j, k] = f_eq_1_1_2 + f[6, i, j, k] = f_eq_1_1_0 + f[7, i, j, k] = f_eq_1_2_2 + f[8, i, j, k] = f_eq_1_0_0 + f[9, i, j, k] = f_eq_1_2_0 + f[10, i, j, k] = f_eq_1_0_2 + f[11, i, j, k] = f_eq_2_1_2 + f[12, i, j, k] = f_eq_0_1_0 + f[13, i, j, k] = f_eq_2_1_0 + f[14, i, j, k] = f_eq_0_1_2 + f[15, i, j, k] = f_eq_2_2_1 + f[16, i, j, k] = f_eq_0_0_1 + f[17, i, j, k] = f_eq_2_0_1 + f[18, i, j, k] = f_eq_0_2_1 + + +def warp_initialize_f(f, dx: float): + # Get inputs + cs = 1.0 / np.sqrt(3.0) + vel = 0.1 * cs + + # Launch kernel + wp.launch( + kernel=warp_initialize_taylor_green, + dim=list(f.shape[1:]), + inputs=[f, dx, vel, 0, 0, 0], + device=f.device, + ) + + return f + + +def warp_apply_collide_stream(f0, f1, tau: float): + # Apply streaming and collision step + wp.launch( + kernel=warp_collide_stream, + dim=list(f0.shape[1:]), + inputs=[f0, f1, f0.shape[1], f0.shape[2], f0.shape[3], tau], + device=f0.device, + ) + + return f1, f0 + + +@cuda.jit("void(float32[:,:,:,::1], float32, int32, int32, int32, int32, int32, int32, int32)", device=True) +def numba_set_f( + f: numba.cuda.cudadrv.devicearray.DeviceNDArray, + value: float, + q: int, + i: int, + j: int, + k: int, + width: int, + height: int, + length: int, +): + # Modulo + if i < 0: + i += width + if j < 0: + j += height + if k < 0: + k += length + if i >= width: + i -= width + if j >= height: + j -= height + if k >= length: + k -= length + f[i, j, k, q] = value + +#@cuda.jit +@cuda.jit("void(float32[:,:,:,::1], float32[:,:,:,::1], int32, int32, int32, float32)") +def numba_collide_stream( + f0: numba.cuda.cudadrv.devicearray.DeviceNDArray, + f1: numba.cuda.cudadrv.devicearray.DeviceNDArray, + width: int, + height: int, + length: int, + tau: float, +): + + x, y, z = cuda.grid(3) + + # sample needed points + f_1_1_1 = f0[x, y, z, 0] + f_2_1_1 = f0[x, y, z, 1] + f_0_1_1 = f0[x, y, z, 2] + f_1_2_1 = f0[x, y, z, 3] + f_1_0_1 = f0[x, y, z, 4] + f_1_1_2 = f0[x, y, z, 5] + f_1_1_0 = f0[x, y, z, 6] + f_1_2_2 = f0[x, y, z, 7] + f_1_0_0 = f0[x, y, z, 8] + f_1_2_0 = f0[x, y, z, 9] + f_1_0_2 = f0[x, y, z, 10] + f_2_1_2 = f0[x, y, z, 11] + f_0_1_0 = f0[x, y, z, 12] + f_2_1_0 = f0[x, y, z, 13] + f_0_1_2 = f0[x, y, z, 14] + f_2_2_1 = f0[x, y, z, 15] + f_0_0_1 = f0[x, y, z, 16] + f_2_0_1 = f0[x, y, z, 17] + f_0_2_1 = f0[x, y, z, 18] + + # compute u and p + p = (f_1_1_1 + + f_2_1_1 + f_0_1_1 + + f_1_2_1 + f_1_0_1 + + f_1_1_2 + f_1_1_0 + + f_1_2_2 + f_1_0_0 + + f_1_2_0 + f_1_0_2 + + f_2_1_2 + f_0_1_0 + + f_2_1_0 + f_0_1_2 + + f_2_2_1 + f_0_0_1 + + f_2_0_1 + f_0_2_1) + u = (f_2_1_1 - f_0_1_1 + + f_2_1_2 - f_0_1_0 + + f_2_1_0 - f_0_1_2 + + f_2_2_1 - f_0_0_1 + + f_2_0_1 - f_0_2_1) + v = (f_1_2_1 - f_1_0_1 + + f_1_2_2 - f_1_0_0 + + f_1_2_0 - f_1_0_2 + + f_2_2_1 - f_0_0_1 + - f_2_0_1 + f_0_2_1) + w = (f_1_1_2 - f_1_1_0 + + f_1_2_2 - f_1_0_0 + - f_1_2_0 + f_1_0_2 + + f_2_1_2 - f_0_1_0 + - f_2_1_0 + f_0_1_2) + res_p = numba.float32(1.0) / p + u = u * res_p + v = v * res_p + w = w * res_p + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = numba.float32(0.0) + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = numba.float32(1.5) + factor_2 = numba.float32(4.5) + weight_0 = numba.float32(0.33333333) + weight_1 = numba.float32(0.05555555) + weight_2 = numba.float32(0.02777777) + + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + numba.float32(1.0))) + f_eq_2_1_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + numba.float32(1.0))) + f_eq_0_1_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + numba.float32(1.0))) + f_eq_1_2_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + numba.float32(1.0))) + f_eq_1_0_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + numba.float32(1.0))) + f_eq_1_1_2 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + numba.float32(1.0))) + f_eq_1_1_0 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + numba.float32(1.0))) + f_eq_1_2_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + numba.float32(1.0))) + f_eq_1_0_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + numba.float32(1.0))) + f_eq_1_2_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + numba.float32(1.0))) + f_eq_1_0_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + numba.float32(1.0))) + f_eq_2_1_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + numba.float32(1.0))) + f_eq_0_1_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + numba.float32(1.0))) + f_eq_2_1_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + numba.float32(1.0))) + f_eq_0_1_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + numba.float32(1.0))) + f_eq_2_2_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + numba.float32(1.0))) + f_eq_0_0_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + numba.float32(1.0))) + f_eq_2_0_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + numba.float32(1.0))) + f_eq_0_2_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + numba.float32(1.0))) + + # set next lattice state + inv_tau = numba.float32((numba.float32(1.0) / tau)) + numba_set_f(f1, f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1), 0, x, y, z, width, height, length) + numba_set_f(f1, f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1), 1, x + 1, y, z, width, height, length) + numba_set_f(f1, f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1), 2, x - 1, y, z, width, height, length) + numba_set_f(f1, f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1), 3, x, y + 1, z, width, height, length) + numba_set_f(f1, f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1), 4, x, y - 1, z, width, height, length) + numba_set_f(f1, f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2), 5, x, y, z + 1, width, height, length) + numba_set_f(f1, f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0), 6, x, y, z - 1, width, height, length) + numba_set_f(f1, f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2), 7, x, y + 1, z + 1, width, height, length) + numba_set_f(f1, f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0), 8, x, y - 1, z - 1, width, height, length) + numba_set_f(f1, f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0), 9, x, y + 1, z - 1, width, height, length) + numba_set_f(f1, f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2), 10, x, y - 1, z + 1, width, height, length) + numba_set_f(f1, f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2), 11, x + 1, y, z + 1, width, height, length) + numba_set_f(f1, f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0), 12, x - 1, y, z - 1, width, height, length) + numba_set_f(f1, f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0), 13, x + 1, y, z - 1, width, height, length) + numba_set_f(f1, f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2), 14, x - 1, y, z + 1, width, height, length) + numba_set_f(f1, f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1), 15, x + 1, y + 1, z, width, height, length) + numba_set_f(f1, f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1), 16, x - 1, y - 1, z, width, height, length) + numba_set_f(f1, f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1), 17, x + 1, y - 1, z, width, height, length) + numba_set_f(f1, f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1), 18, x - 1, y + 1, z, width, height, length) + + +@cuda.jit +def numba_initialize_taylor_green( + f, + dx, + vel, + start_x, + start_y, + start_z, +): + + i, j, k = cuda.grid(3) + + # get real pos + x = numba.float32(i + start_x) * dx + y = numba.float32(j + start_y) * dx + z = numba.float32(k + start_z) * dx + + # compute u + u = vel * math.sin(x) * math.cos(y) * math.cos(z) + v = -vel * math.cos(x) * math.sin(y) * math.cos(z) + w = 0.0 + + # compute p + p = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * (math.cos(2.0 * x) + math.cos(2.0 * y) * (math.cos(2.0 * z) + 2.0)) + + 1.0 + ) + + # compute u X u + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0.0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (-uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_1 - uxu) + + factor_2 * (exu_2_1_1 * exu_2_1_1) + + 1.0 + ) + ) + f_eq_0_1_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_1 - uxu) + + factor_2 * (exu_0_1_1 * exu_0_1_1) + + 1.0 + ) + ) + f_eq_1_2_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_0_1 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_1 - uxu) + + factor_2 * (exu_1_2_1 * exu_1_2_1) + + 1.0 + ) + ) + f_eq_1_1_2 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_2 - uxu) + + factor_2 * (exu_1_1_2 * exu_1_1_2) + + 1.0 + ) + ) + f_eq_1_1_0 = weight_1 * ( + p + * ( + factor_1 * (2.0 * exu_1_1_0 - uxu) + + factor_2 * (exu_1_1_0 * exu_1_1_0) + + 1.0 + ) + ) + f_eq_1_2_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_2 - uxu) + + factor_2 * (exu_1_2_2 * exu_1_2_2) + + 1.0 + ) + ) + f_eq_1_0_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_0 - uxu) + + factor_2 * (exu_1_0_0 * exu_1_0_0) + + 1.0 + ) + ) + f_eq_1_2_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_2_0 - uxu) + + factor_2 * (exu_1_2_0 * exu_1_2_0) + + 1.0 + ) + ) + f_eq_1_0_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_1_0_2 - uxu) + + factor_2 * (exu_1_0_2 * exu_1_0_2) + + 1.0 + ) + ) + f_eq_2_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_2 - uxu) + + factor_2 * (exu_2_1_2 * exu_2_1_2) + + 1.0 + ) + ) + f_eq_0_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_0 - uxu) + + factor_2 * (exu_0_1_0 * exu_0_1_0) + + 1.0 + ) + ) + f_eq_2_1_0 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_1_0 - uxu) + + factor_2 * (exu_2_1_0 * exu_2_1_0) + + 1.0 + ) + ) + f_eq_0_1_2 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_1_2 - uxu) + + factor_2 * (exu_0_1_2 * exu_0_1_2) + + 1.0 + ) + ) + f_eq_2_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_2_1 - uxu) + + factor_2 * (exu_2_2_1 * exu_2_2_1) + + 1.0 + ) + ) + f_eq_0_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_0_1 - uxu) + + factor_2 * (exu_0_0_1 * exu_0_0_1) + + 1.0 + ) + ) + f_eq_2_0_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_2_0_1 - uxu) + + factor_2 * (exu_2_0_1 * exu_2_0_1) + + 1.0 + ) + ) + f_eq_0_2_1 = weight_2 * ( + p + * ( + factor_1 * (2.0 * exu_0_2_1 - uxu) + + factor_2 * (exu_0_2_1 * exu_0_2_1) + + 1.0 + ) + ) + + # set next lattice state + f[i, j, k, 0] = f_eq_1_1_1 + f[i, j, k, 1] = f_eq_2_1_1 + f[i, j, k, 2] = f_eq_0_1_1 + f[i, j, k, 3] = f_eq_1_2_1 + f[i, j, k, 4] = f_eq_1_0_1 + f[i, j, k, 5] = f_eq_1_1_2 + f[i, j, k, 6] = f_eq_1_1_0 + f[i, j, k, 7] = f_eq_1_2_2 + f[i, j, k, 8] = f_eq_1_0_0 + f[i, j, k, 9] = f_eq_1_2_0 + f[ i, j, k, 10] = f_eq_1_0_2 + f[ i, j, k, 11] = f_eq_2_1_2 + f[ i, j, k, 12] = f_eq_0_1_0 + f[ i, j, k, 13] = f_eq_2_1_0 + f[ i, j, k, 14] = f_eq_0_1_2 + f[ i, j, k, 15] = f_eq_2_2_1 + f[ i, j, k, 16] = f_eq_0_0_1 + f[ i, j, k, 17] = f_eq_2_0_1 + f[ i, j, k, 18] = f_eq_0_2_1 + + +def numba_initialize_f(f, dx: float): + # Get inputs + cs = 1.0 / np.sqrt(3.0) + vel = 0.1 * cs + + # Launch kernel + blockdim = (16, 16, 1) + griddim = ( + int(np.ceil(f.shape[0] / blockdim[0])), + int(np.ceil(f.shape[1] / blockdim[1])), + int(np.ceil(f.shape[2] / blockdim[2])), + ) + numba_initialize_taylor_green[griddim, blockdim]( + f, dx, vel, 0, 0, 0 + ) + + return f + +def numba_apply_collide_stream(f0, f1, tau: float): + # Apply streaming and collision step + blockdim = (8, 8, 8) + griddim = ( + int(np.ceil(f0.shape[0] / blockdim[0])), + int(np.ceil(f0.shape[1] / blockdim[1])), + int(np.ceil(f0.shape[2] / blockdim[2])), + ) + numba_collide_stream[griddim, blockdim]( + f0, f1, f0.shape[0], f0.shape[1], f0.shape[2], tau + ) + + return f1, f0 + +@partial(jit, static_argnums=(1), donate_argnums=(0)) +def jax_apply_collide_stream(f, tau: float): + + # Get f directions + f_1_1_1 = f[:, :, :, 0] + f_2_1_1 = f[:, :, :, 1] + f_0_1_1 = f[:, :, :, 2] + f_1_2_1 = f[:, :, :, 3] + f_1_0_1 = f[:, :, :, 4] + f_1_1_2 = f[:, :, :, 5] + f_1_1_0 = f[:, :, :, 6] + f_1_2_2 = f[:, :, :, 7] + f_1_0_0 = f[:, :, :, 8] + f_1_2_0 = f[:, :, :, 9] + f_1_0_2 = f[:, :, :, 10] + f_2_1_2 = f[:, :, :, 11] + f_0_1_0 = f[:, :, :, 12] + f_2_1_0 = f[:, :, :, 13] + f_0_1_2 = f[:, :, :, 14] + f_2_2_1 = f[:, :, :, 15] + f_0_0_1 = f[:, :, :, 16] + f_2_0_1 = f[:, :, :, 17] + f_0_2_1 = f[:, :, :, 18] + + # compute u and p + p = (f_1_1_1 + + f_2_1_1 + f_0_1_1 + + f_1_2_1 + f_1_0_1 + + f_1_1_2 + f_1_1_0 + + f_1_2_2 + f_1_0_0 + + f_1_2_0 + f_1_0_2 + + f_2_1_2 + f_0_1_0 + + f_2_1_0 + f_0_1_2 + + f_2_2_1 + f_0_0_1 + + f_2_0_1 + f_0_2_1) + u = (f_2_1_1 - f_0_1_1 + + f_2_1_2 - f_0_1_0 + + f_2_1_0 - f_0_1_2 + + f_2_2_1 - f_0_0_1 + + f_2_0_1 - f_0_2_1) + v = (f_1_2_1 - f_1_0_1 + + f_1_2_2 - f_1_0_0 + + f_1_2_0 - f_1_0_2 + + f_2_2_1 - f_0_0_1 + - f_2_0_1 + f_0_2_1) + w = (f_1_1_2 - f_1_1_0 + + f_1_2_2 - f_1_0_0 + - f_1_2_0 + f_1_0_2 + + f_2_1_2 - f_0_1_0 + - f_2_1_0 + f_0_1_2) + res_p = 1.0 / p + u = u * res_p + v = v * res_p + w = w * res_p + uxu = u * u + v * v + w * w + + # compute e dot u + exu_1_1_1 = 0 + exu_2_1_1 = u + exu_0_1_1 = -u + exu_1_2_1 = v + exu_1_0_1 = -v + exu_1_1_2 = w + exu_1_1_0 = -w + exu_1_2_2 = v + w + exu_1_0_0 = -v - w + exu_1_2_0 = v - w + exu_1_0_2 = -v + w + exu_2_1_2 = u + w + exu_0_1_0 = -u - w + exu_2_1_0 = u - w + exu_0_1_2 = -u + w + exu_2_2_1 = u + v + exu_0_0_1 = -u - v + exu_2_0_1 = u - v + exu_0_2_1 = -u + v + + # compute equilibrium dist + factor_1 = 1.5 + factor_2 = 4.5 + weight_0 = 0.33333333 + weight_1 = 0.05555555 + weight_2 = 0.02777777 + f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + 1.0)) + f_eq_2_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + 1.0)) + f_eq_0_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + 1.0)) + f_eq_1_2_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_0_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) + f_eq_1_1_2 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + 1.0)) + f_eq_1_1_0 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + 1.0)) + f_eq_1_2_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + 1.0)) + f_eq_1_0_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + 1.0)) + f_eq_1_2_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + 1.0)) + f_eq_1_0_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + 1.0)) + f_eq_2_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + 1.0)) + f_eq_0_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + 1.0)) + f_eq_2_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + 1.0)) + f_eq_0_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + 1.0)) + f_eq_2_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + 1.0)) + f_eq_0_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + 1.0)) + f_eq_2_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + 1.0)) + f_eq_0_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + 1.0)) + + # set next lattice state + inv_tau = (1.0 / tau) + f_1_1_1 = f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1) + f_2_1_1 = f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1) + f_0_1_1 = f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1) + f_1_2_1 = f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1) + f_1_0_1 = f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1) + f_1_1_2 = f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2) + f_1_1_0 = f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0) + f_1_2_2 = f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2) + f_1_0_0 = f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0) + f_1_2_0 = f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0) + f_1_0_2 = f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2) + f_2_1_2 = f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2) + f_0_1_0 = f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0) + f_2_1_0 = f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0) + f_0_1_2 = f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2) + f_2_2_1 = f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1) + f_0_0_1 = f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1) + f_2_0_1 = f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1) + f_0_2_1 = f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1) + + # Roll fs and concatenate + f_2_1_1 = jnp.roll(f_2_1_1, -1, axis=0) + f_0_1_1 = jnp.roll(f_0_1_1, 1, axis=0) + f_1_2_1 = jnp.roll(f_1_2_1, -1, axis=1) + f_1_0_1 = jnp.roll(f_1_0_1, 1, axis=1) + f_1_1_2 = jnp.roll(f_1_1_2, -1, axis=2) + f_1_1_0 = jnp.roll(f_1_1_0, 1, axis=2) + f_1_2_2 = jnp.roll(jnp.roll(f_1_2_2, -1, axis=1), -1, axis=2) + f_1_0_0 = jnp.roll(jnp.roll(f_1_0_0, 1, axis=1), 1, axis=2) + f_1_2_0 = jnp.roll(jnp.roll(f_1_2_0, -1, axis=1), 1, axis=2) + f_1_0_2 = jnp.roll(jnp.roll(f_1_0_2, 1, axis=1), -1, axis=2) + f_2_1_2 = jnp.roll(jnp.roll(f_2_1_2, -1, axis=0), -1, axis=2) + f_0_1_0 = jnp.roll(jnp.roll(f_0_1_0, 1, axis=0), 1, axis=2) + f_2_1_0 = jnp.roll(jnp.roll(f_2_1_0, -1, axis=0), 1, axis=2) + f_0_1_2 = jnp.roll(jnp.roll(f_0_1_2, 1, axis=0), -1, axis=2) + f_2_2_1 = jnp.roll(jnp.roll(f_2_2_1, -1, axis=0), -1, axis=1) + f_0_0_1 = jnp.roll(jnp.roll(f_0_0_1, 1, axis=0), 1, axis=1) + f_2_0_1 = jnp.roll(jnp.roll(f_2_0_1, -1, axis=0), 1, axis=1) + f_0_2_1 = jnp.roll(jnp.roll(f_0_2_1, 1, axis=0), -1, axis=1) + + return jnp.stack( + [ + f_1_1_1, + f_2_1_1, + f_0_1_1, + f_1_2_1, + f_1_0_1, + f_1_1_2, + f_1_1_0, + f_1_2_2, + f_1_0_0, + f_1_2_0, + f_1_0_2, + f_2_1_2, + f_0_1_0, + f_2_1_0, + f_0_1_2, + f_2_2_1, + f_0_0_1, + f_2_0_1, + f_0_2_1, + ], + axis=-1, + ) + + + +if __name__ == "__main__": + + # Sim Parameters + n = 256 + tau = 0.505 + dx = 2.0 * np.pi / n + nr_steps = 128 + + # Bar plot + backend = [] + mlups = [] + + ######### Warp ######### + # Make f0, f1 + f0 = wp.empty((19, n, n, n), dtype=wp.float32, device="cuda:0") + f1 = wp.empty((19, n, n, n), dtype=wp.float32, device="cuda:0") + + # Initialize f0 + f0 = warp_initialize_f(f0, dx) + + # Apply streaming and collision + t0 = time.time() + for _ in tqdm(range(nr_steps)): + f0, f1 = warp_apply_collide_stream(f0, f1, tau) + wp.synchronize() + t1 = time.time() + + # Compute MLUPS + mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 + backend.append("Warp") + mlups.append(mlups) + + # Plot results + np_f = f0.numpy() + plt.imshow(np_f[3, :, :, 0]) + plt.colorbar() + plt.savefig("warp_f_.png") + plt.close() + + ######### Numba ######### + # Make f0, f1 + f0 = cp.ascontiguousarray(cp.empty((n, n, n, 19), dtype=np.float32)) + f1 = cp.ascontiguousarray(cp.empty((n, n, n, 19), dtype=np.float32)) + + # Initialize f0 + f0 = numba_initialize_f(f0, dx) + + # Apply streaming and collision + t0 = time.time() + for _ in tqdm(range(nr_steps)): + f0, f1 = numba_apply_collide_stream(f0, f1, tau) + cp.cuda.Device(0).synchronize() + t1 = time.time() + + # Compute MLUPS + mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 + backend.append("Numba") + mlups.append(mlups) + + # Plot results + np_f = f0 + plt.imshow(np_f[:, :, 0, 3].get()) + plt.colorbar() + plt.savefig("numba_f_.png") + plt.close() + + ######### Jax ######### + # Make f0, f1 + f = jnp.zeros((n, n, n, 19), dtype=jnp.float32) + + # Initialize f0 + # f = jax_initialize_f(f, dx) + + # Apply streaming and collision + t0 = time.time() + for _ in tqdm(range(nr_steps)): + f = jax_apply_collide_stream(f, tau) + t1 = time.time() + + # Compute MLUPS + mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 + backend.append("Jax") + mlups.append(mlups) + + # Plot results + np_f = f + plt.imshow(np_f[:, :, 0, 3]) + plt.colorbar() + plt.savefig("jax_f_.png") + plt.close() + + + + diff --git a/examples/refactor/README.md b/examples/refactor/README.md new file mode 100644 index 0000000..37b17e7 --- /dev/null +++ b/examples/refactor/README.md @@ -0,0 +1,31 @@ +# Refactor Examples + +This directory contains several example of using the refactored XLB library. + +These examples are not meant to be veiwed as the new interface to XLB but only how +to expose the compute kernels to a user. Development is still ongoing. + +## Examples + +### JAX Example + +The JAX example is a simple example of using the refactored XLB library +with JAX. The example is located in the `example_jax.py`. It shows +a very basic flow past a cyliner. + +### NUMBA Example + +TODO: Not working yet + +The NUMBA example is a simple example of using the refactored XLB library +with NUMBA. The example is located in the `example_numba.py`. It shows +a very basic flow past a cyliner. This example is not working yet though and +is still under development for numba backend. + +### Out of Core JAX Example + +This shoes how we can use out of core memory with JAX. The example is located +in the `example_jax_out_of_core.py`. It shows a very basic flow past a cyliner. +The basic idea is to create an out of core memory array using the implementation +in XLB. Then we run the simulation using the jax functions implementation obtained +from XLB. Some rendering is done using PhantomGaze. diff --git a/examples/refactor/example_jax.py b/examples/refactor/example_jax.py new file mode 100644 index 0000000..9092022 --- /dev/null +++ b/examples/refactor/example_jax.py @@ -0,0 +1,107 @@ +# from IPython import display +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt + +import xlb + +if __name__ == "__main__": + # Simulation parameters + nr = 128 + vel = 0.05 + visc = 0.00001 + omega = 1.0 / (3.0 * visc + 0.5) + length = 2 * np.pi + + # Geometry (sphere) + lin = np.linspace(0, length, nr) + X, Y, Z = np.meshgrid(lin, lin, lin, indexing="ij") + XYZ = np.stack([X, Y, Z], axis=-1) + radius = np.pi / 8.0 + + # XLB precision policy + precision_policy = xlb.precision_policy.Fp32Fp32() + + # XLB lattice + velocity_set = xlb.velocity_set.D3Q27() + + # XLB equilibrium + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(velocity_set=velocity_set) + + # XLB macroscopic + macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set) + + # XLB collision + collision = xlb.operator.collision.KBC(omega=omega, velocity_set=velocity_set) + + # XLB stream + stream = xlb.operator.stream.Stream(velocity_set=velocity_set) + + # XLB noslip boundary condition (sphere) + in_cylinder = ((X - np.pi/2.0)**2 + (Y - np.pi)**2 + (Z - np.pi)**2) < radius**2 + indices = np.argwhere(in_cylinder) + bounce_back = xlb.operator.boundary_condition.FullBounceBack.from_indices( + indices=indices, + velocity_set=velocity_set + ) + + # XLB outflow boundary condition + outflow = xlb.operator.boundary_condition.DoNothing.from_indices( + indices=np.argwhere(XYZ[..., 0] == length), + velocity_set=velocity_set + ) + + # XLB inflow boundary condition + inflow = xlb.operator.boundary_condition.EquilibriumBoundary.from_indices( + indices=np.argwhere(XYZ[..., 0] == 0.0), + velocity_set=velocity_set, + rho=1.0, + u=np.array([vel, 0.0, 0.0]), + equilibrium=equilibrium + ) + + # XLB stepper + stepper = xlb.operator.stepper.NSE( + collision=collision, + stream=stream, + equilibrium=equilibrium, + macroscopic=macroscopic, + boundary_conditions=[bounce_back, outflow, inflow], + precision_policy=precision_policy, + ) + + # Make initial condition + u = jnp.stack([vel * jnp.ones_like(X), jnp.zeros_like(X), jnp.zeros_like(X)], axis=-1) + rho = jnp.expand_dims(jnp.ones_like(X), axis=-1) + f = equilibrium(rho, u) + + # Get boundary id and mask + ijk = jnp.meshgrid(jnp.arange(nr), jnp.arange(nr), jnp.arange(nr), indexing="ij") + boundary_id, mask = stepper.set_boundary(jnp.stack(ijk, axis=-1)) + + # Run simulation + tic = time.time() + nr_iter = 4096 + for i in tqdm(range(nr_iter)): + f = stepper(f, boundary_id, mask, i) + + if i % 32 == 0: + # Get u, rho from f + rho, u = macroscopic(f) + norm_u = jnp.linalg.norm(u, axis=-1) + norm_u = (1.0 - jnp.minimum(boundary_id, 1.0)) * norm_u + + # Plot + plt.imshow(norm_u[..., nr//2], cmap="jet") + plt.colorbar() + plt.savefig(f"img_{str(i).zfill(5)}.png") + plt.close() + + # Sync to host + f = f.block_until_ready() + toc = time.time() + print(f"MLUPS: {(nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_jax_out_of_core.py b/examples/refactor/example_jax_out_of_core.py new file mode 100644 index 0000000..aeb604f --- /dev/null +++ b/examples/refactor/example_jax_out_of_core.py @@ -0,0 +1,336 @@ +# from IPython import display +import os +os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.7' + +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt +from mpi4py import MPI +import cupy as cp + +import xlb +from xlb.experimental.ooc import OOCmap, OOCArray + +import phantomgaze as pg + +comm = MPI.COMM_WORLD + +@jax.jit +def q_criterion(u): + # Compute derivatives + u_x = u[..., 0] + u_y = u[..., 1] + u_z = u[..., 2] + + # Compute derivatives + u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2 + u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2 + u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2 + u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2 + u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2 + u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2 + u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2 + u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2 + u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2 + + # Compute vorticity + mu_x = u_z_dy - u_y_dz + mu_y = u_x_dz - u_z_dx + mu_z = u_y_dx - u_x_dy + norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2) + + # Compute strain rate + s_0_0 = u_x_dx + s_0_1 = 0.5 * (u_x_dy + u_y_dx) + s_0_2 = 0.5 * (u_x_dz + u_z_dx) + s_1_0 = s_0_1 + s_1_1 = u_y_dy + s_1_2 = 0.5 * (u_y_dz + u_z_dy) + s_2_0 = s_0_2 + s_2_1 = s_1_2 + s_2_2 = u_z_dz + s_dot_s = ( + s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 + + s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 + + s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2 + ) + + # Compute omega + omega_0_0 = 0.0 + omega_0_1 = 0.5 * (u_x_dy - u_y_dx) + omega_0_2 = 0.5 * (u_x_dz - u_z_dx) + omega_1_0 = -omega_0_1 + omega_1_1 = 0.0 + omega_1_2 = 0.5 * (u_y_dz - u_z_dy) + omega_2_0 = -omega_0_2 + omega_2_1 = -omega_1_2 + omega_2_2 = 0.0 + omega_dot_omega = ( + omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 + + omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 + + omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2 + ) + + # Compute q-criterion + q = 0.5 * (omega_dot_omega - s_dot_s) + + return norm_mu, q + + +if __name__ == "__main__": + # Simulation parameters + nr = 256 + nx = 3 * nr + ny = nr + nz = nr + vel = 0.05 + visc = 0.00001 + omega = 1.0 / (3.0 * visc + 0.5) + length = 2 * np.pi + dx = length / (ny - 1) + radius = np.pi / 3.0 + + # OOC parameters + sub_steps = 8 + sub_nr = 128 + padding = (sub_steps, sub_steps, sub_steps, 0) + + # XLB precision policy + precision_policy = xlb.precision_policy.Fp32Fp32() + + # XLB lattice + velocity_set = xlb.velocity_set.D3Q27() + + # XLB equilibrium + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(velocity_set=velocity_set) + + # XLB macroscopic + macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set) + + # XLB collision + collision = xlb.operator.collision.KBC(omega=omega, velocity_set=velocity_set) + + # XLB stream + stream = xlb.operator.stream.Stream(velocity_set=velocity_set) + + # XLB noslip boundary condition (sphere) + # Create a mask function + def set_boundary_sphere(ijk, boundary_id, mask, id_number): + # Get XYZ + XYZ = ijk * dx + sphere_mask = jnp.linalg.norm(XYZ - length / 2.0, axis=-1) < radius + boundary_id = boundary_id.at[sphere_mask].set(id_number) + mask = mask.at[sphere_mask].set(True) + return boundary_id, mask + bounce_back = xlb.operator.boundary_condition.FullBounceBack( + set_boundary=set_boundary_sphere, + velocity_set=velocity_set + ) + + # XLB outflow boundary condition + def set_boundary_outflow(ijk, boundary_id, mask, id_number): + # Get XYZ + XYZ = ijk * dx + outflow_mask = XYZ[..., 0] >= (length * 3.0) - dx + boundary_id = boundary_id.at[outflow_mask].set(id_number) + mask = mask.at[outflow_mask].set(True) + return boundary_id, mask + outflow = xlb.operator.boundary_condition.DoNothing( + set_boundary=set_boundary_outflow, + velocity_set=velocity_set + ) + + # XLB inflow boundary condition + def set_boundary_inflow(ijk, boundary_id, mask, id_number): + # Get XYZ + XYZ = ijk * dx + inflow_mask = XYZ[..., 0] == 0.0 + boundary_id = boundary_id.at[inflow_mask].set(id_number) + mask = mask.at[inflow_mask].set(True) + return boundary_id, mask + inflow = xlb.operator.boundary_condition.EquilibriumBoundary( + set_boundary=set_boundary_inflow, + velocity_set=velocity_set, + rho=1.0, + u=np.array([vel, 0.0, 0.0]), + equilibrium=equilibrium + ) + + # XLB stepper + stepper = xlb.operator.stepper.NSE( + collision=collision, + stream=stream, + equilibrium=equilibrium, + macroscopic=macroscopic, + boundary_conditions=[bounce_back, outflow, inflow], + precision_policy=precision_policy, + ) + + # Make OOC arrays + f = OOCArray( + shape=(nx, ny, nz, velocity_set.q), + dtype=np.float32, + tile_shape=(sub_nr, sub_nr, sub_nr, velocity_set.q), + padding=padding, + comm=comm, + devices=[cp.cuda.Device(0) for i in range(comm.size)], + codec=None, + nr_compute_tiles=1, + ) + + camera_radius = length * 2.0 + focal_point = (3.0 * length / 2.0, length / 2.0, length / 2.0) + angle = 1 * 0.0001 + camera_position = (focal_point[0] + camera_radius * np.sin(angle), focal_point[1], focal_point[2] + camera_radius * np.cos(angle)) + camera = pg.Camera( + position=camera_position, + focal_point=focal_point, + view_up=(0.0, 1.0, 0.0), + height=1440, + width=2560, + max_depth=6.0 * length, + ) + screen_buffer = pg.ScreenBuffer.from_camera(camera) + + + # Initialize f + @OOCmap(comm, (0,), backend="jax") + def initialize_f(f): + # Get inputs + shape = f.shape[:-1] + u = jnp.stack([vel * jnp.ones(shape), jnp.zeros(shape), jnp.zeros(shape)], axis=-1) + rho = jnp.expand_dims(jnp.ones(shape), axis=-1) + f = equilibrium(rho, u) + return f + f = initialize_f(f) + + # Stepping function + @OOCmap(comm, (0,), backend="jax", add_index=True) + def ooc_stepper(f): + + # Get tensors + f, global_index = f + + # Get ijk + lin_i = jnp.arange(global_index[0], global_index[0] + f.shape[0]) + lin_j = jnp.arange(global_index[1], global_index[1] + f.shape[1]) + lin_k = jnp.arange(global_index[2], global_index[2] + f.shape[2]) + ijk = jnp.meshgrid(lin_i, lin_j, lin_k, indexing="ij") + ijk = jnp.stack(ijk, axis=-1) + + # Set boundary_id and mask + boundary_id, mask = stepper.set_boundary(ijk) + + # Run stepper + for _ in range(sub_steps): + f = stepper(f, boundary_id, mask, _) + + # Wait till f is computed using jax + f = f.block_until_ready() + + return f + + # Make a render function + @OOCmap(comm, (0,), backend="jax", add_index=True) + def render(f, screen_buffer, camera): + + # Get tensors + f, global_index = f + + # Get ijk + lin_i = jnp.arange(global_index[0], global_index[0] + f.shape[0]) + lin_j = jnp.arange(global_index[1], global_index[1] + f.shape[1]) + lin_k = jnp.arange(global_index[2], global_index[2] + f.shape[2]) + ijk = jnp.meshgrid(lin_i, lin_j, lin_k, indexing="ij") + ijk = jnp.stack(ijk, axis=-1) + + # Set boundary_id and mask + boundary_id, mask = stepper.set_boundary(ijk) + sphere = (boundary_id == 1).astype(jnp.float32)[1:-1, 1:-1, 1:-1] + + # Get rho, u + rho, u = macroscopic(f) + + # Get q-cr + norm_mu, q = q_criterion(u) + + # Make volumes + origin = ((global_index[0] + 1) * dx, (global_index[1] + 1) * dx, (global_index[2] + 1) * dx) + q_volume = pg.objects.Volume( + q, spacing=(dx, dx, dx), origin=origin + ) + norm_mu_volume = pg.objects.Volume( + norm_mu, spacing=(dx, dx, dx), origin=origin + ) + sphere_volume = pg.objects.Volume( + sphere, spacing=(dx, dx, dx), origin=origin + ) + + # Render + screen_buffer = pg.render.contour( + q_volume, + threshold=0.000005, + color=norm_mu_volume, + colormap=pg.Colormap("jet", vmin=0.0, vmax=0.025), + camera=camera, + screen_buffer=screen_buffer, + ) + screen_buffer = pg.render.contour( + sphere_volume, + threshold=0.5, + camera=camera, + screen_buffer=screen_buffer, + ) + + return f + + # Run simulation + tic = time.time() + nr_iter = 128 * nr // sub_steps + nr_frames = 1024 + for i in tqdm(range(nr_iter)): + f = ooc_stepper(f) + + if i % (nr_iter // nr_frames) == 0: + # Rotate camera + camera_radius = length * 1.0 + focal_point = (length / 2.0, length / 2.0, length / 2.0) + angle = (np.pi / nr_iter) * i + camera_position = (focal_point[0] + camera_radius * np.sin(angle), focal_point[1], focal_point[2] + camera_radius * np.cos(angle)) + camera = pg.Camera( + position=camera_position, + focal_point=focal_point, + view_up=(0.0, 1.0, 0.0), + height=1080, + width=1920, + max_depth=6.0 * length, + ) + + # Render global setup + screen_buffer = pg.render.wireframe( + lower_bound=(0.0, 0.0, 0.0), + upper_bound=(3.0*length, length, length), + thickness=length/100.0, + camera=camera, + ) + screen_buffer = pg.render.axes( + size=length/30.0, + center=(0.0, 0.0, length*1.1), + camera=camera, + screen_buffer=screen_buffer + ) + + # Render + render(f, screen_buffer, camera) + + # Save image + plt.imsave('./q_criterion_' + str(i).zfill(7) + '.png', np.minimum(screen_buffer.image.get(), 1.0)) + + # Sync to host + cp.cuda.runtime.deviceSynchronize() + toc = time.time() + print(f"MLUPS: {(sub_steps * nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_numba.py b/examples/refactor/example_numba.py new file mode 100644 index 0000000..9183067 --- /dev/null +++ b/examples/refactor/example_numba.py @@ -0,0 +1,78 @@ +# from IPython import display +import numpy as np +import cupy as cp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt +from numba import cuda, config + +import xlb + +config.CUDA_ARRAY_INTERFACE_SYNC = False + +if __name__ == "__main__": + # XLB precision policy + precision_policy = xlb.precision_policy.Fp32Fp32() + + # XLB lattice + lattice = xlb.lattice.D3Q19() + + # XLB collision model + collision = xlb.collision.BGK() + + # Make XLB compute kernels + compute = xlb.compute_constructor.NumbaConstructor( + lattice=lattice, + collision=collision, + boundary_conditions=[], + forcing=None, + precision_policy=precision_policy, + ) + + # Make taylor green vortex initial condition + tau = 0.505 + vel = 0.1 * 1.0 / np.sqrt(3.0) + nr = 256 + lin = cp.linspace(0, 2 * np.pi, nr, endpoint=False, dtype=cp.float32) + X, Y, Z = cp.meshgrid(lin, lin, lin, indexing="ij") + X = X[None, ...] + Y = Y[None, ...] + Z = Z[None, ...] + u = vel * cp.sin(X) * cp.cos(Y) * cp.cos(Z) + v = -vel * cp.cos(X) * cp.sin(Y) * cp.cos(Z) + w = cp.zeros_like(X) + rho = ( + 3.0 + * vel**2 + * (1.0 / 16.0) + * (cp.cos(2 * X) + cp.cos(2 * Y) + cp.cos(2 * Z)) + + 1.0) + u = cp.concatenate([u, v, w], axis=-1) + + # Allocate f + f0 = cp.zeros((19, nr, nr, nr), dtype=cp.float32) + f1 = cp.zeros((19, nr, nr, nr), dtype=cp.float32) + + # Get f from u, rho + compute.equilibrium(rho, u, f0) + + # Run compute kernel on f + tic = time.time() + nr_iter = 128 + for i in tqdm(range(nr_iter)): + compute.step(f0, f1, i) + f0, f1 = f1, f0 + + if i % 4 == 0: + ## Get u, rho from f + #rho, u = compute.macroscopic(f) + #norm_u = jnp.linalg.norm(u, axis=-1) + + # Plot + plt.imsave(f"img_{str(i).zfill(5)}.png", f0[8, nr // 2, :, :].get(), cmap="jet") + + # Sync to host + cp.cuda.stream.get_current_stream().synchronize() + toc = time.time() + print(f"MLUPS: {(nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/setup.py b/setup.py index e69de29..5ef3ed7 100644 --- a/setup.py +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + +setup( + name="XLB", + version="0.0.1", + author="", + packages=find_packages(), + install_requires=[ + ], + include_package_data=True, +) diff --git a/src/boundary_conditions.py b/src/boundary_conditions.py deleted file mode 100644 index cd98c91..0000000 --- a/src/boundary_conditions.py +++ /dev/null @@ -1,1175 +0,0 @@ -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -class BoundaryCondition(object): - """ - Base class for boundary conditions in a LBM simulation. - - This class provides a general structure for implementing boundary conditions. It includes methods for preparing the - boundary attributes and for applying the boundary condition. Specific boundary conditions should be implemented as - subclasses of this class, with the `apply` method overridden as necessary. - - Attributes - ---------- - lattice : Lattice - The lattice used in the simulation. - nx: - The number of nodes in the x direction. - ny: - The number of nodes in the y direction. - nz: - The number of nodes in the z direction. - dim : int - The number of dimensions in the simulation (2 or 3). - precision_policy : PrecisionPolicy - The precision policy used in the simulation. - indices : array-like - The indices of the boundary nodes. - name : str or None - The name of the boundary condition. This should be set in subclasses. - isSolid : bool - Whether the boundary condition is for a solid boundary. This should be set in subclasses. - isDynamic : bool - Whether the boundary condition is dynamic (changes over time). This should be set in subclasses. - needsExtraConfiguration : bool - Whether the boundary condition requires extra configuration. This should be set in subclasses. - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. This should be set in subclasses. - """ - - def __init__(self, indices, gridInfo, precision_policy): - self.lattice = gridInfo["lattice"] - self.nx = gridInfo["nx"] - self.ny = gridInfo["ny"] - self.nz = gridInfo["nz"] - self.dim = gridInfo["dim"] - self.precisionPolicy = precision_policy - self.indices = indices - self.name = None - self.isSolid = False - self.isDynamic = False - self.needsExtraConfiguration = False - self.implementationStep = "PostStreaming" - - def create_local_mask_and_normal_arrays(self, grid_mask): - - """ - Creates local mask and normal arrays for the boundary condition. - - Parameters - ---------- - grid_mask : array-like - The grid mask for the lattice. - - Returns - ------- - None - - Notes - ----- - This method creates local mask and normal arrays for the boundary condition based on the grid mask. - If the boundary condition requires extra configuration, the `configure` method is called. - """ - - if self.needsExtraConfiguration: - boundaryMask = self.get_boundary_mask(grid_mask) - self.configure(boundaryMask) - self.needsExtraConfiguration = False - - boundaryMask = self.get_boundary_mask(grid_mask) - self.normals = self.get_normals(boundaryMask) - self.imissing, self.iknown = self.get_missing_indices(boundaryMask) - self.imissingMask, self.iknownMask, self.imiddleMask = self.get_missing_mask(boundaryMask) - - return - - def get_boundary_mask(self, grid_mask): - """ - Add jax.device_count() to the self.indices in x-direction, and 1 to the self.indices other directions - This is to make sure the boundary condition is applied to the correct nodes as grid_mask is - expanded by (jax.device_count(), 1, 1) - - Parameters - ---------- - grid_mask : array-like - The grid mask for the lattice. - - Returns - ------- - boundaryMask : array-like - """ - shifted_indices = np.array(self.indices) - shifted_indices[0] += device_count() - shifted_indices[1:] += 1 - # Convert back to tuple - shifted_indices = tuple(shifted_indices) - boundaryMask = np.array(grid_mask[shifted_indices]) - - return boundaryMask - - def configure(self, boundaryMask): - """ - Configures the boundary condition. - - Parameters - ---------- - boundaryMask : array-like - The grid mask for the boundary voxels. - - Returns - ------- - None - - Notes - ----- - This method should be overridden in subclasses if the boundary condition requires extra configuration. - """ - return - - @partial(jit, static_argnums=(0, 3), inline=True) - def prepare_populations(self, fout, fin, implementation_step): - """ - Prepares the distribution functions for the boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The incoming distribution functions. - fin : jax.numpy.ndarray - The outgoing distribution functions. - implementation_step : str - The step in the lattice Boltzmann method algorithm at which the preparation is applied. - - Returns - ------- - jax.numpy.ndarray - The prepared distribution functions. - - Notes - ----- - This method should be overridden in subclasses if the boundary condition requires preparation of the distribution functions during post-collision or post-streaming. See ExtrapolationBoundaryCondition for an example. - """ - return fout - - def get_normals(self, boundaryMask): - """ - Calculates the normal vectors at the boundary nodes. - - Parameters - ---------- - boundaryMask : array-like - The boundary mask for the lattice. - - Returns - ------- - array-like - The normal vectors at the boundary nodes. - - Notes - ----- - This method calculates the normal vectors by dotting the boundary mask with the main lattice directions. - """ - main_c = self.lattice.c.T[self.lattice.main_indices] - m = boundaryMask[..., self.lattice.main_indices] - normals = -np.dot(m, main_c) - return normals - - def get_missing_indices(self, boundaryMask): - """ - Returns two int8 arrays the same shape as boundaryMask. The non-zero entries of these arrays indicate missing - directions that require BCs (imissing) as well as their corresponding opposite directions (iknown). - - Parameters - ---------- - boundaryMask : array-like - The boundary mask for the lattice. - - Returns - ------- - tuple of array-like - The missing and known indices for the boundary condition. - - Notes - ----- - This method calculates the missing and known indices based on the boundary mask. The missing indices are the - non-zero entries of the boundary mask, and the known indices are their corresponding opposite directions. - """ - - # Find imissing, iknown 1-to-1 corresponding indices - # Note: the "zero" index is used as default value here and won't affect BC computations - nbd = len(self.indices[0]) - imissing = np.vstack([np.arange(self.lattice.q, dtype='uint8')] * nbd) - iknown = np.vstack([self.lattice.opp_indices] * nbd) - imissing[~boundaryMask] = 0 - iknown[~boundaryMask] = 0 - return imissing, iknown - - def get_missing_mask(self, boundaryMask): - """ - Returns three boolean arrays the same shape as boundaryMask. - Note: these boundary masks are useful for reduction (eg. summation) operators of selected q-directions. - - Parameters - ---------- - boundaryMask : array-like - The boundary mask for the lattice. - - Returns - ------- - tuple of array-like - The missing, known, and middle masks for the boundary condition. - - Notes - ----- - This method calculates the missing, known, and middle masks based on the boundary mask. The missing mask - is the boundary mask, the known mask is the opposite directions of the missing mask, and the middle mask - is the directions that are neither missing nor known. - """ - # Find masks for imissing, iknown and imiddle - imissingMask = boundaryMask - iknownMask = imissingMask[:, self.lattice.opp_indices] - imiddleMask = ~(imissingMask | iknownMask) - return imissingMask, iknownMask, imiddleMask - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - None - - Notes - ----- - This method should be overridden in subclasses to implement the specific boundary condition. The method should - modify the output distribution functions in place to apply the boundary condition. - """ - pass - - @partial(jit, static_argnums=(0,)) - def equilibrium(self, rho, u): - """ - Compute equilibrium distribution function. - - Parameters - ---------- - rho : jax.numpy.ndarray - The density at each node in the lattice. - u : jax.numpy.ndarray - The velocity at each node in the lattice. - - Returns - ------- - jax.numpy.ndarray - The equilibrium distribution function at each node in the lattice. - - Notes - ----- - This method computes the equilibrium distribution function based on the density and velocity. The computation is - performed in the compute precision specified by the precision policy. The result is not cast to the output precision as - this is function is used inside other functions that require the compute precision. - """ - rho, u = self.precisionPolicy.cast_to_compute((rho, u)) - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 3.0 * jnp.dot(u, c) - usqr = 1.5 * jnp.sum(u**2, axis=-1, keepdims=True) - feq = rho * self.lattice.w * (1.0 + 1.0 * cu + 0.5 * cu**2 - usqr) - - return feq - - @partial(jit, static_argnums=(0,)) - def momentum_flux(self, fneq): - """ - Compute the momentum flux. - - Parameters - ---------- - fneq : jax.numpy.ndarray - The non-equilibrium distribution function at each node in the lattice. - - Returns - ------- - jax.numpy.ndarray - The momentum flux at each node in the lattice. - - Notes - ----- - This method computes the momentum flux by dotting the non-equilibrium distribution function with the lattice - direction vectors. - """ - return jnp.dot(fneq, self.lattice.cc) - - @partial(jit, static_argnums=(0,)) - def momentum_exchange_force(self, f_poststreaming, f_postcollision): - """ - Using the momentum exchange method to compute the boundary force vector exerted on the solid geometry - based on [1] as described in [3]. Ref [2] shows how [1] is applicable to curved geometries only by using a - bounce-back method (e.g. Bouzidi) that accounts for curved boundaries. - NOTE: this function should be called after BC's are imposed. - [1] A.J.C. Ladd, Numerical simulations of particular suspensions via a discretized Boltzmann equation. - Part 2 (numerical results), J. Fluid Mech. 271 (1994) 311-339. - [2] R. Mei, D. Yu, W. Shyy, L.-S. Luo, Force evaluation in the lattice Boltzmann method involving - curved geometry, Phys. Rev. E 65 (2002) 041203. - [3] Caiazzo, A., & Junk, M. (2008). Boundary forces in lattice Boltzmann: Analysis of momentum exchange - algorithm. Computers & Mathematics with Applications, 55(7), 1415-1423. - - Parameters - ---------- - f_poststreaming : jax.numpy.ndarray - The post-streaming distribution function at each node in the lattice. - f_postcollision : jax.numpy.ndarray - The post-collision distribution function at each node in the lattice. - - Returns - ------- - jax.numpy.ndarray - The force exerted on the solid geometry at each boundary node. - - Notes - ----- - This method computes the force exerted on the solid geometry at each boundary node using the momentum exchange method. - The force is computed based on the post-streaming and post-collision distribution functions. This method - should be called after the boundary conditions are imposed. - """ - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - phi = f_postcollision[self.indices][bindex, self.iknown] + \ - f_poststreaming[self.indices][bindex, self.imissing] - force = jnp.sum(c[:, self.iknown] * phi, axis=-1).T - return force - -class BounceBack(BoundaryCondition): - """ - Bounce-back boundary condition for a lattice Boltzmann method simulation. - - This class implements a full-way bounce-back boundary condition, where particles hitting the boundary are reflected - back in the direction they came from. The boundary condition is applied after the collision step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "BounceBackFullway". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostCollision". - """ - def __init__(self, indices, gridInfo, precision_policy): - super().__init__(indices, gridInfo, precision_policy) - self.name = "BounceBackFullway" - self.implementationStep = "PostCollision" - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the bounce-back boundary condition by reflecting the input distribution functions at the - boundary nodes in the opposite direction. - """ - - return fin[self.indices][..., self.lattice.opp_indices] - -class BounceBackMoving(BoundaryCondition): - """ - Moving bounce-back boundary condition for a lattice Boltzmann method simulation. - - This class implements a moving bounce-back boundary condition, where particles hitting the boundary are reflected - back in the direction they came from, with an additional velocity due to the movement of the boundary. The boundary - condition is applied after the collision step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "BounceBackFullwayMoving". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostCollision". - isDynamic : bool - Whether the boundary condition is dynamic (changes over time). For this class, it is True. - update_function : function - A function that updates the boundary condition. For this class, it is a function that updates the boundary - condition based on the current time step. The signature of the function is `update_function(time) -> (indices, vel)`, - - """ - def __init__(self, gridInfo, precision_policy, update_function=None): - # We get the indices at time zero to pass to the parent class for initialization - indices, _ = update_function(0) - super().__init__(indices, gridInfo, precision_policy) - self.name = "BounceBackFullwayMoving" - self.implementationStep = "PostCollision" - self.isDynamic = True - self.update_function = jit(update_function) - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin, time): - """ - Applies the moving bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - time : int - The current time step. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - indices, vel = self.update_function(time) - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 6.0 * self.lattice.w * jnp.dot(vel, c) - return fout.at[indices].set(fin[indices][..., self.lattice.opp_indices] - cu) - - -class BounceBackHalfway(BoundaryCondition): - """ - Halfway bounce-back boundary condition for a lattice Boltzmann method simulation. - - This class implements a halfway bounce-back boundary condition. The boundary condition is applied after - the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "BounceBackHalfway". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - needsExtraConfiguration : bool - Whether the boundary condition needs extra configuration before it can be applied. For this class, it is True. - isSolid : bool - Whether the boundary condition represents a solid boundary. For this class, it is True. - vel : array-like - The prescribed value of velocity vector for the boundary condition. No-slip BC is assumed if vel=None (default). - """ - def __init__(self, indices, gridInfo, precision_policy, vel=None): - super().__init__(indices, gridInfo, precision_policy) - self.name = "BounceBackHalfway" - self.implementationStep = "PostStreaming" - self.needsExtraConfiguration = True - self.isSolid = True - self.vel = vel - - def configure(self, boundaryMask): - """ - Configures the boundary condition. - - Parameters - ---------- - boundaryMask : array-like - The grid mask for the boundary voxels. - - Returns - ------- - None - - Notes - ----- - This method performs an index shift for the halfway bounce-back boundary condition. It updates the indices of - the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes. - """ - # Perform index shift for halfway BB. - hasFluidNeighbour = ~boundaryMask[:, self.lattice.opp_indices] - nbd_orig = len(self.indices[0]) - idx = np.array(self.indices).T - idx_trg = [] - for i in range(self.lattice.q): - idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) - indices_new = np.unique(np.vstack(idx_trg), axis=0) - self.indices = tuple(indices_new.T) - nbd_modified = len(self.indices[0]) - if (nbd_orig != nbd_modified) and self.vel is not None: - vel_avg = np.mean(self.vel, axis=0) - self.vel = jnp.zeros(indices_new.shape, dtype=self.precisionPolicy.compute_dtype) + vel_avg - print("WARNING: assuming a constant averaged velocity vector is imposed at all BC cells!") - - return - - @partial(jit, static_argnums=(0,)) - def impose_boundary_vel(self, fbd, bindex): - c = jnp.array(self.lattice.c, dtype=self.precisionPolicy.compute_dtype) - cu = 6.0 * self.lattice.w * jnp.dot(self.vel, c) - fbd = fbd.at[bindex, self.imissing].add(-cu[bindex, self.iknown]) - return fbd - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the halfway bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) - if self.vel is not None: - fbd = self.impose_boundary_vel(fbd, bindex) - return fbd - -class EquilibriumBC(BoundaryCondition): - """ - Equilibrium boundary condition for a lattice Boltzmann method simulation. - - This class implements an equilibrium boundary condition, where the distribution function at the boundary nodes is - set to the equilibrium distribution function. The boundary condition is applied after the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "EquilibriumBC". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - out : jax.numpy.ndarray - The equilibrium distribution function at the boundary nodes. - """ - - def __init__(self, indices, gridInfo, precision_policy, rho, u): - super().__init__(indices, gridInfo, precision_policy) - self.out = self.precisionPolicy.cast_to_output(self.equilibrium(rho, u)) - self.name = "EquilibriumBC" - self.implementationStep = "PostStreaming" - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the equilibrium boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the equilibrium boundary condition by setting the output distribution functions at the - boundary nodes to the equilibrium distribution function. - """ - return self.out - -class DoNothing(BoundaryCondition): - def __init__(self, indices, gridInfo, precision_policy): - """ - Do-nothing boundary condition for a lattice Boltzmann method simulation. - - This class implements a do-nothing boundary condition, where no action is taken at the boundary nodes. The boundary - condition is applied after the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "DoNothing". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - - Notes - ----- - This boundary condition enforces skipping of streaming altogether as it sets post-streaming equal to post-collision - populations (so no streaming at this BC voxels). The problem with returning post-streaming values or "fout[self.indices] - is that the information that exit the domain on the opposite side of this boundary, would "re-enter". This is because - we roll the entire array and so the boundary condition acts like a one-way periodic BC. If EquilibriumBC is used as - the BC for that opposite boundary, then the rolled-in values are taken from the initial condition at equilibrium. - Otherwise if ZouHe is used for example the simulation looks like a run-down simulation at low-Re. The opposite boundary - may be even a wall (consider pipebend example). If we correct imissing directions and assign "fin", this method becomes - much less stable and also one needs to correctly take care of corner cases. - """ - super().__init__(indices, gridInfo, precision_policy) - self.name = "DoNothing" - self.implementationStep = "PostStreaming" - - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the do-nothing boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the do-nothing boundary condition by simply returning the input distribution functions at the - boundary nodes. - """ - return fin[self.indices] - - -class ZouHe(BoundaryCondition): - """ - Zou-He boundary condition for a lattice Boltzmann method simulation. - - This class implements the Zou-He boundary condition, which is a non-equilibrium bounce-back boundary condition. - It can be used to set inflow and outflow boundary conditions with prescribed pressure or velocity. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "ZouHe". - implementationStep : str - The step in the lattice Boltzmann method algorithm at which the boundary condition is applied. For this class, - it is "PostStreaming". - type : str - The type of the boundary condition. It can be either 'velocity' for a prescribed velocity boundary condition, - or 'pressure' for a prescribed pressure boundary condition. - prescribed : float or array-like - The prescribed values for the boundary condition. It can be either the prescribed velocities for a 'velocity' - boundary condition, or the prescribed pressures for a 'pressure' boundary condition. - - References - ---------- - Zou, Q., & He, X. (1997). On pressure and velocity boundary conditions for the lattice Boltzmann BGK model. - Physics of Fluids, 9(6), 1591-1598. doi:10.1063/1.869307 - """ - def __init__(self, indices, gridInfo, precision_policy, type, prescribed): - super().__init__(indices, gridInfo, precision_policy) - self.name = "ZouHe" - self.implementationStep = "PostStreaming" - self.type = type - self.prescribed = prescribed - self.needsExtraConfiguration = True - - def configure(self, boundaryMask): - """ - Correct boundary indices to ensure that only voxelized surfaces with normal vectors along main cartesian axes - are assigned this type of BC. - """ - nv = np.dot(self.lattice.c, ~boundaryMask.T) - corner_voxels = np.count_nonzero(nv, axis=0) > 1 - # removed_voxels = np.array(self.indices)[:, corner_voxels] - self.indices = tuple(np.array(self.indices)[:, ~corner_voxels]) - self.prescribed = self.prescribed[~corner_voxels] - return - - @partial(jit, static_argnums=(0,), inline=True) - def calculate_vel(self, fpop, rho): - """ - Calculate velocity based on the prescribed pressure/density (Zou/He BC) - """ - unormal = -1. + 1. / rho * (jnp.sum(fpop[self.indices] * self.imiddleMask, axis=1) + - 2. * jnp.sum(fpop[self.indices] * self.iknownMask, axis=1)) - - # Return the above unormal as a normal vector which sets the tangential velocities to zero - vel = unormal[:, None] * self.normals - return vel - - @partial(jit, static_argnums=(0,), inline=True) - def calculate_rho(self, fpop, vel): - """ - Calculate density based on the prescribed velocity (Zou/He BC) - """ - unormal = np.sum(self.normals*vel, axis=1) - - rho = (1.0/(1.0 + unormal))[..., None] * (jnp.sum(fpop[self.indices] * self.imiddleMask, axis=1, keepdims=True) + - 2.*jnp.sum(fpop[self.indices] * self.iknownMask, axis=1, keepdims=True)) - return rho - - @partial(jit, static_argnums=(0,), inline=True) - def calculate_equilibrium(self, fpop): - """ - This is the ZouHe method of calculating the missing macroscopic variables at the boundary. - """ - if self.type == 'velocity': - vel = self.prescribed - rho = self.calculate_rho(fpop, vel) - elif self.type == 'pressure': - rho = self.prescribed - vel = self.calculate_vel(fpop, rho) - else: - raise ValueError(f"type = {self.type} not supported! Use \'pressure\' or \'velocity\'.") - - # compute feq at the boundary - feq = self.equilibrium(rho, vel) - return feq - - @partial(jit, static_argnums=(0,), inline=True) - def bounceback_nonequilibrium(self, fpop, feq): - """ - Calculate unknown populations using bounce-back of non-equilibrium populations - a la original Zou & He formulation - """ - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fpop[self.indices] - fknown = fpop[self.indices][bindex, self.iknown] + feq[bindex, self.imissing] - feq[bindex, self.iknown] - fbd = fbd.at[bindex, self.imissing].set(fknown) - return fbd - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, _): - """ - Applies the Zou-He boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - _ : jax.numpy.ndarray - The input distribution functions. This is not used in this method. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the Zou-He boundary condition by first computing the equilibrium distribution functions based - on the prescribed values and the type of boundary condition, and then setting the unknown distribution functions - based on the non-equilibrium bounce-back method. - Tangential velocity is not ensured to be zero by adding transverse contributions based on - Hecth & Harting (2010) (doi:10.1088/1742-5468/2010/01/P01018) as it caused numerical instabilities at higher - Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. - """ - # compute the equilibrium based on prescribed values and the type of BC - feq = self.calculate_equilibrium(fout) - - # set the unknown f populations based on the non-equilibrium bounce-back method - fbd = self.bounceback_nonequilibrium(fout, feq) - - - return fbd - -class Regularized(ZouHe): - """ - Regularized boundary condition for a lattice Boltzmann method simulation. - - This class implements the regularized boundary condition, which is a non-equilibrium bounce-back boundary condition - with additional regularization. It can be used to set inflow and outflow boundary conditions with prescribed pressure - or velocity. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "Regularized". - Qi : numpy.ndarray - The Qi tensor, which is used in the regularization of the distribution functions. - - References - ---------- - Latt, J. (2007). Hydrodynamic limit of lattice Boltzmann equations. PhD thesis, University of Geneva. - Latt, J., Chopard, B., Malaspinas, O., Deville, M., & Michler, A. (2008). Straight velocity boundaries in the - lattice Boltzmann method. Physical Review E, 77(5), 056703. doi:10.1103/PhysRevE.77.056703 - """ - - def __init__(self, indices, gridInfo, precision_policy, type, prescribed): - super().__init__(indices, gridInfo, precision_policy, type, prescribed) - self.name = "Regularized" - #TODO for Hesam: check to understand why corner cases cause instability here. - # self.needsExtraConfiguration = False - self.construct_symmetric_lattice_moment() - - def construct_symmetric_lattice_moment(self): - """ - Construct the symmetric lattice moment Qi. - - The Qi tensor is used in the regularization of the distribution functions. It is defined as Qi = cc - cs^2*I, - where cc is the tensor of lattice velocities, cs is the speed of sound, and I is the identity tensor. - """ - Qi = self.lattice.cc - if self.dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif self.dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {self.dim} not supported") - - # Qi = cc - cs^2*I - Qi = Qi.at[:, diagonal].set(self.lattice.cc[:, diagonal] - 1./3.) - - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi = Qi.at[:, offdiagonal].set(self.lattice.cc[:, offdiagonal] * 2.0) - - self.Qi = Qi.T - return - - @partial(jit, static_argnums=(0,), inline=True) - def regularize_fpop(self, fpop, feq): - """ - Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. - - Parameters - ---------- - fpop : jax.numpy.ndarray - The distribution functions. - feq : jax.numpy.ndarray - The equilibrium distribution functions. - - Returns - ------- - jax.numpy.ndarray - The regularized distribution functions. - """ - - # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} - f_neq = fpop - feq - PiNeq = self.momentum_flux(f_neq) - # PiNeq = self.momentum_flux(fpop) - self.momentum_flux(feq) - - # Compute double dot product Qi:Pi1 - # QiPi1 = np.zeros_like(fpop) - # Pi1 = PiNeq - # QiPi1 = jnp.dot(Qi, Pi1) - QiPi1 = jnp.dot(PiNeq, self.Qi) - - # assign all populations based on eq 45 of Latt et al (2008) - # fneq ~ f^1 - fpop1 = 9. / 2. * self.lattice.w[None, :] * QiPi1 - fpop_regularized = feq + fpop1 - - return fpop_regularized - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, _): - """ - Applies the regularized boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - _ : jax.numpy.ndarray - The input distribution functions. This is not used in this method. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - - Notes - ----- - This method applies the regularized boundary condition by first computing the equilibrium distribution functions based - on the prescribed values and the type of boundary condition, then setting the unknown distribution functions - based on the non-equilibrium bounce-back method, and finally regularizing the distribution functions. - """ - - # compute the equilibrium based on prescribed values and the type of BC - feq = self.calculate_equilibrium(fout) - - # set the unknown f populations based on the non-equilibrium bounce-back method - fbd = self.bounceback_nonequilibrium(fout, feq) - - # Regularize the boundary fpop - fbd = self.regularize_fpop(fbd, feq) - return fbd - - -class ExtrapolationOutflow(BoundaryCondition): - """ - Extrapolation outflow boundary condition for a lattice Boltzmann method simulation. - - This class implements the extrapolation outflow boundary condition, which is a type of outflow boundary condition - that uses extrapolation to avoid strong wave reflections. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "ExtrapolationOutflow". - sound_speed : float - The speed of sound in the simulation. - - References - ---------- - Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three - dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547. - doi:10.1016/j.camwa.2015.05.001. - """ - - def __init__(self, indices, gridInfo, precision_policy): - super().__init__(indices, gridInfo, precision_policy) - self.name = "ExtrapolationOutflow" - self.needsExtraConfiguration = True - self.sound_speed = 1./jnp.sqrt(3.) - - def configure(self, boundaryMask): - """ - Configure the boundary condition by finding neighbouring voxel indices. - - Parameters - ---------- - boundaryMask : np.ndarray - The grid mask for the boundary voxels. - """ - hasFluidNeighbour = ~boundaryMask[:, self.lattice.opp_indices] - idx = np.array(self.indices).T - idx_trg = [] - for i in range(self.lattice.q): - idx_trg.append(idx[hasFluidNeighbour[:, i], :] + self.lattice.c[:, i]) - indices_nbr = np.unique(np.vstack(idx_trg), axis=0) - self.indices_nbr = tuple(indices_nbr.T) - - return - - @partial(jit, static_argnums=(0, 3), inline=True) - def prepare_populations(self, fout, fin, implementation_step): - """ - Prepares the distribution functions for the boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The incoming distribution functions. - fin : jax.numpy.ndarray - The outgoing distribution functions. - implementation_step : str - The step in the lattice Boltzmann method algorithm at which the preparation is applied. - - Returns - ------- - jax.numpy.ndarray - The prepared distribution functions. - - Notes - ----- - Because this function is called "PostCollision", f_poststreaming refers to previous time step or t-1 - """ - f_postcollision = fout - f_poststreaming = fin - if implementation_step == 'PostStreaming': - return f_postcollision - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fps_bdr = f_poststreaming[self.indices] - fps_nbr = f_poststreaming[self.indices_nbr] - fpc_bdr = f_postcollision[self.indices] - fpop = fps_bdr[bindex, self.imissing] - fpop_neighbour = fps_nbr[bindex, self.imissing] - fpop_extrapolated = self.sound_speed * fpop_neighbour + (1. - self.sound_speed) * fpop - - # Use the iknown directions of f_postcollision that leave the domain during streaming to store the BC data - fpc_bdr = fpc_bdr.at[bindex, self.iknown].set(fpop_extrapolated) - f_postcollision = f_postcollision.at[self.indices].set(fpc_bdr) - return f_postcollision - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the extrapolation outflow boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - fbd = fbd.at[bindex, self.imissing].set(fin[self.indices][bindex, self.iknown]) - return fbd - - -class InterpolatedBounceBackBouzidi(BounceBackHalfway): - """ - A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice - Boltzmann method simulation. - - This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after - the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "InterpolatedBounceBackBouzidi". - implicit_distances : array-like - An array of shape (nx,ny,nz) indicating the signed-distance field from the solid walls - weights : array-like - An array of shape (number_of_bc_cells, q) initialized as None and constructed using implicit_distances array - during runtime. These "weights" are associated with the fractional distance of fluid cell to the boundary - position defined as: weights(dir_i) = |x_fluid - x_boundary(dir_i)| / |x_fluid - x_solid(dir_i)|. - """ - - def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): - - super().__init__(indices, grid_info, precision_policy, vel=vel) - self.name = "InterpolatedBounceBackBouzidi" - self.implicit_distances = implicit_distances - self.weights = None - - def set_proximity_ratio(self): - """ - Creates the interpolation data needed for the boundary condition. - - Returns - ------- - None. The function updates the object's weights attribute in place. - """ - idx = np.array(self.indices).T - self.weights = np.full((idx.shape[0], self.lattice.q), 0.5) - c = np.array(self.lattice.c) - sdf_f = self.implicit_distances[self.indices] - for q in range(1, self.lattice.q): - solid_indices = idx + c[:, q] - solid_indices_tuple = tuple(map(tuple, solid_indices.T)) - sdf_s = self.implicit_distances[solid_indices_tuple] - mask = self.iknownMask[:, q] - self.weights[mask, q] = sdf_f[mask] / (sdf_f[mask] - sdf_s[mask]) - return - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the halfway bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - if self.weights is None: - self.set_proximity_ratio() - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - f_postcollision_iknown = fin[self.indices][bindex, self.iknown] - f_postcollision_imissing = fin[self.indices][bindex, self.imissing] - f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] - - # if weights<0.5 - fs_near = 2. * self.weights * f_postcollision_iknown + \ - (1.0 - 2.0 * self.weights) * f_poststreaming_iknown - - # if weights>=0.5 - fs_far = 1.0 / (2. * self.weights) * f_postcollision_iknown + \ - (2.0 * self.weights - 1.0) / (2. * self.weights) * f_postcollision_imissing - - # combine near and far contributions - fmissing = jnp.where(self.weights < 0.5, fs_near, fs_far) - fbd = fbd.at[bindex, self.imissing].set(fmissing) - - if self.vel is not None: - fbd = self.impose_boundary_vel(fbd, bindex) - return fbd - - -class InterpolatedBounceBackDifferentiable(InterpolatedBounceBackBouzidi): - """ - A differentiable variant of the "InterpolatedBounceBackBouzidi" BC scheme. This BC is now differentiable at - self.weight = 0.5 unlike the original Bouzidi scheme which switches between 2 equations at weight=0.5. Refer to - [1] (their Appendix E) for more information. - - References - ---------- - [1] Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three - dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507–547. - doi:10.1016/j.camwa.2015.05.001. - - - This class implements a interpolated bounce-back boundary condition. The boundary condition is applied after - the streaming step. - - Attributes - ---------- - name : str - The name of the boundary condition. For this class, it is "InterpolatedBounceBackDifferentiable". - """ - - def __init__(self, indices, implicit_distances, grid_info, precision_policy, vel=None): - - super().__init__(indices, implicit_distances, grid_info, precision_policy, vel=vel) - self.name = "InterpolatedBounceBackDifferentiable" - - - @partial(jit, static_argnums=(0,)) - def apply(self, fout, fin): - """ - Applies the halfway bounce-back boundary condition. - - Parameters - ---------- - fout : jax.numpy.ndarray - The output distribution functions. - fin : jax.numpy.ndarray - The input distribution functions. - - Returns - ------- - jax.numpy.ndarray - The modified output distribution functions after applying the boundary condition. - """ - if self.weights is None: - self.set_proximity_ratio() - nbd = len(self.indices[0]) - bindex = np.arange(nbd)[:, None] - fbd = fout[self.indices] - f_postcollision_iknown = fin[self.indices][bindex, self.iknown] - f_postcollision_imissing = fin[self.indices][bindex, self.imissing] - f_poststreaming_iknown = fout[self.indices][bindex, self.iknown] - fmissing = ((1. - self.weights) * f_poststreaming_iknown + - self.weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + self.weights) - fbd = fbd.at[bindex, self.imissing].set(fmissing) - - if self.vel is not None: - fbd = self.impose_boundary_vel(fbd, bindex) - return fbd \ No newline at end of file diff --git a/src/lattice.py b/src/lattice.py deleted file mode 100644 index 788796b..0000000 --- a/src/lattice.py +++ /dev/null @@ -1,281 +0,0 @@ -import re -import numpy as np -import jax.numpy as jnp - - -class Lattice(object): - """ - This class represents a lattice in the Lattice Boltzmann Method. - - It stores the properties of the lattice, including the dimensions, the number of - velocities, the velocity vectors, the weights, the moments, and the indices of the - opposite, main, right, and left velocities. - - The class also provides methods to construct these properties based on the name of the - lattice. - - Parameters - ---------- - name: str - The name of the lattice, which specifies the dimensions and the number of velocities. - For example, "D2Q9" represents a 2D lattice with 9 velocities. - precision: str, optional - The precision of the computations. It can be "f32/f32", "f32/f16", "f64/f64", - "f64/f32", or "f64/f16". The first part before the slash is the precision of the - computations, and the second part after the slash is the precision of the outputs. - """ - def __init__(self, name, precision="f32/f32") -> None: - self.name = name - dq = re.findall(r"\d+", name) - self.precision = precision - self.d = int(dq[0]) - self.q = int(dq[1]) - if precision == "f32/f32" or precision == "f32/f16": - self.precisionPolicy = jnp.float32 - elif precision == "f64/f64" or precision == "f64/f32" or precision == "f64/f16": - self.precisionPolicy = jnp.float64 - elif precision == "f16/f16": - self.precisionPolicy = jnp.float16 - else: - raise ValueError("precision not supported") - - # Construct the properties of the lattice - self.c = jnp.array(self.construct_lattice_velocity(), dtype=jnp.int8) - self.w = jnp.array(self.construct_lattice_weight(), dtype=self.precisionPolicy) - self.cc = jnp.array(self.construct_lattice_moment(), dtype=self.precisionPolicy) - self.opp_indices = jnp.array(self.construct_opposite_indices(), dtype=jnp.int8) - self.main_indices = jnp.array(self.construct_main_indices(), dtype=jnp.int8) - self.right_indices = np.array(self.construct_right_indices(), dtype=jnp.int8) - self.left_indices = np.array(self.construct_left_indices(), dtype=jnp.int8) - - def construct_opposite_indices(self): - """ - This function constructs the indices of the opposite velocities for each velocity. - - The opposite velocity of a velocity is the velocity that has the same magnitude but the - opposite direction. - - Returns - ------- - opposite: numpy.ndarray - The indices of the opposite velocities. - """ - c = self.c.T - opposite = np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) - return opposite - - def construct_right_indices(self): - """ - This function constructs the indices of the velocities that point in the positive - x-direction. - - Returns - ------- - numpy.ndarray - The indices of the right velocities. - """ - c = self.c.T - return np.nonzero(c[:, 0] == 1)[0] - - def construct_left_indices(self): - """ - This function constructs the indices of the velocities that point in the negative - x-direction. - - Returns - ------- - numpy.ndarray - The indices of the left velocities. - """ - c = self.c.T - return np.nonzero(c[:, 0] == -1)[0] - - def construct_main_indices(self): - """ - This function constructs the indices of the main velocities. - - The main velocities are the velocities that have a magnitude of 1 in lattice units. - - Returns - ------- - numpy.ndarray - The indices of the main velocities. - """ - c = self.c.T - if self.d == 2: - return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] - - elif self.d == 3: - return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] - - def construct_lattice_velocity(self): - """ - This function constructs the velocity vectors of the lattice. - - The velocity vectors are defined based on the name of the lattice. For example, for a D2Q9 - lattice, there are 9 velocities: (0,0), (1,0), (-1,0), (0,1), (0,-1), (1,1), (-1,-1), - (1,-1), and (-1,1). - - Returns - ------- - c.T: numpy.ndarray - The velocity vectors of the lattice. - """ - if self.name == "D2Q9": # D2Q9 - cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] - cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] - c = np.array(tuple(zip(cx, cy))) - elif self.name == "D3Q19": # D3Q19 - c = [(x, y, z) for x in [0, -1, 1] for y in [0, -1, 1] for z in [0, -1, 1]] - c = np.array([ci for ci in c if np.linalg.norm(ci) < 1.5]) - elif self.name == "D3Q27": # D3Q27 - c = [(x, y, z) for x in [0, -1, 1] for y in [0, -1, 1] for z in [0, -1, 1]] - # c = np.array([ci for ci in c if np.linalg.norm(ci) < 1.5]) - c = np.array(c) - else: - raise ValueError("Supported Lattice types are D2Q9, D3Q19 and D3Q27") - - return c.T - - def construct_lattice_weight(self): - """ - This function constructs the weights of the lattice. - - The weights are defined based on the name of the lattice. For example, for a D2Q9 lattice, - the weights are 4/9 for the rest velocity, 1/9 for the main velocities, and 1/36 for the - diagonal velocities. - - Returns - ------- - w: numpy.ndarray - The weights of the lattice. - """ - # Get the transpose of the lattice vector - c = self.c.T - - # Initialize the weights to be 1/36 - w = 1.0 / 36.0 * np.ones(self.q) - - # Update the weights for 2D and 3D lattices - if self.name == "D2Q9": - w[np.linalg.norm(c, axis=1) < 1.1] = 1.0 / 9.0 - w[0] = 4.0 / 9.0 - elif self.name == "D3Q19": - w[np.linalg.norm(c, axis=1) < 1.1] = 2.0 / 36.0 - w[0] = 1.0 / 3.0 - elif self.name == "D3Q27": - cl = np.linalg.norm(c, axis=1) - w[np.isclose(cl, 1.0, atol=1e-8)] = 2.0 / 27.0 - w[(cl > 1) & (cl <= np.sqrt(2))] = 1.0 / 54.0 - w[(cl > np.sqrt(2)) & (cl <= np.sqrt(3))] = 1.0 / 216.0 - w[0] = 8.0 / 27.0 - else: - raise ValueError("Supported Lattice types are D2Q9, D3Q19 and D3Q27") - - # Return the weights - return w - - def construct_lattice_moment(self): - """ - This function constructs the moments of the lattice. - - The moments are the products of the velocity vectors, which are used in the computation of - the equilibrium distribution functions and the collision operator in the Lattice Boltzmann - Method (LBM). - - Returns - ------- - cc: numpy.ndarray - The moments of the lattice. - """ - c = self.c.T - # Counter for the loop - cntr = 0 - - # nt: number of independent elements of a symmetric tensor - nt = self.d * (self.d + 1) // 2 - - cc = np.zeros((self.q, nt)) - for a in range(0, self.d): - for b in range(a, self.d): - cc[:, cntr] = c[:, a] * c[:, b] - cntr += 1 - - return cc - - def __str__(self): - return self.name - -class LatticeD2Q9(Lattice): - """ - Lattice class for 2D D2Q9 lattice. - - D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the - Lat tice Boltzmann Method for simulating fluid flows in two dimensions. - - Parameters - ---------- - precision: str, optional - The precision of the lattice. The default is "f32/f32" - """ - def __init__(self, precision="f32/f32"): - super().__init__("D2Q9", precision) - self._set_constants() - - def _set_constants(self): - self.cs = jnp.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - self.i_s = jnp.asarray(list(range(9))) - self.im = 3 # Number of imiddles (includes center) - self.ik = 3 # Number of iknowns or iunknowns - - -class LatticeD3Q19(Lattice): - """ - Lattice class for 3D D3Q19 lattice. - - D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the - Lattice Boltzmann Method for simulating fluid flows in three dimensions. - - Parameters - ---------- - precision: str, optional - The precision of the lattice. The default is "f32/f32" - """ - def __init__(self, precision="f32/f32"): - super().__init__("D3Q19", precision) - self._set_constants() - - def _set_constants(self): - self.cs = jnp.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - self.i_s = jnp.asarray(list(range(19)), dtype=jnp.int8) - - self.im = 9 # Number of imiddles (includes center) - self.ik = 5 # Number of iknowns or iunknowns - - -class LatticeD3Q27(Lattice): - """ - Lattice class for 3D D3Q27 lattice. - - D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the - Lattice Boltzmann Method for simulating fluid flows in three dimensions. - - Parameters - ---------- - precision: str, optional - The precision of the lattice. The default is "f32/f32" - """ - - def __init__(self, precision="f32/f32"): - super().__init__("D3Q27", precision) - self._set_constants() - - def _set_constants(self): - self.cs = jnp.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - self.i_s = jnp.asarray(list(range(27)), dtype=jnp.int8) \ No newline at end of file diff --git a/src/models.py b/src/models.py deleted file mode 100644 index a0500c8..0000000 --- a/src/models.py +++ /dev/null @@ -1,260 +0,0 @@ -import jax.numpy as jnp -from jax import jit -from functools import partial -from src.base import LBMBase -""" -Collision operators are defined in this file for different models. -""" - -class BGKSim(LBMBase): - """ - BGK simulation class. - - This class implements the Bhatnagar-Gross-Krook (BGK) approximation for the collision step in the Lattice Boltzmann Method. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, f): - """ - BGK collision step for lattice. - - The collision step is where the main physics of the LBM is applied. In the BGK approximation, - the distribution function is relaxed towards the equilibrium distribution function. - """ - f = self.precisionPolicy.cast_to_compute(f) - rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, cast_output=False) - fneq = f - feq - fout = f - self.omega * fneq - if self.force is not None: - fout = self.apply_force(fout, feq, rho, u) - return self.precisionPolicy.cast_to_output(fout) - -class KBCSim(LBMBase): - """ - KBC simulation class. - - This class implements the Karlin-Bösch-Chikatamarla (KBC) model for the collision step in the Lattice Boltzmann Method. - """ - def __init__(self, **kwargs): - if kwargs.get('lattice').name != 'D3Q27' and kwargs.get('nz') > 0: - raise ValueError("KBC collision operator in 3D must only be used with D3Q27 lattice.") - super().__init__(**kwargs) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, f): - """ - KBC collision step for lattice. - """ - f = self.precisionPolicy.cast_to_compute(f) - tiny = 1e-32 - beta = self.omega * 0.5 - rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, cast_output=False) - fneq = f - feq - if self.dim == 2: - deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 - else: - deltaS = self.fdecompose_shear_d3q27(fneq) * rho - deltaH = fneq - deltaS - invBeta = 1.0 / beta - gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq)) - - fout = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH) - - # add external force - if self.force is not None: - fout = self.apply_force(fout, feq, rho, u) - return self.precisionPolicy.cast_to_output(fout) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision_modified(self, f): - """ - Alternative KBC collision step for lattice. - Note: - At low Reynolds number the orignal KBC collision above produces inaccurate results because - it does not check for the entropy increase/decrease. The KBC stabalizations should only be - applied in principle to cells whose entropy decrease after a regular BGK collision. This is - the case in most cells at higher Reynolds numbers and hence a check may not be needed. - Overall the following alternative collision is more reliable and may replace the original - implementation. The issue at the moment is that it is about 60-80% slower than the above method. - """ - f = self.precisionPolicy.cast_to_compute(f) - tiny = 1e-32 - beta = self.omega * 0.5 - rho, u = self.update_macroscopic(f) - feq = self.equilibrium(rho, u, castOutput=False) - - # Alternative KBC: only stabalizes for voxels whose entropy decreases after BGK collision. - f_bgk = f - self.omega * (f - feq) - H_fin = jnp.sum(f * jnp.log(f / self.w), axis=-1, keepdims=True) - H_fout = jnp.sum(f_bgk * jnp.log(f_bgk / self.w), axis=-1, keepdims=True) - - # the rest is identical to collision_deprecated - fneq = f - feq - if self.dim == 2: - deltaS = self.fdecompose_shear_d2q9(fneq) * rho / 4.0 - else: - deltaS = self.fdecompose_shear_d3q27(fneq) * rho - deltaH = fneq - deltaS - invBeta = 1.0 / beta - gamma = invBeta - (2.0 - invBeta) * self.entropic_scalar_product(deltaS, deltaH, feq) / (tiny + self.entropic_scalar_product(deltaH, deltaH, feq)) - - f_kbc = f - beta * (2.0 * deltaS + gamma[..., None] * deltaH) - fout = jnp.where(H_fout > H_fin, f_kbc, f_bgk) - - # add external force - if self.force is not None: - fout = self.apply_force(fout, feq, rho, u) - return self.precisionPolicy.cast_to_output(fout) - - @partial(jit, static_argnums=(0,), inline=True) - def entropic_scalar_product(self, x, y, feq): - """ - Compute the entropic scalar product of x and y to approximate gamma in KBC. - - Returns - ------- - jax.numpy.array - Entropic scalar product of x, y, and feq. - """ - return jnp.sum(x * y / feq, axis=-1) - - @partial(jit, static_argnums=(0,), inline=True) - def fdecompose_shear_d2q9(self, fneq): - """ - Decompose fneq into shear components for D2Q9 lattice. - - Parameters - ---------- - fneq : jax.numpy.array - Non-equilibrium distribution function. - - Returns - ------- - jax.numpy.array - Shear components of fneq. - """ - Pi = self.momentum_flux(fneq) - N = Pi[..., 0] - Pi[..., 2] - s = jnp.zeros_like(fneq) - s = s.at[..., 6].set(N) - s = s.at[..., 3].set(N) - s = s.at[..., 2].set(-N) - s = s.at[..., 1].set(-N) - s = s.at[..., 8].set(Pi[..., 1]) - s = s.at[..., 4].set(-Pi[..., 1]) - s = s.at[..., 5].set(-Pi[..., 1]) - s = s.at[..., 7].set(Pi[..., 1]) - - return s - - @partial(jit, static_argnums=(0,), inline=True) - def fdecompose_shear_d3q27(self, fneq): - """ - Decompose fneq into shear components for D3Q27 lattice. - - Parameters - ---------- - fneq : jax.numpy.ndarray - Non-equilibrium distribution function. - - Returns - ------- - jax.numpy.ndarray - Shear components of fneq. - """ - # if self.grid.dim == 3: - # diagonal = (0, 3, 5) - # offdiagonal = (1, 2, 4) - # elif self.grid.dim == 2: - # diagonal = (0, 2) - # offdiagonal = (1,) - - # c= - # array([[0, 0, 0],-----0 - # [0, 0, -1],----1 - # [0, 0, 1],-----2 - # [0, -1, 0],----3 - # [0, -1, -1],---4 - # [0, -1, 1],----5 - # [0, 1, 0],-----6 - # [0, 1, -1],----7 - # [0, 1, 1],-----8 - # [-1, 0, 0],----9 - # [-1, 0, -1],--10 - # [-1, 0, 1],---11 - # [-1, -1, 0],--12 - # [-1, -1, -1],-13 - # [-1, -1, 1],--14 - # [-1, 1, 0],---15 - # [-1, 1, -1],--16 - # [-1, 1, 1],---17 - # [1, 0, 0],----18 - # [1, 0, -1],---19 - # [1, 0, 1],----20 - # [1, -1, 0],---21 - # [1, -1, -1],--22 - # [1, -1, 1],---23 - # [1, 1, 0],----24 - # [1, 1, -1],---25 - # [1, 1, 1]])---26 - Pi = self.momentum_flux(fneq) - Nxz = Pi[..., 0] - Pi[..., 5] - Nyz = Pi[..., 3] - Pi[..., 5] - - # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) - s = jnp.zeros_like(fneq) - s = s.at[..., 9].set((2.0 * Nxz - Nyz) / 6.0) - s = s.at[..., 18].set((2.0 * Nxz - Nyz) / 6.0) - s = s.at[..., 3].set((-Nxz + 2.0 * Nyz) / 6.0) - s = s.at[..., 6].set((-Nxz + 2.0 * Nyz) / 6.0) - s = s.at[..., 1].set((-Nxz - Nyz) / 6.0) - s = s.at[..., 2].set((-Nxz - Nyz) / 6.0) - - # For c = (i, j, 0) - s = s.at[..., 12].set(Pi[..., 1] / 4.0) - s = s.at[..., 24].set(Pi[..., 1] / 4.0) - s = s.at[..., 21].set(-Pi[..., 1] / 4.0) - s = s.at[..., 15].set(-Pi[..., 1] / 4.0) - - # For c = (i, 0, k) - s = s.at[..., 10].set(Pi[..., 2] / 4.0) - s = s.at[..., 20].set(Pi[..., 2] / 4.0) - s = s.at[..., 19].set(-Pi[..., 2] / 4.0) - s = s.at[..., 11].set(-Pi[..., 2] / 4.0) - - # For c = (0, j, k) - s = s.at[..., 8].set(Pi[..., 4] / 4.0) - s = s.at[..., 4].set(Pi[..., 4] / 4.0) - s = s.at[..., 7].set(-Pi[..., 4] / 4.0) - s = s.at[..., 5].set(-Pi[..., 4] / 4.0) - - return s - - -class AdvectionDiffusionBGK(LBMBase): - """ - Advection Diffusion Model based on the BGK model. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.vel = kwargs.get("vel", None) - if self.vel is None: - raise ValueError("Velocity must be specified for AdvectionDiffusionBGK.") - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, f): - """ - BGK collision step for lattice. - """ - f = self.precisionPolicy.cast_to_compute(f) - rho =jnp.sum(f, axis=-1, keepdims=True) - feq = self.equilibrium(rho, self.vel, cast_output=False) - fneq = f - feq - fout = f - self.omega * fneq - return self.precisionPolicy.cast_to_output(fout) \ No newline at end of file diff --git a/xlb/__init__.py b/xlb/__init__.py new file mode 100644 index 0000000..8f49baa --- /dev/null +++ b/xlb/__init__.py @@ -0,0 +1,18 @@ +# Enum classes +from xlb.compute_backend import ComputeBackend +from xlb.physics_type import PhysicsType + +# Precision policy +import xlb.precision_policy + +# Velocity Set +import xlb.velocity_set + +# Operators +import xlb.operator.collision +import xlb.operator.stream +import xlb.operator.boundary_condition +# import xlb.operator.force +import xlb.operator.equilibrium +import xlb.operator.macroscopic +import xlb.operator.stepper diff --git a/src/base.py b/xlb/base.py similarity index 100% rename from src/base.py rename to xlb/base.py diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py new file mode 100644 index 0000000..dee998f --- /dev/null +++ b/xlb/compute_backend.py @@ -0,0 +1,9 @@ +# Enum used to keep track of the compute backends + +from enum import Enum + +class ComputeBackend(Enum): + JAX = 1 + NUMBA = 2 + PYTORCH = 3 + WARP = 4 diff --git a/xlb/experimental/__init__.py b/xlb/experimental/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/xlb/experimental/__init__.py @@ -0,0 +1 @@ + diff --git a/xlb/experimental/ooc/__init__.py b/xlb/experimental/ooc/__init__.py new file mode 100644 index 0000000..5206cc1 --- /dev/null +++ b/xlb/experimental/ooc/__init__.py @@ -0,0 +1,2 @@ +from xlb.experimental.ooc.out_of_core import OOCmap +from xlb.experimental.ooc.ooc_array import OOCArray diff --git a/xlb/experimental/ooc/ooc_array.py b/xlb/experimental/ooc/ooc_array.py new file mode 100644 index 0000000..f5e2c34 --- /dev/null +++ b/xlb/experimental/ooc/ooc_array.py @@ -0,0 +1,485 @@ +import numpy as np +import cupy as cp +#from mpi4py import MPI +import itertools +from dataclasses import dataclass + +from xlb.experimental.ooc.tiles.dense_tile import DenseTile, DenseGPUTile, DenseCPUTile +from xlb.experimental.ooc.tiles.compressed_tile import CompressedTile, CompressedGPUTile, CompressedCPUTile + + +class OOCArray: + """An out-of-core distributed array class. + + Parameters + ---------- + shape : tuple + The shape of the array. + dtype : cp.dtype + The data type of the array. + tile_shape : tuple + The shape of the tiles. Should be a factor of the shape. + padding : int or tuple + The padding of the tiles. + comm : MPI communicator + The MPI communicator. + devices : list of cp.cuda.Device + The list of GPU devices to use. + codec : Codec + The codec to use for compression. None for no compression (Dense tiles). + nr_compute_tiles : int + The number of compute tiles used for asynchronous copies. + TODO currently only 1 is supported when using JAX. + """ + + def __init__( + self, + shape, + dtype, + tile_shape, + padding=1, + comm=None, + devices=[cp.cuda.Device(0)], + codec=None, + nr_compute_tiles=1, + ): + self.shape = shape + self.tile_shape = tile_shape + self.dtype = dtype + if isinstance(padding, int): + padding = (padding,) * len(shape) + self.padding = padding + self.comm = comm + self.devices = devices + self.codec = codec + self.nr_compute_tiles = nr_compute_tiles + + # Set tile class + if self.codec is None: + self.Tile = DenseTile + self.DeviceTile = DenseGPUTile + self.HostTile = ( + DenseCPUTile # TODO: Possibly make HardDiskTile or something + ) + + else: + self.Tile = CompressedTile + self.DeviceTile = CompressedGPUTile + self.HostTile = CompressedCPUTile + + # Get process id and number of processes + self.pid = self.comm.Get_rank() + self.nr_proc = self.comm.Get_size() + + # Check that the tile shape divides the array shape. + if any([shape[i] % tile_shape[i] != 0 for i in range(len(shape))]): + raise ValueError(f"Tile shape {tile_shape} does not divide shape {shape}.") + self.tile_dims = tuple([shape[i] // tile_shape[i] for i in range(len(shape))]) + self.nr_tiles = np.prod(self.tile_dims) + + # Get number of tiles per process + if self.nr_tiles % self.nr_proc != 0: + raise ValueError( + f"Number of tiles {self.nr_tiles} does not divide number of processes {self.nr_proc}." + ) + self.nr_tiles_per_proc = self.nr_tiles // self.nr_proc + + # Make the tile mapppings + self.tile_process_map = {} + self.tile_device_map = {} + for i, tile_index in enumerate( + itertools.product(*[range(n) for n in self.tile_dims]) + ): + self.tile_process_map[tile_index] = i % self.nr_proc + self.tile_device_map[tile_index] = devices[ + i % len(devices) + ] # Checkoboard pattern, TODO: may not be optimal + + # Get my device + if self.nr_proc != len(self.devices): + raise ValueError( + f"Number of processes {self.nr_proc} does not equal number of devices {len(self.devices)}." + ) + self.device = self.devices[self.pid] + + # Make the tiles + self.tiles = {} + for tile_index in self.tile_process_map.keys(): + if self.pid == self.tile_process_map[tile_index]: + self.tiles[tile_index] = self.HostTile( + self.tile_shape, self.dtype, self.padding, self.codec + ) + + # Make GPU tiles for copying data between CPU and GPU + if self.nr_tiles % self.nr_compute_tiles != 0: + raise ValueError( + f"Number of tiles {self.nr_tiles} does not divide number of compute tiles {self.nr_compute_tiles}. This is used for asynchronous copies." + ) + compute_array_shape = [ + s + 2 * p for (s, p) in zip(self.tile_shape, self.padding) + ] + self.compute_tiles_htd = [] + self.compute_tiles_dth = [] + self.compute_streams_htd = [] + self.compute_streams_dth = [] + self.compute_arrays = [] + self.current_compute_index = 0 + with cp.cuda.Device(self.device): + for i in range(self.nr_compute_tiles): + # Make compute tiles for copying data + compute_tile = self.DeviceTile( + self.tile_shape, self.dtype, self.padding, self.codec + ) + self.compute_tiles_htd.append(compute_tile) + compute_tile = self.DeviceTile( + self.tile_shape, self.dtype, self.padding, self.codec + ) + self.compute_tiles_dth.append(compute_tile) + + # Make cupy stream + self.compute_streams_htd.append(cp.cuda.Stream(non_blocking=True)) + self.compute_streams_dth.append(cp.cuda.Stream(non_blocking=True)) + + # Make compute array + + self.compute_arrays.append(cp.empty(compute_array_shape, self.dtype)) + + # Make compute tile mappings + self.compute_tile_mapping_htd = {} + self.compute_tile_mapping_dth = {} + self.compute_stream_mapping_htd = {} + + def size(self): + """Return number of allocated bytes for all host tiles.""" + return sum([tile.size() for tile in self.tiles.values()]) + + def nbytes(self): + """Return number of bytes for all host tiles.""" + return sum([tile.nbytes for tile in self.tiles.values()]) + + def compression_ratio(self): + """Return the compression ratio for all host tiles.""" + return self.nbytes() / self.size() + + def compression_ratio(self): + """Return the compression ratio aggregated over all tiles.""" + + if self.codec is None: + return 1.0 + else: + total_bytes = 0 + total_uncompressed_bytes = 0 + for tile in self.tiles.values(): + ( + tile_total_bytes_uncompressed, + tile_total_bytes_compressed, + ) = tile.compression_ratio() + total_bytes += tile_total_bytes_compressed + total_uncompressed_bytes += tile_total_bytes_uncompressed + return total_uncompressed_bytes / total_bytes + + def update_compute_index(self): + """Update the current compute index.""" + self.current_compute_index = ( + self.current_compute_index + 1 + ) % self.nr_compute_tiles + + def _guess_next_tile_index(self, tile_index): + """Guess the next tile index to use for the compute array.""" + # TODO: This assumes access is sequential + tile_indices = list(self.tiles.keys()) + current_ind = tile_indices.index(tile_index) + next_ind = current_ind + 1 + if next_ind >= len(tile_indices): + return None + else: + return tile_indices[next_ind] + + def reset_queue_htd(self): + """Reset the queue for host to device copies.""" + + self.compute_tile_mapping_htd = {} + self.compute_stream_mapping_htd = {} + self.current_compute_index = 0 + + def managed_compute_tiles_htd(self, tile_index): + """Get the compute tiles needed for computation. + + Parameters + ---------- + tile_index : tuple + The tile index. + + Returns + ------- + compute_tile : ComputeTile + The compute tile needed for computation. + """ + + ################################################### + # TODO: This assumes access is sequential for tiles + ################################################### + + # Que up the next tiles + cur_tile_index = tile_index + cur_compute_index = self.current_compute_index + for i in range(self.nr_compute_tiles): + # Check if already in compute tile map and if not que it + if cur_tile_index not in self.compute_tile_mapping_htd.keys(): + # Get the store tile + tile = self.tiles[cur_tile_index] + + # Get the compute tile + compute_tile = self.compute_tiles_htd[cur_compute_index] + + # Get the compute stream + compute_stream = self.compute_streams_htd[cur_compute_index] + + # Copy the tile to the compute tile using the compute stream + with compute_stream: + tile.to_gpu_tile(compute_tile) + tile.to_gpu_tile(compute_tile) + + # Set the compute tile mapping + self.compute_tile_mapping_htd[cur_tile_index] = compute_tile + self.compute_stream_mapping_htd[cur_tile_index] = compute_stream + + # Update the tile index and compute index + cur_tile_index = self._guess_next_tile_index(cur_tile_index) + if cur_tile_index is None: + break + cur_compute_index = (cur_compute_index + 1) % self.nr_compute_tiles + + # Get the compute tile + self.compute_stream_mapping_htd[tile_index].synchronize() + compute_tile = self.compute_tile_mapping_htd[tile_index] + + # Pop the tile from the compute tile map + self.compute_tile_mapping_htd.pop(tile_index) + self.compute_stream_mapping_htd.pop(tile_index) + + # Return the compute tile + return compute_tile + + def get_compute_array(self, tile_index): + """Given a tile index, copy the tile to the compute array. + + Parameters + ---------- + tile_index : tuple + The tile index. + + Returns + ------- + compute_array : array + The compute array. + global_index : tuple + The lower bound index that the compute array corresponds to in the global array. + For example, if the compute array is the 0th tile and has padding 1, then the + global index will be (-1, -1, ..., -1). + """ + + # Get the compute tile + compute_tile = self.managed_compute_tiles_htd(tile_index) + + # Concatenate the sub-arrays to make the compute array + compute_tile.to_array(self.compute_arrays[self.current_compute_index]) + + # Return the compute array index in global array + global_index = tuple( + [i * s - p for (i, s, p) in zip(tile_index, self.tile_shape, self.padding)] + ) + + return self.compute_arrays[self.current_compute_index], global_index + + def set_tile(self, compute_array, tile_index): + """Given a tile index, copy the compute array to the tile. + + Parameters + ---------- + compute_array : array + The compute array. + tile_index : tuple + The tile index. + """ + + # Syncronize the current stream dth stream + stream = self.compute_streams_dth[self.current_compute_index] + stream.synchronize() + cp.cuda.get_current_stream().synchronize() + + # Set the compute tile to the correct one + compute_tile = self.compute_tiles_dth[self.current_compute_index] + + # Split the compute array into a tile + compute_tile.from_array(compute_array) + + # Syncronize the current stream and the compute stream + cp.cuda.get_current_stream().synchronize() + + # Copy the tile from the compute tile to the store tile + with stream: + compute_tile.to_cpu_tile(self.tiles[tile_index]) + compute_tile.to_cpu_tile(self.tiles[tile_index]) + + def update_padding(self): + """Perform a padding swap between neighboring tiles.""" + + # Get padding indices + pad_ind = self.compute_tiles_htd[0].pad_ind + + # Loop over tiles + comm_tag = 0 + for tile_index in self.tile_process_map.keys(): + # Loop over all padding + for pad_index in pad_ind: + # Get neighboring tile index + neigh_tile_index = tuple( + [ + (i + p) % s + for (i, p, s) in zip(tile_index, pad_index, self.tile_dims) + ] + ) + neigh_pad_index = tuple([-p for p in pad_index]) # flip + + # 4 cases: + # 1. the tile and neighboring tile are on the same process + # 2. the tile is on this process and the neighboring tile is on another process + # 3. the tile is on another process and the neighboring tile is on this process + # 4. the tile and neighboring tile are on different processes + + # Case 1: the tile and neighboring tile are on the same process + if ( + self.pid == self.tile_process_map[tile_index] + and self.pid == self.tile_process_map[neigh_tile_index] + ): + # Get the tile and neighboring tile + tile = self.tiles[tile_index] + neigh_tile = self.tiles[neigh_tile_index] + + # Get pointer to padding and neighboring padding + padding = tile._padding[pad_index] + neigh_padding = neigh_tile._buf_padding[neigh_pad_index] + + # Swap padding + tile._padding[pad_index] = neigh_padding + neigh_tile._buf_padding[neigh_pad_index] = padding + + # Case 2: the tile is on this process and the neighboring tile is on another process + if ( + self.pid == self.tile_process_map[tile_index] + and self.pid != self.tile_process_map[neigh_tile_index] + ): + # Get the tile and padding + tile = self.tiles[tile_index] + padding = tile._padding[pad_index] + + # Send padding to neighboring process + self.comm.Send( + padding, + dest=self.tile_process_map[neigh_tile_index], + tag=comm_tag, + ) + + # Case 3: the tile is on another process and the neighboring tile is on this process + if ( + self.pid != self.tile_process_map[tile_index] + and self.pid == self.tile_process_map[neigh_tile_index] + ): + # Get the neighboring tile and padding + neigh_tile = self.tiles[neigh_tile_index] + neigh_padding = neigh_tile._buf_padding[neigh_pad_index] + + # Receive padding from neighboring process + self.comm.Recv( + neigh_padding, + source=self.tile_process_map[tile_index], + tag=comm_tag, + ) + + # Case 4: the tile and neighboring tile are on different processes + if ( + self.pid != self.tile_process_map[tile_index] + and self.pid != self.tile_process_map[neigh_tile_index] + ): + pass + + # Increment the communication tag + comm_tag += 1 + + # Shuffle padding with buffers + for tile in self.tiles.values(): + tile.swap_buf_padding() + + def get_array(self): + """Get the full array out from all the sub-arrays. This should only be used for testing.""" + + # Get the full array + if self.comm.rank == 0: + array = np.ones(self.shape, dtype=self.dtype) + else: + array = None + + # Loop over tiles + comm_tag = 0 + for tile_index in self.tile_process_map.keys(): + # Set the center array in the full array + slice_index = tuple( + [ + slice(i * s, (i + 1) * s) + for (i, s) in zip(tile_index, self.tile_shape) + ] + ) + + # if tile on this process compute the center array + if self.comm.rank == self.tile_process_map[tile_index]: + # Get the tile + tile = self.tiles[tile_index] + + # Copy the tile to the compute tile + tile.to_gpu_tile(self.compute_tiles_htd[0]) + + # Get the compute array + self.compute_tiles_htd[0].to_array(self.compute_arrays[0]) + + # Get the center array + center_array = self.compute_arrays[0][tile._slice_center].get() + + # 4 cases: + # 1. the tile is on rank 0 and this process is rank 0 + # 2. the tile is on another rank and this process is rank 0 + # 3. the tile is on this rank and this process is not rank 0 + # 4. the tile is not on rank 0 and this process is not rank 0 + + # Case 1: the tile is on rank 0 + if self.comm.rank == 0 and self.tile_process_map[tile_index] == 0: + # Set the center array in the full array + array[slice_index] = center_array + + # Case 2: the tile is on another rank and this process is rank 0 + if self.comm.rank == 0 and self.tile_process_map[tile_index] != 0: + # Get the data from the other rank + center_array = np.empty(self.tile_shape, dtype=self.dtype) + self.comm.Recv( + center_array, source=self.tile_process_map[tile_index], tag=comm_tag + ) + + # Set the center array in the full array + array[slice_index] = center_array + + # Case 3: the tile is on this rank and this process is not rank 0 + if ( + self.comm.rank != 0 + and self.tile_process_map[tile_index] == self.comm.rank + ): + # Send the data to rank 0 + self.comm.Send(center_array, dest=0, tag=comm_tag) + + # Case 4: the tile is not on rank 0 and this process is not rank 0 + if self.comm.rank != 0 and self.tile_process_map[tile_index] != 0: + pass + + # Update the communication tag + comm_tag += 1 + + return array diff --git a/xlb/experimental/ooc/out_of_core.py b/xlb/experimental/ooc/out_of_core.py new file mode 100644 index 0000000..5faa422 --- /dev/null +++ b/xlb/experimental/ooc/out_of_core.py @@ -0,0 +1,110 @@ +# Out-of-core decorator for functions that take a lot of memory + +import functools +import warp as wp +import cupy as cp +import jax.dlpack as jdlpack +import jax +import numpy as np + +from xlb.experimental.ooc.ooc_array import OOCArray +from xlb.experimental.ooc.utils import _cupy_to_backend, _backend_to_cupy, _stream_to_backend + + +def OOCmap(comm, ref_args, add_index=False, backend="jax"): + """Decorator for out-of-core functions. + + Parameters + ---------- + comm : MPI communicator + The MPI communicator. (TODO add functionality) + ref_args : List[int] + The indices of the arguments that are OOC arrays to be written to by outputs of the function. + add_index : bool, optional + Whether to add the index of the global array to the arguments of the function. Default is False. + If true the function will take in a tuple of (array, index) instead of just the array. + backend : str, optional + The backend to use for the function. Default is 'jax'. + Options are 'jax' and 'warp'. + give_stream : bool, optional + Whether to give the function a stream to run on. Default is False. + If true the function will take in a last argument of the stream to run on. + """ + + def decorator(func): + def wrapper(*args): + # Get list of OOC arrays + ooc_array_args = [] + for arg in args: + if isinstance(arg, OOCArray): + ooc_array_args.append(arg) + + # Check that all ooc arrays are compatible + # TODO: Add better checks + for ooc_array in ooc_array_args: + if ooc_array_args[0].tile_dims != ooc_array.tile_dims: + raise ValueError( + f"Tile dimensions of ooc arrays do not match. {ooc_array_args[0].tile_dims} != {ooc_array.tile_dims}" + ) + + # Apply the function to each of the ooc arrays + for tile_index in ooc_array_args[0].tiles.keys(): + # Run through args and kwargs and replace ooc arrays with their compute arrays + new_args = [] + for arg in args: + if isinstance(arg, OOCArray): + # Get the compute array (this performs all the memory copies) + compute_array, global_index = arg.get_compute_array(tile_index) + + # Convert to backend array + compute_array = _cupy_to_backend(compute_array, backend) + + # Add index to the arguments if requested + if add_index: + compute_array = (compute_array, global_index) + + new_args.append(compute_array) + else: + new_args.append(arg) + + # Run the function + results = func(*new_args) + + # Convert the results to a tuple if not already + if not isinstance(results, tuple): + results = (results,) + + # Convert the results back to cupy arrays + results = tuple( + [_backend_to_cupy(result, backend) for result in results] + ) + + # Write the results back to the ooc array + for arg_index, result in zip(ref_args, results): + args[arg_index].set_tile(result, tile_index) + + # Update the ooc arrays compute tile index + for ooc_array in ooc_array_args: + ooc_array.update_compute_index() + + # Syncronize all processes + cp.cuda.Device().synchronize() + comm.Barrier() + + # Update the ooc arrays padding + for i, ooc_array in enumerate(ooc_array_args): + if i in ref_args: + ooc_array.update_padding() + + # Reset que + ooc_array.reset_queue_htd() + + # Return OOC arrays + if len(ref_args) == 1: + return ooc_array_args[ref_args[0]] + else: + return tuple([args[arg_index] for arg_index in ref_args]) + + return wrapper + + return decorator diff --git a/xlb/experimental/ooc/tiles/__init__.py b/xlb/experimental/ooc/tiles/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xlb/experimental/ooc/tiles/compressed_tile.py b/xlb/experimental/ooc/tiles/compressed_tile.py new file mode 100644 index 0000000..415f83b --- /dev/null +++ b/xlb/experimental/ooc/tiles/compressed_tile.py @@ -0,0 +1,273 @@ +import numpy as np +import cupy as cp +import itertools +from dataclasses import dataclass +import warnings +import time + +try: + from kvikio._lib.arr import asarray +except ImportError: + warnings.warn("kvikio not installed. Compression will not work.") + +from xlb.experimental.ooc.tiles.tile import Tile +from xlb.experimental.ooc.tiles.dense_tile import DenseGPUTile +from xlb.experimental.ooc.tiles.dynamic_array import DynamicPinnedArray + + +def _decode(comp_array, dest_array, codec): + """ + Decompresses comp_array into dest_array. + + Parameters + ---------- + comp_array : cupy array + The compressed array to be decompressed. Data type uint8. + dest_array : cupy array + The storage array for the decompressed data. + codec : Codec + The codec to use for decompression. For example, `kvikio.nvcomp.CascadedManager`. + + """ + + # Store needed information + dtype = dest_array.dtype + shape = dest_array.shape + + # Reshape dest_array to match to make into buffer + dest_array = dest_array.view(cp.uint8).reshape(-1) + + # Decompress + codec._manager.decompress(asarray(dest_array), asarray(comp_array)) + + return dest_array.view(dtype).reshape(shape) + + +def _encode(array, dest_array, codec): + """ + Compresses array into dest_array. + + Parameters + ---------- + array : cupy array + The array to be compressed. + dest_array : cupy array + The storage array for the compressed data. Data type uint8. + codec : Codec + The codec to use for compression. For example, `kvikio.nvcomp.CascadedManager`. + """ + + # Make sure array is contiguous + array = cp.ascontiguousarray(array) + + # Configure compression + codec._manager.configure_compression(array.nbytes) + + # Compress + size = codec._manager.compress(asarray(array), asarray(dest_array)) + return size + + +class CompressedTile(Tile): + """A Tile where the data is stored in compressed form.""" + + def __init__(self, shape, dtype, padding, codec): + super().__init__(shape, dtype, padding, codec) + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + raise NotImplementedError + + def to_array(self, array): + """Copy a tile to a full array.""" + # Only implemented for GPU tiles + raise NotImplementedError + + def from_array(self, array): + """Copy a full array to tile.""" + # Only implemented for GPU tiles + raise NotImplementedError + + def compression_ratio(self): + """Returns the compression ratio of the tile.""" + # Get total number of bytes in tile + total_bytes = self._array.size() + for pad_ind in self.pad_ind: + total_bytes += self._padding[pad_ind].size() + + # Get total number of bytes in uncompressed tile + total_bytes_uncompressed = np.prod(self.shape) * self.dtype_itemsize + for pad_ind in self.pad_ind: + total_bytes_uncompressed += ( + np.prod(self._padding_shape[pad_ind]) * self.dtype_itemsize + ) + + # Return compression ratio + return total_bytes_uncompressed, total_bytes + + +class CompressedCPUTile(CompressedTile): + """A tile with cells on the CPU.""" + + def __init__(self, shape, dtype, padding, codec): + super().__init__(shape, dtype, padding, codec) + + def size(self): + """Returns the size of the tile in bytes.""" + size = self._array.size() + for pad_ind in self.pad_ind: + size += self._padding[pad_ind].size() + return size + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + # Make zero array + cp_array = cp.zeros(shape, dtype=self.dtype) + + # compress array + codec = self.codec() + compressed_cp_array = codec.compress(cp_array) + + # Allocate array on CPU + array = DynamicPinnedArray(compressed_cp_array.nbytes) + + # Copy array + compressed_cp_array.get(out=array.array) + + # add nbytes + self.nbytes += cp_array.nbytes + + # delete GPU arrays + del compressed_cp_array + del cp_array + + return array + + def to_gpu_tile(self, dst_gpu_tile): + """Copy tile to a GPU tile.""" + + # Check tile is Compressed + assert isinstance( + dst_gpu_tile, CompressedGPUTile + ), "Destination tile must be a CompressedGPUTile" + + # Copy array + dst_gpu_tile._array[: len(self._array.array)].set(self._array.array) + dst_gpu_tile._array_bytes = self._array.nbytes + + # Copy padding + for pad_ind in self.pad_ind: + dst_gpu_tile._padding[pad_ind][: len(self._padding[pad_ind].array)].set( + self._padding[pad_ind].array + ) + dst_gpu_tile._padding_bytes[pad_ind] = self._padding[pad_ind].nbytes + + +class CompressedGPUTile(CompressedTile): + """A sub-array with ghost cells on the GPU.""" + + def __init__(self, shape, dtype, padding, codec): + super().__init__(shape, dtype, padding, codec) + + # Allocate dense GPU tile + self.dense_gpu_tile = DenseGPUTile(shape, dtype, padding) + + # Set bytes for each array and padding + self._array_bytes = -1 + self._padding_bytes = {} + for pad_ind in self.pad_ind: + self._padding_bytes[pad_ind] = -1 + + # Set codec for each array and padding + self._array_codec = None + self._padding_codec = {} + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + nbytes = np.prod(shape) * self.dtype_itemsize + codec = self.codec() + max_compressed_buffer = codec._manager.configure_compression(nbytes)[ + "max_compressed_buffer_size" + ] + array = cp.zeros((max_compressed_buffer,), dtype=np.uint8) + return array + + def to_array(self, array): + """Copy a tile to a full array.""" + + # Copy center array + if self._array_codec is None: + self._array_codec = self.codec() + self._array_codec._manager.configure_decompression_with_compressed_buffer( + asarray(self._array[: self._array_bytes]) + ) + self._array_codec.decompression_config = self._array_codec._manager.configure_decompression_with_compressed_buffer( + asarray(self._array[: self._array_bytes]) + ) + self.dense_gpu_tile._array = _decode( + self._array[: self._array_bytes], + self.dense_gpu_tile._array, + self._array_codec, + ) + array[self._slice_center] = self.dense_gpu_tile._array + + # Copy padding + for pad_ind in self.pad_ind: + if pad_ind not in self._padding_codec: + self._padding_codec[pad_ind] = self.codec() + self._padding_codec[pad_ind].decompression_config = self._padding_codec[ + pad_ind + ]._manager.configure_decompression_with_compressed_buffer( + asarray(self._padding[pad_ind][: self._padding_bytes[pad_ind]]) + ) + self.dense_gpu_tile._padding[pad_ind] = _decode( + self._padding[pad_ind][: self._padding_bytes[pad_ind]], + self.dense_gpu_tile._padding[pad_ind], + self._padding_codec[pad_ind], + ) + array[self._slice_padding_to_array[pad_ind]] = self.dense_gpu_tile._padding[ + pad_ind + ] + + def from_array(self, array): + """Copy a full array to tile.""" + + # Copy center array + if self._array_codec is None: + self._array_codec = self.codec() + self._array_codec.configure_compression(self._array.nbytes) + self._array_bytes = _encode( + array[self._slice_center], self._array, self._array_codec + ) + + # Copy padding + for pad_ind in self.pad_ind: + if pad_ind not in self._padding_codec: + self._padding_codec[pad_ind] = self.codec() + self._padding_codec[pad_ind].configure_compression( + self._padding[pad_ind].nbytes + ) + self._padding_bytes[pad_ind] = _encode( + array[self._slice_array_to_padding[pad_ind]], + self._padding[pad_ind], + self._padding_codec[pad_ind], + ) + + def to_cpu_tile(self, dst_cpu_tile): + """Copy tile to a CPU tile.""" + + # Check tile is Compressed + assert isinstance( + dst_cpu_tile, CompressedCPUTile + ), "Destination tile must be a CompressedCPUTile" + + # Copy array + dst_cpu_tile._array.resize(self._array_bytes) + self._array[: self._array_bytes].get(out=dst_cpu_tile._array.array) + + # Copy padding + for pad_ind in self.pad_ind: + dst_cpu_tile._padding[pad_ind].resize(self._padding_bytes[pad_ind]) + self._padding[pad_ind][: self._padding_bytes[pad_ind]].get( + out=dst_cpu_tile._padding[pad_ind].array + ) diff --git a/xlb/experimental/ooc/tiles/dense_tile.py b/xlb/experimental/ooc/tiles/dense_tile.py new file mode 100644 index 0000000..8a303e4 --- /dev/null +++ b/xlb/experimental/ooc/tiles/dense_tile.py @@ -0,0 +1,96 @@ +import numpy as np +import cupy as cp +import itertools +from dataclasses import dataclass + +from xlb.experimental.ooc.tiles.tile import Tile + + +class DenseTile(Tile): + """A Tile where the data is stored in a dense array of the requested dtype.""" + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + raise NotImplementedError + + def to_array(self, array): + """Copy a tile to a full array.""" + # TODO: This can be done with a single kernel call, profile to see if it is faster and needs to be done. + + # Copy center array + array[self._slice_center] = self._array + + # Copy padding + for pad_ind in self.pad_ind: + array[self._slice_padding_to_array[pad_ind]] = self._padding[pad_ind] + + def from_array(self, array): + """Copy a full array to tile.""" + # TODO: This can be done with a single kernel call, profile to see if it is faster and needs to be done. + + # Copy center array + self._array[...] = array[self._slice_center] + + # Copy padding + for pad_ind in self.pad_ind: + self._padding[pad_ind][...] = array[self._slice_array_to_padding[pad_ind]] + + +class DenseCPUTile(DenseTile): + """A dense tile with cells on the CPU.""" + + def __init__(self, shape, dtype, padding, codec=None): + super().__init__(shape, dtype, padding, None) + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + # TODO: Seems hacky, but it works. Is there a better way? + mem = cp.cuda.alloc_pinned_memory(np.prod(shape) * self.dtype_itemsize) + array = np.frombuffer(mem, dtype=self.dtype, count=np.prod(shape)).reshape( + shape + ) + self.nbytes += mem.size() + return array + + def to_gpu_tile(self, dst_gpu_tile): + """Copy tile to a GPU tile.""" + + # Check that the destination tile is on the GPU + assert isinstance(dst_gpu_tile, DenseGPUTile), "Destination tile must be on GPU" + + # Copy array + dst_gpu_tile._array.set(self._array) + + # Copy padding + for src_array, dst_gpu_array in zip( + self._padding.values(), dst_gpu_tile._padding.values() + ): + dst_gpu_array.set(src_array) + + +class DenseGPUTile(DenseTile): + """A sub-array with ghost cells on the GPU.""" + + def __init__(self, shape, dtype, padding, codec=None): + super().__init__(shape, dtype, padding, None) + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + array = cp.zeros(shape, dtype=self.dtype) + self.nbytes += array.nbytes + return array + + def to_cpu_tile(self, dst_cpu_tile): + """Copy tile to a CPU tile.""" + + # Check that the destination tile is on the CPU + assert isinstance(dst_cpu_tile, DenseCPUTile), "Destination tile must be on CPU" + + # Copy arra + self._array.get(out=dst_cpu_tile._array) + + # Copy padding + for src_array, dst_array in zip( + self._padding.values(), dst_cpu_tile._padding.values() + ): + src_array.get(out=dst_array) diff --git a/xlb/experimental/ooc/tiles/dynamic_array.py b/xlb/experimental/ooc/tiles/dynamic_array.py new file mode 100644 index 0000000..2b05b2e --- /dev/null +++ b/xlb/experimental/ooc/tiles/dynamic_array.py @@ -0,0 +1,72 @@ +# Dynamic array class for pinned memory allocation + +import math +import cupy as cp +import numpy as np +import time + + +class DynamicArray: + """ + Dynamic pinned memory array class. + + Attributes + ---------- + nbytes : int + The number of bytes in the array. + bytes_resize : int + The number of bytes to resize the array by if the number of bytes requested exceeds the allocated number of bytes. + """ + + def __init__(self, nbytes, bytes_resize_factor=0.025): + # Set the number of bytes + self.nbytes = nbytes + self.bytes_resize_factor = bytes_resize_factor + self.bytes_resize = math.ceil(bytes_resize_factor * nbytes) + + # Set the number of bytes + self.allocated_bytes = math.ceil(nbytes / self.bytes_resize) * self.bytes_resize + + +class DynamicPinnedArray(DynamicArray): + def __init__(self, nbytes, bytes_resize_factor=0.05): + super().__init__(nbytes, bytes_resize_factor) + + # Allocate the memory + self.mem = cp.cuda.alloc_pinned_memory(self.allocated_bytes) + + # Make np array that points to the pinned memory + self.array = np.frombuffer(self.mem, dtype=np.uint8, count=int(self.nbytes)) + + def size(self): + return self.mem.size() + + def resize(self, nbytes): + # Set the new number of bytes + self.nbytes = nbytes + + # Check if the number of bytes requested is less than 2xbytes_resize or if the number of bytes requested exceeds the allocated number of bytes + if ( + nbytes < (self.allocated_bytes - 2 * self.bytes_resize) + or nbytes > self.allocated_bytes + ): + ## Free the memory + # del self.mem + + # Set the new number of allocated bytes + self.allocated_bytes = ( + math.ceil(nbytes / self.bytes_resize) * self.bytes_resize + ) + + # Allocate the memory + self.mem = cp.cuda.alloc_pinned_memory(self.allocated_bytes) + + # Make np array that points to the pinned memory + self.array = np.frombuffer(self.mem, dtype=np.uint8, count=int(self.nbytes)) + + # Set new resize number of bytes + self.bytes_resize = math.ceil(self.bytes_resize_factor * nbytes) + + # Otherwise change numpy array size + else: + self.array = np.frombuffer(self.mem, dtype=np.uint8, count=int(self.nbytes)) diff --git a/xlb/experimental/ooc/tiles/tile.py b/xlb/experimental/ooc/tiles/tile.py new file mode 100644 index 0000000..90c3334 --- /dev/null +++ b/xlb/experimental/ooc/tiles/tile.py @@ -0,0 +1,115 @@ +import numpy as np +import cupy as cp +import itertools +from dataclasses import dataclass + + +class Tile: + """Base class for Tile with ghost cells. This tile is used to build a distributed array. + + Attributes + ---------- + shape : tuple + Shape of the tile. This will be the shape of the array without padding/ghost cells. + dtype : cp.dtype + Data type the tile represents. Note that the data data may be stored in a different + data type. For example, if it is stored in compressed form. + padding : tuple + Number of padding/ghost cells in each dimension. + """ + + def __init__(self, shape, dtype, padding, codec=None): + # Store parameters + self.shape = shape + self.dtype = dtype + self.padding = padding + self.dtype_itemsize = cp.dtype(self.dtype).itemsize + self.nbytes = 0 # Updated when array is allocated + self.codec = ( + codec # Codec to use for compression TODO: Find better abstraction for this + ) + + # Make center array + self._array = self.allocate_array(self.shape) + + # Make padding indices + pad_dir = [] + for i in range(len(self.shape)): + if self.padding[i] == 0: + pad_dir.append((0,)) + else: + pad_dir.append((-1, 0, 1)) + self.pad_ind = list(itertools.product(*pad_dir)) + self.pad_ind.remove((0,) * len(self.shape)) + + # Make padding and padding buffer arrays + self._padding = {} + self._buf_padding = {} + for ind in self.pad_ind: + # determine array shape + shape = [] + for i in range(len(self.shape)): + if ind[i] == -1 or ind[i] == 1: + shape.append(self.padding[i]) + else: + shape.append(self.shape[i]) + + # Make padding and padding buffer + self._padding[ind] = self.allocate_array(shape) + self._buf_padding[ind] = self.allocate_array(shape) + + # Get slicing for array copies + self._slice_center = tuple( + [slice(pad, pad + shape) for (pad, shape) in zip(self.padding, self.shape)] + ) + self._slice_padding_to_array = {} + self._slice_array_to_padding = {} + self._padding_shape = {} + for pad_ind in self.pad_ind: + slice_padding_to_array = [] + slice_array_to_padding = [] + padding_shape = [] + for pad, ind, s in zip(self.padding, pad_ind, self.shape): + if ind == -1: + slice_padding_to_array.append(slice(0, pad)) + slice_array_to_padding.append(slice(pad, 2 * pad)) + padding_shape.append(pad) + elif ind == 0: + slice_padding_to_array.append(slice(pad, s + pad)) + slice_array_to_padding.append(slice(pad, s + pad)) + padding_shape.append(s) + else: + slice_padding_to_array.append(slice(s + pad, s + 2 * pad)) + slice_array_to_padding.append(slice(s, s + pad)) + padding_shape.append(pad) + self._slice_padding_to_array[pad_ind] = tuple(slice_padding_to_array) + self._slice_array_to_padding[pad_ind] = tuple(slice_array_to_padding) + self._padding_shape[pad_ind] = tuple(padding_shape) + + def size(self): + """Returns the number of bytes allocated for the tile.""" + raise NotImplementedError + + def allocate_array(self, shape): + """Returns a cupy array with the given shape.""" + raise NotImplementedError + + def copy_tile(self, dst_tile): + """Copy a tile from one tile to another.""" + raise NotImplementedError + + def to_array(self, array): + """Copy a tile to a full array.""" + raise NotImplementedError + + def from_array(self, array): + """Copy a full array to a tile.""" + raise NotImplementedError + + def swap_buf_padding(self): + """Swap the padding buffer pointer with the padding pointer.""" + for index in self.pad_ind: + (self._buf_padding[index], self._padding[index]) = ( + self._padding[index], + self._buf_padding[index], + ) diff --git a/xlb/experimental/ooc/utils.py b/xlb/experimental/ooc/utils.py new file mode 100644 index 0000000..f607128 --- /dev/null +++ b/xlb/experimental/ooc/utils.py @@ -0,0 +1,79 @@ +import warp as wp +import cupy as cp +import jax.dlpack as jdlpack +import jax + + +def _cupy_to_backend(cupy_array, backend): + """ + Convert cupy array to backend array + + Parameters + ---------- + cupy_array : cupy.ndarray + Input cupy array + backend : str + Backend to convert to. Options are "jax", "warp", or "cupy" + """ + + # Convert cupy array to backend array + dl_array = cupy_array.toDlpack() + if backend == "jax": + backend_array = jdlpack.from_dlpack(dl_array) + elif backend == "warp": + backend_array = wp.from_dlpack(dl_array) + elif backend == "cupy": + backend_array = cupy_array + else: + raise ValueError(f"Backend {backend} not supported") + return backend_array + + +def _backend_to_cupy(backend_array, backend): + """ + Convert backend array to cupy array + + Parameters + ---------- + backend_array : backend.ndarray + Input backend array + backend : str + Backend to convert from. Options are "jax", "warp", or "cupy" + """ + + # Convert backend array to cupy array + if backend == "jax": + (jax.device_put(0.0) + 0).block_until_ready() + dl_array = jdlpack.to_dlpack(backend_array) + elif backend == "warp": + dl_array = wp.to_dlpack(backend_array) + elif backend == "cupy": + return backend_array + else: + raise ValueError(f"Backend {backend} not supported") + cupy_array = cp.fromDlpack(dl_array) + return cupy_array + + +def _stream_to_backend(stream, backend): + """ + Convert cupy stream to backend stream + + Parameters + ---------- + stream : cupy.cuda.Stream + Input cupy stream + backend : str + Backend to convert to. Options are "jax", "warp", or "cupy" + """ + + # Convert stream to backend stream + if backend == "jax": + raise ValueError("Jax currently does not support streams") + elif backend == "warp": + backend_stream = wp.Stream(cuda_stream=stream.ptr) + elif backend == "cupy": + backend_stream = stream + else: + raise ValueError(f"Backend {backend} not supported") + return backend_stream diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py new file mode 100644 index 0000000..1cf32f9 --- /dev/null +++ b/xlb/operator/__init__.py @@ -0,0 +1 @@ +from xlb.operator.operator import Operator diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py new file mode 100644 index 0000000..9937bc3 --- /dev/null +++ b/xlb/operator/boundary_condition/__init__.py @@ -0,0 +1,5 @@ +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition, ImplementationStep +from xlb.operator.boundary_condition.full_bounce_back import FullBounceBack +from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBack +from xlb.operator.boundary_condition.do_nothing import DoNothing +from xlb.operator.boundary_condition.equilibrium_boundary import EquilibriumBoundary diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py new file mode 100644 index 0000000..7e8909a --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -0,0 +1,100 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.operator.operator import Operator +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend + +# Enum for implementation step +class ImplementationStep(Enum): + COLLISION = 1 + STREAMING = 2 + +class BoundaryCondition(Operator): + """ + Base class for boundary conditions in a LBM simulation. + """ + + def __init__( + self, + set_boundary, + implementation_step: ImplementationStep, + velocity_set: VelocitySet, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + # Set implementation step + self.implementation_step = implementation_step + + # Set boundary function + if compute_backend == ComputeBackend.JAX: + self.set_boundary = set_boundary + else: + raise NotImplementedError + + @classmethod + def from_indices(cls, indices, implementation_step: ImplementationStep): + """ + Creates a boundary condition from a list of indices. + """ + raise NotImplementedError + + @partial(jit, static_argnums=(0,)) + def apply_jax(self, f_pre, f_post, mask, velocity_set: VelocitySet): + """ + Applies the boundary condition. + """ + pass + + @staticmethod + def _indices_to_tuple(indices): + """ + Converts a tensor of indices to a tuple for indexing + TODO: Might be better to index + """ + return tuple([indices[:, i] for i in range(indices.shape[1])]) + + @staticmethod + def _set_boundary_from_indices(indices): + """ + This create the standard set_boundary function from a list of indices. + `boundary_id` is set to `id_number` at the indices and `mask` is set to `True` at the indices. + Many boundary conditions can be created from this function however some may require a custom function such as + HalfwayBounceBack. + """ + + # Create a mask function + def set_boundary(ijk, boundary_id, mask, id_number): + """ + Sets the mask id for the boundary condition. + + Parameters + ---------- + ijk : jnp.ndarray + Array of shape (N, N, N, 3) containing the meshgrid of lattice points. + boundary_id : jnp.ndarray + Array of shape (N, N, N) containing the boundary id. This will be modified in place and returned. + mask : jnp.ndarray + Array of shape (N, N, N, Q) containing the mask. This will be modified in place and returned. + """ + + # Get local indices from the meshgrid and the indices + local_indices = ijk[BoundaryCondition._indices_to_tuple(indices)] + + # Set the boundary id + boundary_id = boundary_id.at[BoundaryCondition._indices_to_tuple(indices)].set(id_number) + + # Set the mask + mask = mask.at[BoundaryCondition._indices_to_tuple(indices)].set(True) + + return boundary_id, mask + + return set_boundary diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py new file mode 100644 index 0000000..f8f28ed --- /dev/null +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -0,0 +1,56 @@ +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream.stream import Stream +from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class DoNothing(BoundaryCondition): + """ + A boundary condition that skips the streaming step. + """ + + def __init__( + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.STREAMING, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @classmethod + def from_indices( + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + return cls( + set_boundary=cls._set_boundary_from_indices(indices), + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) + f = lax.select(do_nothing, f_pre, f_post) + return f diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py new file mode 100644 index 0000000..f615a53 --- /dev/null +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -0,0 +1,69 @@ +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream.stream import Stream +from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class EquilibriumBoundary(BoundaryCondition): + """ + A boundary condition that skips the streaming step. + """ + + def __init__( + self, + set_boundary, + rho: float, + u: tuple[float, float], + equilibrium: Equilibrium, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.STREAMING, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + self.f = equilibrium(rho, u) + + @classmethod + def from_indices( + cls, + indices, + rho: float, + u: tuple[float, float], + equilibrium: Equilibrium, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + return cls( + set_boundary=cls._set_boundary_from_indices(indices), + rho=rho, + u=u, + equilibrium=equilibrium, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) + equilibrium_f = jnp.repeat(self.f[None, ...], boundary.shape[0], axis=0) + equilibrium_f = jnp.repeat(equilibrium_f[:, None], boundary.shape[1], axis=1) + equilibrium_f = jnp.repeat(equilibrium_f[:, :, None], boundary.shape[2], axis=2) + f = lax.select(equilibrium_mask, equilibrium_f, f_post) + return f diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py new file mode 100644 index 0000000..fc883c8 --- /dev/null +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -0,0 +1,57 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class FullBounceBack(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + def __init__( + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.COLLISION, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @classmethod + def from_indices( + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + return cls( + set_boundary=cls._set_boundary_from_indices(indices), + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + flip = jnp.repeat(boundary[..., jnp.newaxis], self.velocity_set.q, axis=-1) + flipped_f = lax.select(flip, f_pre[..., self.velocity_set.opp_indices], f_post) + return flipped_f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py new file mode 100644 index 0000000..3b5b6de --- /dev/null +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -0,0 +1,97 @@ +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream.stream import Stream +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) + +class HalfwayBounceBack(BoundaryCondition): + """ + Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + def __init__( + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + super().__init__( + set_boundary=set_boundary, + implementation_step=ImplementationStep.STREAMING, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + @classmethod + def from_indices( + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): + """ + Creates a boundary condition from a list of indices. + """ + + # Make stream operator to get edge points + stream = Stream(velocity_set=velocity_set) + + # Create a mask function + def set_boundary(ijk, boundary_id, mask, id_number): + """ + Sets the mask id for the boundary condition. + Halfway bounce-back is implemented by setting the mask to True for points in the boundary, + then streaming the mask to get the points on the surface. + + Parameters + ---------- + ijk : jnp.ndarray + Array of shape (N, N, N, 3) containing the meshgrid of lattice points. + boundary_id : jnp.ndarray + Array of shape (N, N, N) containing the boundary id. This will be modified in place and returned. + mask : jnp.ndarray + Array of shape (N, N, N, Q) containing the mask. This will be modified in place and returned. + """ + + # Get local indices from the meshgrid and the indices + local_indices = ijk[tuple(s[:, 0] for s in jnp.split(indices, velocity_set.d, axis=1))] + + # Make mask then stream to get the edge points + pre_stream_mask = jnp.zeros_like(mask) + pre_stream_mask = pre_stream_mask.at[tuple([s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)])].set(True) + post_stream_mask = stream(pre_stream_mask) + + # Set false for points inside the boundary + post_stream_mask = post_stream_mask.at[post_stream_mask[..., 0] == True].set(False) + + # Get indices on edges + edge_indices = jnp.argwhere(post_stream_mask) + + # Set the boundary id + boundary_id = boundary_id.at[tuple([s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)])].set(id_number) + + # Set the mask + mask = mask.at[edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], :].set(post_stream_mask[edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], :]) + + return boundary_id, mask + + return cls( + set_boundary=set_boundary, + velocity_set=velocity_set, + compute_backend=compute_backend, + ) + + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + flip_mask = boundary[..., jnp.newaxis] & mask + flipped_f = lax.select(flip_mask, f_pre[..., self.velocity_set.opp_indices], f_post) + return flipped_f diff --git a/xlb/operator/collision/__init__.py b/xlb/operator/collision/__init__.py new file mode 100644 index 0000000..77395e6 --- /dev/null +++ b/xlb/operator/collision/__init__.py @@ -0,0 +1,3 @@ +from xlb.operator.collision.collision import Collision +from xlb.operator.collision.bgk import BGK +from xlb.operator.collision.kbc import KBC diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py new file mode 100644 index 0000000..19b5846 --- /dev/null +++ b/xlb/operator/collision/bgk.py @@ -0,0 +1,109 @@ +""" +BGK collision operator for LBM. +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +from numba import cuda, float32 + +from xlb.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.collision.collision import Collision + + +class BGK(Collision): + """ + BGK collision operator for LBM. + + The BGK collision operator is the simplest collision operator for LBM. + It is based on the Bhatnagar-Gross-Krook approximation to the Boltzmann equation. + Reference: https://en.wikipedia.org/wiki/Bhatnagar%E2%80%93Gross%E2%80%93Krook_operator + """ + + def __init__( + self, + omega: float, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__( + omega=omega, + velocity_set=velocity_set, + compute_backend=compute_backend + ) + + @partial(jit, static_argnums=(0), donate_argnums=(1,2,3,4)) + def apply_jax( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + u : jnp.ndarray, + ): + """ + BGK collision step for lattice. + + The collision step is where the main physics of the LBM is applied. In the BGK approximation, + the distribution function is relaxed towards the equilibrium distribution function. + + Parameters + ---------- + f : jax.numpy.ndarray + The distribution function + feq : jax.numpy.ndarray + The equilibrium distribution function + rho : jax.numpy.ndarray + The macroscopic density + u : jax.numpy.ndarray + The macroscopic velocity + + """ + fneq = f - feq + fout = f - self.omega * fneq + return fout + + def construct_numba(self): + """ + Numba implementation of the collision operator. + + Returns + ------- + _collision : numba.cuda.jit + The compiled numba function for the collision operator. + """ + + # Get needed parameters for numba function + omega = self.omega + omega = float32(omega) + + # Make numba function + @cuda.jit(device=True) + def _collision(f, feq, rho, u, fout): + """ + Numba BGK collision step for lattice. + + The collision step is where the main physics of the LBM is applied. In the BGK approximation, + the distribution function is relaxed towards the equilibrium distribution function. + + Parameters + ---------- + f : cuda.local.array + The distribution function + feq : cuda.local.array + The equilibrium distribution function + rho : cuda.local.array + The macroscopic density + u : cuda.local.array + The macroscopic velocity + fout : cuda.local.array + The output distribution function + """ + + # Relaxation + for i in range(f.shape[0]): + fout[i] = f[i] - omega * (f[i] - feq[i]) + + return fout + + return _collision diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py new file mode 100644 index 0000000..728f40c --- /dev/null +++ b/xlb/operator/collision/collision.py @@ -0,0 +1,48 @@ +""" +Base class for Collision operators +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +import numba + +from xlb.compute_backend import ComputeBackend +from xlb.velocity_set import VelocitySet +from xlb.operator import Operator + + +class Collision(Operator): + """ + Base class for collision operators. + + This class defines the collision step for the Lattice Boltzmann Method. + + Parameters + ---------- + omega : float + Relaxation parameter for collision step. Default value is 0.6. + shear : bool + Flag to indicate whether the collision step requires the shear stress. + """ + + def __init__( + self, + omega: float, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + self.omega = omega + + def apply_jax(self, f, feq, rho, u): + """ + Jax implementation of collision step. + """ + raise NotImplementedError("Child class must implement apply_jax.") + + def construct_numba(self, velocity_set: VelocitySet, dtype=numba.float32): + """ + Construct numba implementation of collision step. + """ + raise NotImplementedError("Child class must implement construct_numba.") diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py new file mode 100644 index 0000000..b5567bb --- /dev/null +++ b/xlb/operator/collision/kbc.py @@ -0,0 +1,203 @@ +""" +KBC collision operator for LBM. +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +from numba import cuda, float32 + +from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 +from xlb.compute_backend import ComputeBackend +from xlb.operator.collision.collision import Collision + + +class KBC(Collision): + """ + KBC collision operator for LBM. + + This class implements the Karlin-Bösch-Chikatamarla (KBC) model for the collision step in the Lattice Boltzmann Method. + """ + + def __init__( + self, + omega, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__( + omega=omega, + velocity_set=velocity_set, + compute_backend=compute_backend + ) + self.epsilon = 1e-32 + self.beta = self.omega * 0.5 + self.inv_beta = 1.0 / self.beta + + @partial(jit, static_argnums=(0,), donate_argnums=(1,2,3,4)) + def apply_jax( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + u: jnp.ndarray, + ): + """ + KBC collision step for lattice. + + Parameters + ---------- + f : jax.numpy.array + Distribution function. + feq : jax.numpy.array + Equilibrium distribution function. + rho : jax.numpy.array + Density. + u : jax.numpy.array + Velocity. + """ + + # Compute shear TODO: Generalize this and possibly make it an operator or something + fneq = f - feq + if isinstance(self.velocity_set, D2Q9): + shear = self.decompose_shear_d2q9_jax(fneq) + delta_s = shear * rho / 4.0 # TODO: Check this + elif isinstance(self.velocity_set, D3Q27): + shear = self.decompose_shear_d3q27_jax(fneq) + delta_s = shear * rho + + # Perform collision + delta_h = fneq - delta_s + gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product( + delta_s, delta_h, feq + ) / (self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq)) + + fout = f - self.beta * (2.0 * delta_s + gamma[..., None] * delta_h) + + return fout + + @partial(jit, static_argnums=(0,), inline=True) + def entropic_scalar_product( + self, + x: jnp.ndarray, + y: jnp.ndarray, + feq: jnp.ndarray + ): + """ + Compute the entropic scalar product of x and y to approximate gamma in KBC. + + Returns + ------- + jax.numpy.array + Entropic scalar product of x, y, and feq. + """ + return jnp.sum(x * y / feq, axis=-1) + + @partial(jit, static_argnums=(0, 2), donate_argnums=(1,)) + def momentum_flux_jax( + self, + fneq: jnp.ndarray, + ): + """ + This function computes the momentum flux, which is the product of the non-equilibrium + distribution functions (fneq) and the lattice moments (cc). + + The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann + Method (LBM). + + # TODO: probably move this to equilibrium calculation + + Parameters + ---------- + fneq: jax.numpy.ndarray + The non-equilibrium distribution functions. + + Returns + ------- + jax.numpy.ndarray + The computed momentum flux. + """ + + return jnp.dot(fneq, jnp.array(self.velocity_set.cc, dtype=fneq.dtype)) + + + @partial(jit, static_argnums=(0, 2), inline=True) + def decompose_shear_d3q27_jax(self, fneq): + """ + Decompose fneq into shear components for D3Q27 lattice. + + Parameters + ---------- + fneq : jax.numpy.ndarray + Non-equilibrium distribution function. + + Returns + ------- + jax.numpy.ndarray + Shear components of fneq. + """ + + # Calculate the momentum flux + Pi = self.momentum_flux_jax(fneq) + Nxz = Pi[..., 0] - Pi[..., 5] + Nyz = Pi[..., 3] - Pi[..., 5] + + # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) + s = jnp.zeros_like(fneq) + s = s.at[..., 9].set((2.0 * Nxz - Nyz) / 6.0) + s = s.at[..., 18].set((2.0 * Nxz - Nyz) / 6.0) + s = s.at[..., 3].set((-Nxz + 2.0 * Nyz) / 6.0) + s = s.at[..., 6].set((-Nxz + 2.0 * Nyz) / 6.0) + s = s.at[..., 1].set((-Nxz - Nyz) / 6.0) + s = s.at[..., 2].set((-Nxz - Nyz) / 6.0) + + # For c = (i, j, 0) + s = s.at[..., 12].set(Pi[..., 1] / 4.0) + s = s.at[..., 24].set(Pi[..., 1] / 4.0) + s = s.at[..., 21].set(-Pi[..., 1] / 4.0) + s = s.at[..., 15].set(-Pi[..., 1] / 4.0) + + # For c = (i, 0, k) + s = s.at[..., 10].set(Pi[..., 2] / 4.0) + s = s.at[..., 20].set(Pi[..., 2] / 4.0) + s = s.at[..., 19].set(-Pi[..., 2] / 4.0) + s = s.at[..., 11].set(-Pi[..., 2] / 4.0) + + # For c = (0, j, k) + s = s.at[..., 8].set(Pi[..., 4] / 4.0) + s = s.at[..., 4].set(Pi[..., 4] / 4.0) + s = s.at[..., 7].set(-Pi[..., 4] / 4.0) + s = s.at[..., 5].set(-Pi[..., 4] / 4.0) + + return s + + @partial(jit, static_argnums=(0, 2), inline=True) + def decompose_shear_d2q9_jax(self, fneq): + """ + Decompose fneq into shear components for D2Q9 lattice. + + Parameters + ---------- + fneq : jax.numpy.array + Non-equilibrium distribution function. + + Returns + ------- + jax.numpy.array + Shear components of fneq. + """ + Pi = self.momentum_flux_jax(fneq) + N = Pi[..., 0] - Pi[..., 2] + s = jnp.zeros_like(fneq) + s = s.at[..., 6].set(N) + s = s.at[..., 3].set(N) + s = s.at[..., 2].set(-N) + s = s.at[..., 1].set(-N) + s = s.at[..., 8].set(Pi[..., 1]) + s = s.at[..., 4].set(-Pi[..., 1]) + s = s.at[..., 5].set(-Pi[..., 1]) + s = s.at[..., 7].set(Pi[..., 1]) + + return s + + diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py new file mode 100644 index 0000000..587f673 --- /dev/null +++ b/xlb/operator/equilibrium/__init__.py @@ -0,0 +1 @@ +from xlb.operator.equilibrium.equilibrium import Equilibrium, QuadraticEquilibrium diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py new file mode 100644 index 0000000..9de736f --- /dev/null +++ b/xlb/operator/equilibrium/equilibrium.py @@ -0,0 +1,88 @@ +# Base class for all equilibriums + +from functools import partial +import jax.numpy as jnp +from jax import jit +import numba +from numba import cuda + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class Equilibrium(Operator): + """ + Base class for all equilibriums + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + +class QuadraticEquilibrium(Equilibrium): + """ + Quadratic equilibrium of Boltzmann equation using hermite polynomials. + Standard equilibrium model for LBM. + + TODO: move this to a separate file and lower and higher order equilibriums + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + def apply_jax(self, rho, u): + """ + JAX implementation of the equilibrium distribution function. + + # TODO: This might be optimized using a for loop for because + # the compiler will remove 0 c terms. + """ + cu = 3.0 * jnp.dot(u, jnp.array(self.velocity_set.c, dtype=rho.dtype)) + usqr = 1.5 * jnp.sum(jnp.square(u), axis=-1, keepdims=True) + feq = ( + rho + * jnp.array(self.velocity_set.w, dtype=rho.dtype) + * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + ) + return feq + + def construct_numba(self, velocity_set: VelocitySet, dtype=numba.float32): + """ + Numba implementation of the equilibrium distribution function. + """ + # Get needed values for numba functions + q = velocity_set.q + c = velocity_set.c.T + w = velocity_set.w + + # Make numba functions + @cuda.jit(device=True) + def _equilibrium(rho, u, feq): + # Compute the equilibrium distribution function + usqr = dtype(1.5) * (u[0] * u[0] + u[1] * u[1] + u[2] * u[2]) + for i in range(q): + cu = dtype(3.0) * ( + u[0] * dtype(c[i, 0]) + + u[1] * dtype(c[i, 1]) + + u[2] * dtype(c[i, 2]) + ) + feq[i] = ( + rho[0] + * dtype(w[i]) + * (dtype(1.0) + cu * (dtype(1.0) + dtype(0.5) * cu) - usqr) + ) + + # Return the equilibrium distribution function + return feq # comma is needed for numba to return a tuple, seems like a bug in numba + + return _equilibrium diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py new file mode 100644 index 0000000..91eb36c --- /dev/null +++ b/xlb/operator/macroscopic/__init__.py @@ -0,0 +1 @@ +from xlb.operator.macroscopic.macroscopic import Macroscopic diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py new file mode 100644 index 0000000..ff6ae83 --- /dev/null +++ b/xlb/operator/macroscopic/macroscopic.py @@ -0,0 +1,38 @@ +# Base class for all equilibriums + +from functools import partial +import jax.numpy as jnp +from jax import jit + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class Macroscopic(Operator): + """ + Base class for all macroscopic operators + + TODO: Currently this is only used for the standard rho and u moments. + In the future, this should be extended to include higher order moments + and other physic types (e.g. temperature, electromagnetism, etc...) + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @partial(jit, static_argnums=(0), inline=True) + def apply_jax(self, f): + """ + Apply the macroscopic operator to the lattice distribution function + """ + + rho = jnp.sum(f, axis=-1, keepdims=True) + c = jnp.array(self.velocity_set.c, dtype=f.dtype).T + u = jnp.dot(f, c) / rho + + return rho, u diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py new file mode 100644 index 0000000..c6ef3d7 --- /dev/null +++ b/xlb/operator/operator.py @@ -0,0 +1,81 @@ +# Base class for all operators, (collision, streaming, equilibrium, etc.) + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend + + +class Operator: + """ + Base class for all operators, collision, streaming, equilibrium, etc. + + This class is responsible for handling compute backends. + """ + + def __init__(self, velocity_set, compute_backend): + self.velocity_set = velocity_set + self.compute_backend = compute_backend + + # Check if compute backend is supported + # TODO: Fix check for compute backend + #if self.compute_backend not in self.supported_compute_backends: + # raise ValueError( + # f"Compute backend {self.compute_backend} not supported by {self.__class__.__name__}" + # ) + + def __call__(self, *args, **kwargs): + """ + Apply the operator to a input. This method will call the + appropriate apply method based on the compute backend. + """ + if self.compute_backend == ComputeBackend.JAX: + return self.apply_jax(*args, **kwargs) + elif self.compute_backend == ComputeBackend.NUMBA: + return self.apply_numba(*args, **kwargs) + + def apply_jax(self, *args, **kwargs): + """ + Implement the operator using JAX. + If using the JAX backend, this method will then become + the self.__call__ method. + """ + raise NotImplementedError("Child class must implement apply_jax") + + def apply_numba(self, *args, **kwargs): + """ + Implement the operator using Numba. + If using the Numba backend, this method will then become + the self.__call__ method. + """ + raise NotImplementedError("Child class must implement apply_numba") + + def construct_numba(self): + """ + Constructs numba kernel for the operator + """ + raise NotImplementedError("Child class must implement apply_numba") + + @property + def supported_compute_backend(self): + """ + Returns the supported compute backend for the operator + """ + supported_backend = [] + if self._is_method_overridden("apply_jax"): + supported_backend.append(ComputeBackend.JAX) + elif self._is_method_overridden("apply_numba"): + supported_backend.append(ComputeBackend.NUMBA) + else: + raise NotImplementedError("No supported compute backend implemented") + return supported_backend + + def _is_method_overridden(self, method_name): + """ + Helper method to check if a method is overridden in a subclass. + """ + method = getattr(self, method_name, None) + if method is None: + return False + return method.__func__ is not getattr(Operator, method_name, None).__func__ + + def __repr__(self): + return f"{self.__class__.__name__}()" diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py new file mode 100644 index 0000000..0469f13 --- /dev/null +++ b/xlb/operator/stepper/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.stepper.nse import NSE diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py new file mode 100644 index 0000000..9ef1dbc --- /dev/null +++ b/xlb/operator/stepper/nse.py @@ -0,0 +1,93 @@ +# Base class for all stepper operators + +from functools import partial +from jax import jit + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.boundary_condition import ImplementationStep + + +class NSE(Stepper): + """ + Class that handles the construction of lattice boltzmann stepping operator for the Navier-Stokes equations + + TODO: Check that the given operators (collision, stream, equilibrium, macroscopic, ...) are compatible + with the Navier-Stokes equations + """ + + def __init__( + self, + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions=[], + forcing=None, + precision_policy=None, + ): + super().__init__( + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions, + forcing, + precision_policy, + ) + + #@partial(jit, static_argnums=(0, 5), donate_argnums=(1)) # TODO: This donate args seems to break out of core memory + @partial(jit, static_argnums=(0, 5)) + def apply_jax(self, f, boundary_id, mask, timestep): + """ + Perform a single step of the lattice boltzmann method + """ + + # Cast to compute precision + f_pre_collision = self.precision_policy.cast_to_compute_jax(f) + + # Compute the macroscopic variables + rho, u = self.macroscopic(f_pre_collision) + + # Compute equilibrium + feq = self.equilibrium(rho, u) + + # Apply collision + f_post_collision = self.collision( + f, + feq, + rho, + u, + ) + + # Apply collision type boundary conditions + for id_number, bc in self.collision_boundary_conditions.items(): + f_post_collision = bc( + f_pre_collision, + f_post_collision, + boundary_id == id_number, + mask, + ) + f_pre_streaming = f_post_collision + + ## Apply forcing + # if self.forcing_op is not None: + # f = self.forcing_op.apply_jax(f, timestep) + + # Apply streaming + f_post_streaming = self.stream(f_pre_streaming) + + # Apply boundary conditions + for id_number, bc in self.stream_boundary_conditions.items(): + f_post_streaming = bc( + f_pre_streaming, + f_post_streaming, + boundary_id == id_number, + mask, + ) + + # Copy back to store precision + f = self.precision_policy.cast_to_store_jax(f_post_streaming) + + return f diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py new file mode 100644 index 0000000..b5c0a44 --- /dev/null +++ b/xlb/operator/stepper/stepper.py @@ -0,0 +1,84 @@ +# Base class for all stepper operators + +import jax.numpy as jnp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition import ImplementationStep + + +class Stepper(Operator): + """ + Class that handles the construction of lattice boltzmann stepping operator + """ + + def __init__( + self, + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions=[], + forcing=None, + precision_policy=None, + compute_backend=ComputeBackend.JAX, + ): + # Set parameters + self.collision = collision + self.stream = stream + self.equilibrium = equilibrium + self.macroscopic = macroscopic + self.boundary_conditions = boundary_conditions + self.forcing = forcing + self.precision_policy = precision_policy + self.compute_backend = compute_backend + + # Get collision and stream boundary conditions + self.collision_boundary_conditions = {} + self.stream_boundary_conditions = {} + for id_number, bc in enumerate(self.boundary_conditions): + bc_id = id_number + 1 + if bc.implementation_step == ImplementationStep.COLLISION: + self.collision_boundary_conditions[bc_id] = bc + elif bc.implementation_step == ImplementationStep.STREAMING: + self.stream_boundary_conditions[bc_id] = bc + else: + raise ValueError("Boundary condition step not recognized") + + # Get all operators for checking + self.operators = [ + collision, + stream, + equilibrium, + macroscopic, + *boundary_conditions, + ] + + # Get velocity set and backend + velocity_sets = set([op.velocity_set for op in self.operators]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + self.velocity_set = velocity_sets.pop() + compute_backends = set([op.compute_backend for op in self.operators]) + assert len(compute_backends) == 1, "All compute backends must be the same" + self.compute_backend = compute_backends.pop() + + # Initialize operator + super().__init__(self.velocity_set, self.compute_backend) + + def set_boundary(self, ijk): + """ + Set boundary condition arrays + These store the boundary condition information for each boundary + """ + # Empty boundary condition array + boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) + mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) + + # Set boundary condition arrays + for id_number, bc in self.collision_boundary_conditions.items(): + boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) + for id_number, bc in self.stream_boundary_conditions.items(): + boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) + + return boundary_id, mask diff --git a/xlb/operator/stream/__init__.py b/xlb/operator/stream/__init__.py new file mode 100644 index 0000000..9093da7 --- /dev/null +++ b/xlb/operator/stream/__init__.py @@ -0,0 +1 @@ +from xlb.operator.stream.stream import Stream diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py new file mode 100644 index 0000000..7d31a3b --- /dev/null +++ b/xlb/operator/stream/stream.py @@ -0,0 +1,88 @@ +# Base class for all streaming operators + +from functools import partial +import jax.numpy as jnp +from jax import jit, vmap +import numba + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class Stream(Operator): + """ + Base class for all streaming operators. + + TODO: Currently only this one streaming operator is implemented but + in the future we may have more streaming operators. For example, + one might want a multi-step streaming operator. + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackend.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @partial(jit, static_argnums=(0), donate_argnums=(1,)) + def apply_jax(self, f): + """ + JAX implementation of the streaming step. + + Parameters + ---------- + f: jax.numpy.ndarray + The distribution function. + """ + + def _streaming(f, c): + """ + Perform individual streaming operation in a direction. + + Parameters + ---------- + f: The distribution function. + c: The streaming direction vector. + + Returns + ------- + jax.numpy.ndarray + The updated distribution function after streaming. + """ + if self.velocity_set.d == 2: + return jnp.roll(f, (c[0], c[1]), axis=(0, 1)) + elif self.velocity_set.d == 3: + return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) + + return vmap(_streaming, in_axes=(-1, 0), out_axes=-1)( + f, jnp.array(self.velocity_set.c).T + ) + + def construct_numba(self, dtype=numba.float32): + """ + Numba implementation of the streaming step. + """ + + # Get needed values for numba functions + d = velocity_set.d + q = velocity_set.q + c = velocity_set.c.T + + # Make numba functions + @cuda.jit(device=True) + def _streaming(f_array, f, ijk): + # Stream to the next node + for _ in range(q): + if d == 2: + i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] + j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] + f_array[i, j, _] = f[_] + else: + i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] + j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] + k = (ijk[2] + int32(c[_, 2])) % f_array.shape[2] + f_array[i, j, k, _] = f[_] + + return _streaming diff --git a/xlb/physics_type.py b/xlb/physics_type.py new file mode 100644 index 0000000..73d9e70 --- /dev/null +++ b/xlb/physics_type.py @@ -0,0 +1,7 @@ +# Enum used to keep track of the physics types supported by different operators + +from enum import Enum + +class PhysicsType(Enum): + NSE = 1 # Navier-Stokes Equations + ADE = 2 # Advection-Diffusion Equations diff --git a/xlb/precision_policy/__init__.py b/xlb/precision_policy/__init__.py new file mode 100644 index 0000000..c228387 --- /dev/null +++ b/xlb/precision_policy/__init__.py @@ -0,0 +1,2 @@ +from xlb.precision_policy.precision_policy import PrecisionPolicy +from xlb.precision_policy.fp32fp32 import Fp32Fp32 diff --git a/xlb/precision_policy/fp32fp32.py b/xlb/precision_policy/fp32fp32.py new file mode 100644 index 0000000..1e37d2c --- /dev/null +++ b/xlb/precision_policy/fp32fp32.py @@ -0,0 +1,20 @@ +# Purpose: Precision policy for lattice Boltzmann method with computation and +# storage precision both set to float32. + +import jax.numpy as jnp + +from xlb.precision_policy.precision_policy import PrecisionPolicy + + +class Fp32Fp32(PrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation and storage + precision both set to float32. + + Parameters + ---------- + None + """ + + def __init__(self): + super().__init__(jnp.float32, jnp.float32) diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py new file mode 100644 index 0000000..459ce74 --- /dev/null +++ b/xlb/precision_policy/precision_policy.py @@ -0,0 +1,159 @@ +# Precision policy for lattice Boltzmann method +# TODO: possibly refctor this to be more general + +from functools import partial +import jax.numpy as jnp +from jax import jit +import numba +from numba import cuda + + +class PrecisionPolicy(object): + """ + Base class for precision policy in lattice Boltzmann method. + Basic idea is to allow for storing the lattice in a different precision than the computation. + + Stores dtype in jax but also contains same information for other backends such as numba. + + Parameters + ---------- + compute_dtype: jax.numpy.dtype + The precision used for computation. + storage_dtype: jax.numpy.dtype + The precision used for storage. + """ + + def __init__(self, compute_dtype, storage_dtype): + # Store the dtypes (jax) + self.compute_dtype = compute_dtype + self.storage_dtype = storage_dtype + + # Get the corresponding numba dtypes + self.compute_dtype_numba = self._get_numba_dtype(compute_dtype) + self.storage_dtype_numba = self._get_numba_dtype(storage_dtype) + + # Check that compute dtype is one of the supported dtypes (float16, float32, float64) + self.supported_compute_dtypes = [jnp.float16, jnp.float32, jnp.float64] + if self.compute_dtype not in self.supported_compute_dtypes: + raise ValueError( + f"Compute dtype {self.compute_dtype} is not supported. Supported dtypes are {self.supported_compute_dtypes}" + ) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def cast_to_compute_jax(self, array): + """ + Cast the array to the computation precision + + Parameters + ---------- + Array: jax.numpy.ndarray + The array to cast. + + Returns + ------- + jax.numpy.ndarray + The casted array + """ + return array.astype(self.compute_dtype) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def cast_to_store_jax(self, array): + """ + Cast the array to the storage precision + + Parameters + ---------- + Array: jax.numpy.ndarray + The array to cast. + + Returns + ------- + jax.numpy.ndarray + The casted array + """ + return array.astype(self.storage_dtype) + + def cast_to_compute_numba(self): + """ + Constructs a numba function to cast a value to the computation precision + + Parameters + ---------- + value: float + The value to cast. + + Returns + ------- + float + The casted value + """ + return self._cast_to_dtype_numba(self.compute_dtype_numba) + + def cast_to_store_numba(self): + """ + Constructs a numba function to cast a value to the storage precision + + Parameters + ---------- + value: float + The value to cast. + + Returns + ------- + float + The casted value + """ + return self._cast_to_dtype_numba(self.storage_dtype_numba) + + def _cast_to_dytpe_numba(self, dtype): + """ + Constructs a numba function to cast a value to the computation precision + + Parameters + ---------- + value: float + The value to cast. + + Returns + ------- + float + The casted value + """ + + @cuda.jit(device=True) + def cast_to_dtype(value): + return dtype(value) + + def _get_numba_dtype(self, dtype): + """ + Get the corresponding numba dtype + + # TODO: Make this more general + + Parameters + ---------- + dtype: jax.numpy.dtype + The dtype to convert + + Returns + ------- + numba.dtype + The corresponding numba dtype + """ + if dtype == jnp.float16: + return numba.float16 + elif dtype == jnp.float32: + return numba.float32 + elif dtype == jnp.float64: + return numba.float64 + elif dtype == jnp.int32: + return numba.int32 + elif dtype == jnp.int64: + return numba.int64 + elif dtype == jnp.int16: + return numba.int16 + else: + raise ValueError(f"Unsupported dtype {dtype}") + + def __repr__(self): + return f"compute_dtype={self.compute_dtype}/{self.storage_dtype}" diff --git a/src/utils.py b/xlb/utils/utils.py similarity index 73% rename from src/utils.py rename to xlb/utils/utils.py index 1720ca6..d01c500 100644 --- a/src/utils.py +++ b/xlb/utils/utils.py @@ -15,7 +15,7 @@ @partial(jit, static_argnums=(1, 2)) -def downsample_field(field, factor, method='bicubic'): +def downsample_field(field, factor, method="bicubic"): """ Downsample a JAX array by a factor of `factor` along each axis. @@ -38,12 +38,15 @@ def downsample_field(field, factor, method='bicubic'): else: new_shape = tuple(dim // factor for dim in field.shape[:-1]) downsampled_components = [] - for i in range(field.shape[-1]): # Iterate over the last dimension (vector components) + for i in range( + field.shape[-1] + ): # Iterate over the last dimension (vector components) resized = resize(field[..., i], new_shape, method=method) downsampled_components.append(resized) return jnp.stack(downsampled_components, axis=-1) + def save_image(timestep, fld, prefix=None): """ Save an image of a field at a given timestep. @@ -78,16 +81,17 @@ def save_image(timestep, fld, prefix=None): fld = np.sqrt(fld[..., 0] ** 2 + fld[..., 1] ** 2) plt.clf() - plt.imsave(fname + '.png', fld.T, cmap=cm.nipy_spectral, origin='lower') + plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower") + -def save_fields_vtk(timestep, fields, output_dir='.', prefix='fields'): +def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): """ Save VTK fields to the specified directory. Parameters ---------- timestep (int): The timestep number to be associated with the saved fields. - fields (Dict[str, np.ndarray]): A dictionary of fields to be saved. Each field must be an array-like object + fields (Dict[str, np.ndarray]): A dictionary of fields to be saved. Each field must be an array-like object with dimensions (nx, ny) for 2D fields or (nx, ny, nz) for 3D fields, where: - nx : int, number of grid points along the x-axis - ny : int, number of grid points along the y-axis @@ -112,9 +116,11 @@ def save_fields_vtk(timestep, fields, output_dir='.', prefix='fields'): if key == list(fields.keys())[0]: dimensions = value.shape else: - assert value.shape == dimensions, "All fields must have the same dimensions!" + assert ( + value.shape == dimensions + ), "All fields must have the same dimensions!" - output_filename = os.path.join(output_dir, prefix + "_" + f"{timestep:07d}.vtk") + output_filename = os.path.join(output_dir, prefix + "_" + f"{timestep:07d}.vtk") # Add 1 to the dimensions tuple as we store cell values dimensions = tuple([dim + 1 for dim in dimensions]) @@ -123,17 +129,18 @@ def save_fields_vtk(timestep, fields, output_dir='.', prefix='fields'): if value.ndim == 2: dimensions = dimensions + (1,) - grid = pv.ImageData(dimensions=dimensions) + grid = pv.UniformGrid(dimensions=dimensions) # Add the fields to the grid for key, value in fields.items(): - grid[key] = value.flatten(order='F') + grid[key] = value.flatten(order="F") # Save the grid to a VTK file start = time() grid.save(output_filename, binary=True) print(f"Saved {output_filename} in {time() - start:.6f} seconds.") + def live_volume_randering(timestep, field): # WORK IN PROGRESS """ @@ -157,29 +164,30 @@ def live_volume_randering(timestep, field): if field.ndim != 3: raise ValueError("The input field must be 3D!") dimensions = field.shape - grid = pv.ImageData(dimensions=dimensions) + grid = pv.UniformGrid(dimensions=dimensions) # Add the field to the grid - grid['field'] = field.flatten(order='F') + grid["field"] = field.flatten(order="F") # Create the rendering scene if timestep == 0: plt.ion() plt.figure(figsize=(10, 10)) - plt.axis('off') + plt.axis("off") plt.title("Live rendering of the field") pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap='nipy_spectral', opacity='sigmoid_10', shade=False) + pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) plt.imshow(pl.screenshot()) else: pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap='nipy_spectral', opacity='sigmoid_10', shade=False) + pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) # Update the rendering scene every 0.1 seconds plt.imshow(pl.screenshot()) plt.pause(0.1) -def save_BCs_vtk(timestep, BCs, gridInfo, output_dir='.'): + +def save_BCs_vtk(timestep, BCs, gridInfo, output_dir="."): """ Save boundary conditions as VTK format to the specified directory. @@ -200,14 +208,14 @@ def save_BCs_vtk(timestep, BCs, gridInfo, output_dir='.'): """ # Create a uniform grid - if gridInfo['nz'] == 0: - gridDimensions = (gridInfo['nx'] + 1, gridInfo['ny'] + 1, 1) - fieldDimensions = (gridInfo['nx'], gridInfo['ny'], 1) + if gridInfo["nz"] == 0: + gridDimensions = (gridInfo["nx"] + 1, gridInfo["ny"] + 1, 1) + fieldDimensions = (gridInfo["nx"], gridInfo["ny"], 1) else: - gridDimensions = (gridInfo['nx'] + 1, gridInfo['ny'] + 1, gridInfo['nz'] + 1) - fieldDimensions = (gridInfo['nx'], gridInfo['ny'], gridInfo['nz']) + gridDimensions = (gridInfo["nx"] + 1, gridInfo["ny"] + 1, gridInfo["nz"] + 1) + fieldDimensions = (gridInfo["nx"], gridInfo["ny"], gridInfo["nz"]) - grid = pv.ImageData(dimensions=gridDimensions) + grid = pv.UniformGrid(dimensions=gridDimensions) # Dictionary to keep track of encountered BC names bcNamesCount = {} @@ -226,16 +234,16 @@ def save_BCs_vtk(timestep, BCs, gridInfo, output_dir='.'): bcIndices = bc.indices # Convert indices to 1D indices - if gridInfo['dim'] == 2: - bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions[:-1], order='F') + if gridInfo["dim"] == 2: + bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions[:-1], order="F") else: - bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions, order='F') + bcIndices = np.ravel_multi_index(bcIndices, fieldDimensions, order="F") - grid[bcName] = np.zeros(fieldDimensions, dtype=bool).flatten(order='F') + grid[bcName] = np.zeros(fieldDimensions, dtype=bool).flatten(order="F") grid[bcName][bcIndices] = True # Save the grid to a VTK file - output_filename = os.path.join(output_dir, "BCs_" + f"{timestep:07d}.vtk") + output_filename = os.path.join(output_dir, "BCs_" + f"{timestep:07d}.vtk") start = time() grid.save(output_filename, binary=True) @@ -267,10 +275,15 @@ def rotate_geometry(indices, origin, axis, angle): This function rotates the mesh by applying a rotation matrix to the voxel indices. The rotation matrix is calculated using the axis-angle representation of rotations. The origin of the rotation axis is assumed to be at (0, 0, 0). """ - indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat(axis, angle) + origin - return tuple(jnp.rint(indices_rotated).astype('int32').T) + indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat( + axis, angle + ) + origin + return tuple(jnp.rint(indices_rotated).astype("int32").T) + -def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None): +def voxelize_stl( + stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None +): """ Converts an STL file to a voxelized mesh. @@ -309,7 +322,7 @@ def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, def axangle2mat(axis, angle, is_normalized=False): - ''' Rotation matrix for rotation angle `angle` around `axis` + """Rotation matrix for rotation angle `angle` around `axis` Parameters ---------- axis : 3 element sequence @@ -326,7 +339,7 @@ def axangle2mat(axis, angle, is_normalized=False): ----- From : https://github.com/matthew-brett/transforms3d Ref : http://en.wikipedia.org/wiki/Rotation_matrix#Axis_and_angle - ''' + """ x, y, z = axis if not is_normalized: n = jnp.sqrt(x * x + y * y + z * z) @@ -345,70 +358,10 @@ def axangle2mat(axis, angle, is_normalized=False): xyC = x * yC yzC = y * zC zxC = z * xC - return jnp.array([ - [x * xC + c, xyC - zs, zxC + ys], - [xyC + zs, y * yC + c, yzC - xs], - [zxC - ys, yzC + xs, z * zC + c]]) - -@partial(jit) -def q_criterion(u): - # Compute derivatives - u_x = u[..., 0] - u_y = u[..., 1] - u_z = u[..., 2] - - # Compute derivatives - u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2 - u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2 - u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2 - u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2 - u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2 - u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2 - u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2 - u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2 - u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2 - - # Compute vorticity - mu_x = u_z_dy - u_y_dz - mu_y = u_x_dz - u_z_dx - mu_z = u_y_dx - u_x_dy - norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2) - - # Compute strain rate - s_0_0 = u_x_dx - s_0_1 = 0.5 * (u_x_dy + u_y_dx) - s_0_2 = 0.5 * (u_x_dz + u_z_dx) - s_1_0 = s_0_1 - s_1_1 = u_y_dy - s_1_2 = 0.5 * (u_y_dz + u_z_dy) - s_2_0 = s_0_2 - s_2_1 = s_1_2 - s_2_2 = u_z_dz - s_dot_s = ( - s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 + - s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 + - s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2 + return jnp.array( + [ + [x * xC + c, xyC - zs, zxC + ys], + [xyC + zs, y * yC + c, yzC - xs], + [zxC - ys, yzC + xs, z * zC + c], + ] ) - - # Compute omega - omega_0_0 = 0.0 - omega_0_1 = 0.5 * (u_x_dy - u_y_dx) - omega_0_2 = 0.5 * (u_x_dz - u_z_dx) - omega_1_0 = -omega_0_1 - omega_1_1 = 0.0 - omega_1_2 = 0.5 * (u_y_dz - u_z_dy) - omega_2_0 = -omega_0_2 - omega_2_1 = -omega_1_2 - omega_2_2 = 0.0 - omega_dot_omega = ( - omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 + - omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 + - omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2 - ) - - # Compute q-criterion - q = 0.5 * (omega_dot_omega - s_dot_s) - - return norm_mu, q - - diff --git a/xlb/velocity_set/__init__.py b/xlb/velocity_set/__init__.py new file mode 100644 index 0000000..5b7b737 --- /dev/null +++ b/xlb/velocity_set/__init__.py @@ -0,0 +1,4 @@ +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set.d2q9 import D2Q9 +from xlb.velocity_set.d3q19 import D3Q19 +from xlb.velocity_set.d3q27 import D3Q27 diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py new file mode 100644 index 0000000..d5f8793 --- /dev/null +++ b/xlb/velocity_set/d2q9.py @@ -0,0 +1,26 @@ +# Description: Lattice class for 2D D2Q9 lattice. + +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet + + +class D2Q9(VelocitySet): + """ + Velocity Set for 2D D2Q9 lattice. + + D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the + Lat tice Boltzmann Method for simulating fluid flows in two dimensions. + """ + + def __init__(self): + # Construct the velocity vectors and weights + cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] + cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] + c = np.array(tuple(zip(cx, cy))).T + w = np.array( + [4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36] + ) + + # Call the parent constructor + super().__init__(2, 9, c, w) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py new file mode 100644 index 0000000..5debdea --- /dev/null +++ b/xlb/velocity_set/d3q19.py @@ -0,0 +1,36 @@ +# Description: Lattice class for 3D D3Q19 lattice. + +import itertools +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet + + +class D3Q19(VelocitySet): + """ + Velocity Set for 3D D3Q19 lattice. + + D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the + Lattice Boltzmann Method for simulating fluid flows in three dimensions. + """ + + def __init__(self): + # Construct the velocity vectors and weights + c = np.array( + [ + ci + for ci in itertools.product([-1, 0, 1], repeat=3) + if np.sum(np.abs(ci)) <= 2 + ] + ).T + w = np.zeros(19) + for i in range(19): + if np.sum(np.abs(c[:, i])) == 0: + w[i] = 1.0 / 3.0 + elif np.sum(np.abs(c[:, i])) == 1: + w[i] = 1.0 / 18.0 + elif np.sum(np.abs(c[:, i])) == 2: + w[i] = 1.0 / 36.0 + + # Initialize the lattice + super().__init__(3, 19, c, w) diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py new file mode 100644 index 0000000..702acf4 --- /dev/null +++ b/xlb/velocity_set/d3q27.py @@ -0,0 +1,32 @@ +# Description: Lattice class for 3D D3Q27 lattice. + +import itertools +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet + + +class D3Q27(VelocitySet): + """ + Velocity Set for 3D D3Q27 lattice. + + D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the + Lattice Boltzmann Method for simulating fluid flows in three dimensions. + """ + + def __init__(self): + # Construct the velocity vectors and weights + c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T + w = np.zeros(27) + for i in range(27): + if np.sum(np.abs(c[:, i])) == 0: + w[i] = 8.0 / 27.0 + elif np.sum(np.abs(c[:, i])) == 1: + w[i] = 2.0 / 27.0 + elif np.sum(np.abs(c[:, i])) == 2: + w[i] = 1.0 / 54.0 + elif np.sum(np.abs(c[:, i])) == 3: + w[i] = 1.0 / 216.0 + + # Initialize the Lattice + super().__init__(3, 27, c, w) diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py new file mode 100644 index 0000000..05ba2a2 --- /dev/null +++ b/xlb/velocity_set/velocity_set.py @@ -0,0 +1,200 @@ +# Base Velocity Set class + +import math +import numpy as np +from functools import partial +import jax.numpy as jnp +from jax import jit, vmap +import numba +from numba import cuda, float32, int32 + + +class VelocitySet(object): + """ + Base class for the velocity set of the Lattice Boltzmann Method (LBM), e.g. D2Q9, D3Q27, etc. + + Parameters + ---------- + d: int + The dimension of the lattice. + q: int + The number of velocities of the lattice. + c: numpy.ndarray + The velocity vectors of the lattice. Shape: (q, d) + w: numpy.ndarray + The weights of the lattice. Shape: (q,) + """ + + def __init__(self, d, q, c, w): + # Store the dimension and the number of velocities + self.d = d + self.q = q + + # Constants + self.cs = math.sqrt(3) / 3.0 + self.cs2 = 1.0 / 3.0 + self.inv_cs2 = 3.0 + + # Construct the properties of the lattice + self.c = c + self.w = w + self.cc = self._construct_lattice_moment() + self.opp_indices = self._construct_opposite_indices() + self.main_indices = self._construct_main_indices() + self.right_indices = self._construct_right_indices() + self.left_indices = self._construct_left_indices() + + @partial(jit, static_argnums=(0,)) + def momentum_flux_jax(self, fneq): + """ + This function computes the momentum flux, which is the product of the non-equilibrium + distribution functions (fneq) and the lattice moments (cc). + + The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann + Method (LBM). + + Parameters + ---------- + fneq: jax.numpy.ndarray + The non-equilibrium distribution functions. + + Returns + ------- + jax.numpy.ndarray + The computed momentum flux. + """ + + return jnp.dot(fneq, self.cc) + + def momentum_flux_numba(self): + """ + This function computes the momentum flux, which is the product of the non-equilibrium + """ + raise NotImplementedError + + @partial(jit, static_argnums=(0,)) + def decompose_shear_jax(self, fneq): + """ + Decompose fneq into shear components for D3Q27 lattice. + + TODO: add generali + + Parameters + ---------- + fneq : jax.numpy.ndarray + Non-equilibrium distribution function. + + Returns + ------- + jax.numpy.ndarray + Shear components of fneq. + """ + raise NotImplementedError + + def decompose_shear_numba(self): + """ + Decompose fneq into shear components for D3Q27 lattice. + """ + raise NotImplementedError + + def _construct_lattice_moment(self): + """ + This function constructs the moments of the lattice. + + The moments are the products of the velocity vectors, which are used in the computation of + the equilibrium distribution functions and the collision operator in the Lattice Boltzmann + Method (LBM). + + Returns + ------- + cc: numpy.ndarray + The moments of the lattice. + """ + c = self.c.T + # Counter for the loop + cntr = 0 + + # nt: number of independent elements of a symmetric tensor + nt = self.d * (self.d + 1) // 2 + + cc = np.zeros((self.q, nt)) + for a in range(0, self.d): + for b in range(a, self.d): + cc[:, cntr] = c[:, a] * c[:, b] + cntr += 1 + + return cc + + def _construct_opposite_indices(self): + """ + This function constructs the indices of the opposite velocities for each velocity. + + The opposite velocity of a velocity is the velocity that has the same magnitude but the + opposite direction. + + Returns + ------- + opposite: numpy.ndarray + The indices of the opposite velocities. + """ + c = self.c.T + opposite = np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) + return opposite + + def _construct_main_indices(self): + """ + This function constructs the indices of the main velocities. + + The main velocities are the velocities that have a magnitude of 1 in lattice units. + + Returns + ------- + numpy.ndarray + The indices of the main velocities. + """ + c = self.c.T + if self.d == 2: + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] + + elif self.d == 3: + return np.nonzero( + (np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1) + )[0] + + def _construct_right_indices(self): + """ + This function constructs the indices of the velocities that point in the positive + x-direction. + + Returns + ------- + numpy.ndarray + The indices of the right velocities. + """ + c = self.c.T + return np.nonzero(c[:, 0] == 1)[0] + + def _construct_left_indices(self): + """ + This function constructs the indices of the velocities that point in the negative + x-direction. + + Returns + ------- + numpy.ndarray + The indices of the left velocities. + """ + c = self.c.T + return np.nonzero(c[:, 0] == -1)[0] + + def __str__(self): + """ + This function returns the name of the lattice in the format of DxQy. + """ + return self.__repr__() + + def __repr__(self): + """ + This function returns the name of the lattice in the format of DxQy. + """ + return "D{}Q{}".format(self.d, self.q) From 1c8e75ca399b4a0426a4b013b5a5b3331c83fea3 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Sat, 16 Dec 2023 03:19:22 -0500 Subject: [PATCH 002/144] Added pyc to gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d774451..2fad3f7 100644 --- a/.gitignore +++ b/.gitignore @@ -39,7 +39,7 @@ Thumbs.db __pycache__/ *.py[cod] *$py.class - +**pyc # C extensions *.so From 9fc7d203cf440819fc2a6efe3eaf3789270e0538 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Wed, 10 Jan 2024 17:41:52 -0500 Subject: [PATCH 003/144] Refactoring backend initiation and data structure: collison and eq done. --- xlb/__init__.py | 4 +- ...compute_backend.py => compute_backends.py} | 6 +- .../boundary_condition/boundary_condition.py | 6 +- xlb/operator/boundary_condition/do_nothing.py | 8 +- .../equilibrium_boundary.py | 6 +- .../boundary_condition/full_bounce_back.py | 6 +- .../boundary_condition/halfway_bounce_back.py | 6 +- xlb/operator/collision/bgk.py | 106 ++---------- xlb/operator/collision/collision.py | 30 +--- xlb/operator/collision/kbc.py | 126 +++++++------- xlb/operator/equilibrium/__init__.py | 2 +- xlb/operator/equilibrium/equilibrium.py | 81 +-------- .../equilibrium/quadratic_equilibrium.py | 37 +++++ xlb/operator/macroscopic/macroscopic.py | 4 +- xlb/operator/operator.py | 70 ++++---- xlb/operator/stepper/nse.py | 7 +- xlb/operator/stepper/stepper.py | 4 +- xlb/operator/stream/stream.py | 4 +- xlb/precision_policy/__init__.py | 3 +- xlb/precision_policy/fp32fp32.py | 20 --- .../jax_precision_policy/___init__.py | 1 + .../jax_precision_policy.py | 72 ++++++++ xlb/precision_policy/precision_policy.py | 155 +----------------- xlb/velocity_set/d2q9.py | 2 +- 24 files changed, 266 insertions(+), 500 deletions(-) rename xlb/{compute_backend.py => compute_backends.py} (54%) create mode 100644 xlb/operator/equilibrium/quadratic_equilibrium.py delete mode 100644 xlb/precision_policy/fp32fp32.py create mode 100644 xlb/precision_policy/jax_precision_policy/___init__.py create mode 100644 xlb/precision_policy/jax_precision_policy/jax_precision_policy.py diff --git a/xlb/__init__.py b/xlb/__init__.py index 8f49baa..e4edc5c 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -1,5 +1,5 @@ # Enum classes -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.physics_type import PhysicsType # Precision policy @@ -9,10 +9,10 @@ import xlb.velocity_set # Operators +import xlb.operator.equilibrium import xlb.operator.collision import xlb.operator.stream import xlb.operator.boundary_condition # import xlb.operator.force -import xlb.operator.equilibrium import xlb.operator.macroscopic import xlb.operator.stepper diff --git a/xlb/compute_backend.py b/xlb/compute_backends.py similarity index 54% rename from xlb/compute_backend.py rename to xlb/compute_backends.py index dee998f..f60073a 100644 --- a/xlb/compute_backend.py +++ b/xlb/compute_backends.py @@ -2,8 +2,6 @@ from enum import Enum -class ComputeBackend(Enum): +class ComputeBackends(Enum): JAX = 1 - NUMBA = 2 - PYTORCH = 3 - WARP = 4 + WARP = 2 diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 7e8909a..cca6ed0 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -10,7 +10,7 @@ from xlb.operator.operator import Operator from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends # Enum for implementation step class ImplementationStep(Enum): @@ -27,7 +27,7 @@ def __init__( set_boundary, implementation_step: ImplementationStep, velocity_set: VelocitySet, - compute_backend: ComputeBackend.JAX, + compute_backend: ComputeBackends.JAX, ): super().__init__(velocity_set, compute_backend) @@ -35,7 +35,7 @@ def __init__( self.implementation_step = implementation_step # Set boundary function - if compute_backend == ComputeBackend.JAX: + if compute_backend == ComputeBackends.JAX: self.set_boundary = set_boundary else: raise NotImplementedError diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py index f8f28ed..39fa4d8 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -5,9 +5,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend -from xlb.operator.stream.stream import Stream -from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.compute_backends import ComputeBackends from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, ImplementationStep, @@ -22,7 +20,7 @@ def __init__( self, set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): super().__init__( set_boundary=set_boundary, @@ -36,7 +34,7 @@ def from_indices( cls, indices, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index f615a53..b3f124f 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -5,7 +5,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.stream.stream import Stream from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator.boundary_condition.boundary_condition import ( @@ -25,7 +25,7 @@ def __init__( u: tuple[float, float], equilibrium: Equilibrium, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): super().__init__( set_boundary=set_boundary, @@ -43,7 +43,7 @@ def from_indices( u: tuple[float, float], equilibrium: Equilibrium, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index fc883c8..311c73f 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -9,7 +9,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, ImplementationStep, @@ -24,7 +24,7 @@ def __init__( self, set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): super().__init__( set_boundary=set_boundary, @@ -38,7 +38,7 @@ def from_indices( cls, indices, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index 3b5b6de..8eb14fb 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -5,7 +5,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.stream.stream import Stream from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, @@ -21,7 +21,7 @@ def __init__( self, set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): super().__init__( set_boundary=set_boundary, @@ -35,7 +35,7 @@ def from_indices( cls, indices, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + compute_backend: ComputeBackends = ComputeBackends.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 19b5846..90943b8 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,109 +1,35 @@ -""" -BGK collision operator for LBM. -""" - import jax.numpy as jnp from jax import jit -from functools import partial -from numba import cuda, float32 - from xlb.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.collision.collision import Collision +from xlb.operator import Operator +from functools import partial class BGK(Collision): """ BGK collision operator for LBM. - - The BGK collision operator is the simplest collision operator for LBM. - It is based on the Bhatnagar-Gross-Krook approximation to the Boltzmann equation. - Reference: https://en.wikipedia.org/wiki/Bhatnagar%E2%80%93Gross%E2%80%93Krook_operator """ def __init__( - self, - omega: float, - velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, - ): - super().__init__( - omega=omega, - velocity_set=velocity_set, - compute_backend=compute_backend - ) - - @partial(jit, static_argnums=(0), donate_argnums=(1,2,3,4)) - def apply_jax( self, - f: jnp.ndarray, - feq: jnp.ndarray, - rho: jnp.ndarray, - u : jnp.ndarray, + omega: float, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, ): - """ - BGK collision step for lattice. - - The collision step is where the main physics of the LBM is applied. In the BGK approximation, - the distribution function is relaxed towards the equilibrium distribution function. - - Parameters - ---------- - f : jax.numpy.ndarray - The distribution function - feq : jax.numpy.ndarray - The equilibrium distribution function - rho : jax.numpy.ndarray - The macroscopic density - u : jax.numpy.ndarray - The macroscopic velocity + super().__init__( + omega=omega, velocity_set=velocity_set, compute_backend=compute_backend + ) - """ + @Operator.register_backend(ComputeBackends.JAX) + @partial(jit, static_argnums=(0,)) + def jax_implementation_2(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq return fout - def construct_numba(self): - """ - Numba implementation of the collision operator. - - Returns - ------- - _collision : numba.cuda.jit - The compiled numba function for the collision operator. - """ - - # Get needed parameters for numba function - omega = self.omega - omega = float32(omega) - - # Make numba function - @cuda.jit(device=True) - def _collision(f, feq, rho, u, fout): - """ - Numba BGK collision step for lattice. - - The collision step is where the main physics of the LBM is applied. In the BGK approximation, - the distribution function is relaxed towards the equilibrium distribution function. - - Parameters - ---------- - f : cuda.local.array - The distribution function - feq : cuda.local.array - The equilibrium distribution function - rho : cuda.local.array - The macroscopic density - u : cuda.local.array - The macroscopic velocity - fout : cuda.local.array - The output distribution function - """ - - # Relaxation - for i in range(f.shape[0]): - fout[i] = f[i] - omega * (f[i] - feq[i]) - - return fout - - return _collision + @Operator.register_backend(ComputeBackends.WARP) + def warp_implementation(self, *args, **kwargs): + # Implementation for the Warp backend + raise NotImplementedError diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 728f40c..9c22895 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -1,13 +1,7 @@ """ Base class for Collision operators """ - -import jax.numpy as jnp -from jax import jit -from functools import partial -import numba - -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.velocity_set import VelocitySet from xlb.operator import Operator @@ -27,22 +21,10 @@ class Collision(Operator): """ def __init__( - self, - omega: float, - velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, - ): + self, + omega: float, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, + ): super().__init__(velocity_set, compute_backend) self.omega = omega - - def apply_jax(self, f, feq, rho, u): - """ - Jax implementation of collision step. - """ - raise NotImplementedError("Child class must implement apply_jax.") - - def construct_numba(self, velocity_set: VelocitySet, dtype=numba.float32): - """ - Construct numba implementation of collision step. - """ - raise NotImplementedError("Child class must implement construct_numba.") diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index b5567bb..e79b6bd 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -6,9 +6,9 @@ from jax import jit from functools import partial from numba import cuda, float32 - +from xlb.operator import Operator from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.collision.collision import Collision @@ -20,27 +20,25 @@ class KBC(Collision): """ def __init__( - self, - omega, - velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, - ): + self, + omega, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, + ): super().__init__( - omega=omega, - velocity_set=velocity_set, - compute_backend=compute_backend + omega=omega, velocity_set=velocity_set, compute_backend=compute_backend ) self.epsilon = 1e-32 self.beta = self.omega * 0.5 self.inv_beta = 1.0 / self.beta - @partial(jit, static_argnums=(0,), donate_argnums=(1,2,3,4)) - def apply_jax( + @Operator.register_backend(ComputeBackends.JAX) + @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) + def jax_implementation( self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, - u: jnp.ndarray, ): """ KBC collision step for lattice. @@ -53,18 +51,19 @@ def apply_jax( Equilibrium distribution function. rho : jax.numpy.array Density. - u : jax.numpy.array - Velocity. """ - - # Compute shear TODO: Generalize this and possibly make it an operator or something fneq = f - feq + print(self.velocity_set) if isinstance(self.velocity_set, D2Q9): shear = self.decompose_shear_d2q9_jax(fneq) - delta_s = shear * rho / 4.0 # TODO: Check this + delta_s = shear * rho / 4.0 # TODO: Check this elif isinstance(self.velocity_set, D3Q27): shear = self.decompose_shear_d3q27_jax(fneq) delta_s = shear * rho + else: + raise NotImplementedError( + "Velocity set not supported: {}".format(type(self.velocity_set)) + ) # Perform collision delta_h = fneq - delta_s @@ -72,17 +71,12 @@ def apply_jax( delta_s, delta_h, feq ) / (self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq)) - fout = f - self.beta * (2.0 * delta_s + gamma[..., None] * delta_h) + fout = f - self.beta * (2.0 * delta_s + gamma[None, ...] * delta_h) return fout @partial(jit, static_argnums=(0,), inline=True) - def entropic_scalar_product( - self, - x: jnp.ndarray, - y: jnp.ndarray, - feq: jnp.ndarray - ): + def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarray): """ Compute the entropic scalar product of x and y to approximate gamma in KBC. @@ -91,18 +85,18 @@ def entropic_scalar_product( jax.numpy.array Entropic scalar product of x, y, and feq. """ - return jnp.sum(x * y / feq, axis=-1) + return jnp.sum(x * y / feq, axis=0) - @partial(jit, static_argnums=(0, 2), donate_argnums=(1,)) + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) def momentum_flux_jax( - self, - fneq: jnp.ndarray, - ): + self, + fneq: jnp.ndarray, + ): """ - This function computes the momentum flux, which is the product of the non-equilibrium + This function computes the momentum flux, which is the product of the non-equilibrium distribution functions (fneq) and the lattice moments (cc). - The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann + The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann Method (LBM). # TODO: probably move this to equilibrium calculation @@ -117,11 +111,10 @@ def momentum_flux_jax( jax.numpy.ndarray The computed momentum flux. """ - - return jnp.dot(fneq, jnp.array(self.velocity_set.cc, dtype=fneq.dtype)) + return jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) - @partial(jit, static_argnums=(0, 2), inline=True) + @partial(jit, static_argnums=(0,), inline=True) def decompose_shear_d3q27_jax(self, fneq): """ Decompose fneq into shear components for D3Q27 lattice. @@ -139,39 +132,40 @@ def decompose_shear_d3q27_jax(self, fneq): # Calculate the momentum flux Pi = self.momentum_flux_jax(fneq) - Nxz = Pi[..., 0] - Pi[..., 5] - Nyz = Pi[..., 3] - Pi[..., 5] + # Calculating Nxz and Nyz with indices moved to the first dimension + Nxz = Pi[0, ...] - Pi[5, ...] + Nyz = Pi[3, ...] - Pi[5, ...] # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) s = jnp.zeros_like(fneq) - s = s.at[..., 9].set((2.0 * Nxz - Nyz) / 6.0) - s = s.at[..., 18].set((2.0 * Nxz - Nyz) / 6.0) - s = s.at[..., 3].set((-Nxz + 2.0 * Nyz) / 6.0) - s = s.at[..., 6].set((-Nxz + 2.0 * Nyz) / 6.0) - s = s.at[..., 1].set((-Nxz - Nyz) / 6.0) - s = s.at[..., 2].set((-Nxz - Nyz) / 6.0) + s = s.at[9, ...].set((2.0 * Nxz - Nyz) / 6.0) + s = s.at[18, ...].set((2.0 * Nxz - Nyz) / 6.0) + s = s.at[3, ...].set((-Nxz + 2.0 * Nyz) / 6.0) + s = s.at[6, ...].set((-Nxz + 2.0 * Nyz) / 6.0) + s = s.at[1, ...].set((-Nxz - Nyz) / 6.0) + s = s.at[2, ...].set((-Nxz - Nyz) / 6.0) # For c = (i, j, 0) - s = s.at[..., 12].set(Pi[..., 1] / 4.0) - s = s.at[..., 24].set(Pi[..., 1] / 4.0) - s = s.at[..., 21].set(-Pi[..., 1] / 4.0) - s = s.at[..., 15].set(-Pi[..., 1] / 4.0) + s = s.at[12, ...].set(Pi[1, ...] / 4.0) + s = s.at[24, ...].set(Pi[1, ...] / 4.0) + s = s.at[21, ...].set(-Pi[1, ...] / 4.0) + s = s.at[15, ...].set(-Pi[1, ...] / 4.0) # For c = (i, 0, k) - s = s.at[..., 10].set(Pi[..., 2] / 4.0) - s = s.at[..., 20].set(Pi[..., 2] / 4.0) - s = s.at[..., 19].set(-Pi[..., 2] / 4.0) - s = s.at[..., 11].set(-Pi[..., 2] / 4.0) + s = s.at[10, ...].set(Pi[2, ...] / 4.0) + s = s.at[20, ...].set(Pi[2, ...] / 4.0) + s = s.at[19, ...].set(-Pi[2, ...] / 4.0) + s = s.at[11, ...].set(-Pi[2, ...] / 4.0) # For c = (0, j, k) - s = s.at[..., 8].set(Pi[..., 4] / 4.0) - s = s.at[..., 4].set(Pi[..., 4] / 4.0) - s = s.at[..., 7].set(-Pi[..., 4] / 4.0) - s = s.at[..., 5].set(-Pi[..., 4] / 4.0) + s = s.at[8, ...].set(Pi[4, ...] / 4.0) + s = s.at[4, ...].set(Pi[4, ...] / 4.0) + s = s.at[7, ...].set(-Pi[4, ...] / 4.0) + s = s.at[5, ...].set(-Pi[4, ...] / 4.0) return s - @partial(jit, static_argnums=(0, 2), inline=True) + @partial(jit, static_argnums=(0,), inline=True) def decompose_shear_d2q9_jax(self, fneq): """ Decompose fneq into shear components for D2Q9 lattice. @@ -187,17 +181,15 @@ def decompose_shear_d2q9_jax(self, fneq): Shear components of fneq. """ Pi = self.momentum_flux_jax(fneq) - N = Pi[..., 0] - Pi[..., 2] + N = Pi[0, ...] - Pi[2, ...] s = jnp.zeros_like(fneq) - s = s.at[..., 6].set(N) - s = s.at[..., 3].set(N) - s = s.at[..., 2].set(-N) - s = s.at[..., 1].set(-N) - s = s.at[..., 8].set(Pi[..., 1]) - s = s.at[..., 4].set(-Pi[..., 1]) - s = s.at[..., 5].set(-Pi[..., 1]) - s = s.at[..., 7].set(Pi[..., 1]) + s = s.at[3, ...].set(N) + s = s.at[6, ...].set(N) + s = s.at[2, ...].set(-N) + s = s.at[1, ...].set(-N) + s = s.at[8, ...].set(Pi[1, ...]) + s = s.at[4, ...].set(-Pi[1, ...]) + s = s.at[5, ...].set(-Pi[1, ...]) + s = s.at[7, ...].set(Pi[1, ...]) return s - - diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py index 587f673..1cf7459 100644 --- a/xlb/operator/equilibrium/__init__.py +++ b/xlb/operator/equilibrium/__init__.py @@ -1 +1 @@ -from xlb.operator.equilibrium.equilibrium import Equilibrium, QuadraticEquilibrium +from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py index 9de736f..f60d9ee 100644 --- a/xlb/operator/equilibrium/equilibrium.py +++ b/xlb/operator/equilibrium/equilibrium.py @@ -1,13 +1,6 @@ # Base class for all equilibriums - -from functools import partial -import jax.numpy as jnp -from jax import jit -import numba -from numba import cuda - from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator @@ -17,72 +10,8 @@ class Equilibrium(Operator): """ def __init__( - self, - velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, - ): - super().__init__(velocity_set, compute_backend) - - -class QuadraticEquilibrium(Equilibrium): - """ - Quadratic equilibrium of Boltzmann equation using hermite polynomials. - Standard equilibrium model for LBM. - - TODO: move this to a separate file and lower and higher order equilibriums - """ - - def __init__( - self, - velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, - ): + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, + ): super().__init__(velocity_set, compute_backend) - - @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) - def apply_jax(self, rho, u): - """ - JAX implementation of the equilibrium distribution function. - - # TODO: This might be optimized using a for loop for because - # the compiler will remove 0 c terms. - """ - cu = 3.0 * jnp.dot(u, jnp.array(self.velocity_set.c, dtype=rho.dtype)) - usqr = 1.5 * jnp.sum(jnp.square(u), axis=-1, keepdims=True) - feq = ( - rho - * jnp.array(self.velocity_set.w, dtype=rho.dtype) - * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) - ) - return feq - - def construct_numba(self, velocity_set: VelocitySet, dtype=numba.float32): - """ - Numba implementation of the equilibrium distribution function. - """ - # Get needed values for numba functions - q = velocity_set.q - c = velocity_set.c.T - w = velocity_set.w - - # Make numba functions - @cuda.jit(device=True) - def _equilibrium(rho, u, feq): - # Compute the equilibrium distribution function - usqr = dtype(1.5) * (u[0] * u[0] + u[1] * u[1] + u[2] * u[2]) - for i in range(q): - cu = dtype(3.0) * ( - u[0] * dtype(c[i, 0]) - + u[1] * dtype(c[i, 1]) - + u[2] * dtype(c[i, 2]) - ) - feq[i] = ( - rho[0] - * dtype(w[i]) - * (dtype(1.0) + cu * (dtype(1.0) + dtype(0.5) * cu) - usqr) - ) - - # Return the equilibrium distribution function - return feq # comma is needed for numba to return a tuple, seems like a bug in numba - - return _equilibrium diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py new file mode 100644 index 0000000..bc4282b --- /dev/null +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -0,0 +1,37 @@ +import jax.numpy as jnp +from jax import jit +from xlb.velocity_set import VelocitySet +from xlb.compute_backends import ComputeBackends +from xlb.operator.equilibrium.equilibrium import Equilibrium +from functools import partial +from xlb.operator import Operator + + +class QuadraticEquilibrium(Equilibrium): + """ + Quadratic equilibrium of Boltzmann equation using hermite polynomials. + Standard equilibrium model for LBM. + + TODO: move this to a separate file and lower and higher order equilibriums + """ + + def __init__( + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, + ): + super().__init__(velocity_set, compute_backend) + + @Operator.register_backend(ComputeBackends.JAX) + # @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + def jax_implementation(self, rho, u): + cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) + usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) + w = self.velocity_set.w.reshape(-1, 1, 1) + + feq = ( + rho + * w + * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + ) + return feq diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index ff6ae83..55121e6 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -5,7 +5,7 @@ from jax import jit from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator @@ -21,7 +21,7 @@ class Macroscopic(Operator): def __init__( self, velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, + compute_backend=ComputeBackends.JAX, ): super().__init__(velocity_set, compute_backend) diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index c6ef3d7..592961d 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,7 +1,7 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends class Operator: @@ -11,62 +11,46 @@ class Operator: This class is responsible for handling compute backends. """ + _backends = {} + def __init__(self, velocity_set, compute_backend): self.velocity_set = velocity_set self.compute_backend = compute_backend + if compute_backend not in ComputeBackends: + raise ValueError(f"Compute backend {compute_backend} is not supported") - # Check if compute backend is supported - # TODO: Fix check for compute backend - #if self.compute_backend not in self.supported_compute_backends: - # raise ValueError( - # f"Compute backend {self.compute_backend} not supported by {self.__class__.__name__}" - # ) - - def __call__(self, *args, **kwargs): + @classmethod + def register_backend(cls, backend_name): """ - Apply the operator to a input. This method will call the - appropriate apply method based on the compute backend. + Decorator to register a backend for the operator. """ - if self.compute_backend == ComputeBackend.JAX: - return self.apply_jax(*args, **kwargs) - elif self.compute_backend == ComputeBackend.NUMBA: - return self.apply_numba(*args, **kwargs) - def apply_jax(self, *args, **kwargs): - """ - Implement the operator using JAX. - If using the JAX backend, this method will then become - the self.__call__ method. - """ - raise NotImplementedError("Child class must implement apply_jax") + def decorator(func): + # Use the combination of operator name and backend name as the key + subclass_name = func.__qualname__.split(".")[0] + key = (subclass_name, backend_name) + cls._backends[key] = func + return func - def apply_numba(self, *args, **kwargs): - """ - Implement the operator using Numba. - If using the Numba backend, this method will then become - the self.__call__ method. - """ - raise NotImplementedError("Child class must implement apply_numba") + return decorator - def construct_numba(self): + def __call__(self, *args, **kwargs): """ - Constructs numba kernel for the operator + Calls the operator with the compute backend specified in the constructor. """ - raise NotImplementedError("Child class must implement apply_numba") + key = (self.__class__.__name__, self.compute_backend) + backend_method = self._backends.get(key) + if backend_method: + return backend_method(self, *args, **kwargs) + else: + raise NotImplementedError(f"Backend {self.compute_backend} not implemented") @property def supported_compute_backend(self): """ Returns the supported compute backend for the operator """ - supported_backend = [] - if self._is_method_overridden("apply_jax"): - supported_backend.append(ComputeBackend.JAX) - elif self._is_method_overridden("apply_numba"): - supported_backend.append(ComputeBackend.NUMBA) - else: - raise NotImplementedError("No supported compute backend implemented") - return supported_backend + return list(self._backends.keys()) def _is_method_overridden(self, method_name): """ @@ -79,3 +63,9 @@ def _is_method_overridden(self, method_name): def __repr__(self): return f"{self.__class__.__name__}()" + + def data_handler(self, *args, **kwargs): + """ + Handles data for the operator. + """ + raise NotImplementedError("Child class must implement data_handler") diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index 9ef1dbc..a8863b2 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -4,7 +4,7 @@ from jax import jit from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.stepper.stepper import Stepper from xlb.operator.boundary_condition import ImplementationStep @@ -37,8 +37,7 @@ def __init__( precision_policy, ) - #@partial(jit, static_argnums=(0, 5), donate_argnums=(1)) # TODO: This donate args seems to break out of core memory - @partial(jit, static_argnums=(0, 5)) + @partial(jit, static_argnums=(0,)) def apply_jax(self, f, boundary_id, mask, timestep): """ Perform a single step of the lattice boltzmann method @@ -90,4 +89,4 @@ def apply_jax(self, f, boundary_id, mask, timestep): # Copy back to store precision f = self.precision_policy.cast_to_store_jax(f_post_streaming) - return f + return f \ No newline at end of file diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index b5c0a44..8730478 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -3,7 +3,7 @@ import jax.numpy as jnp from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator from xlb.operator.boundary_condition import ImplementationStep @@ -22,7 +22,7 @@ def __init__( boundary_conditions=[], forcing=None, precision_policy=None, - compute_backend=ComputeBackend.JAX, + compute_backend=ComputeBackends.JAX, ): # Set parameters self.collision = collision diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 7d31a3b..7a4a1d8 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -6,7 +6,7 @@ import numba from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend +from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator @@ -22,7 +22,7 @@ class Stream(Operator): def __init__( self, velocity_set: VelocitySet, - compute_backend=ComputeBackend.JAX, + compute_backend=ComputeBackends.JAX, ): super().__init__(velocity_set, compute_backend) diff --git a/xlb/precision_policy/__init__.py b/xlb/precision_policy/__init__.py index c228387..996d8e9 100644 --- a/xlb/precision_policy/__init__.py +++ b/xlb/precision_policy/__init__.py @@ -1,2 +1 @@ -from xlb.precision_policy.precision_policy import PrecisionPolicy -from xlb.precision_policy.fp32fp32 import Fp32Fp32 +from xlb.precision_policy.precision_policy import PrecisionPolicy \ No newline at end of file diff --git a/xlb/precision_policy/fp32fp32.py b/xlb/precision_policy/fp32fp32.py deleted file mode 100644 index 1e37d2c..0000000 --- a/xlb/precision_policy/fp32fp32.py +++ /dev/null @@ -1,20 +0,0 @@ -# Purpose: Precision policy for lattice Boltzmann method with computation and -# storage precision both set to float32. - -import jax.numpy as jnp - -from xlb.precision_policy.precision_policy import PrecisionPolicy - - -class Fp32Fp32(PrecisionPolicy): - """ - Precision policy for lattice Boltzmann method with computation and storage - precision both set to float32. - - Parameters - ---------- - None - """ - - def __init__(self): - super().__init__(jnp.float32, jnp.float32) diff --git a/xlb/precision_policy/jax_precision_policy/___init__.py b/xlb/precision_policy/jax_precision_policy/___init__.py new file mode 100644 index 0000000..1955b79 --- /dev/null +++ b/xlb/precision_policy/jax_precision_policy/___init__.py @@ -0,0 +1 @@ +from xlb.precision_policy.jax_precision_policy import Fp32Fp16, Fp32Fp32, Fp64Fp16, Fp64Fp32, Fp64Fp64 \ No newline at end of file diff --git a/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py b/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py new file mode 100644 index 0000000..c9063c2 --- /dev/null +++ b/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py @@ -0,0 +1,72 @@ +from xlb.precision_policy.precision_policy import PrecisionPolicy +from jax import jit +from functools import partial +import jax.numpy as jnp + + +class JaxPrecisionPolicy(PrecisionPolicy): + """ + JAX-specific precision policy. + """ + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def cast_to_compute(self, array): + return array.astype(self.compute_dtype) + + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def cast_to_store(self, array): + return array.astype(self.storage_dtype) + + +class Fp32Fp32(JaxPrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation and storage + precision both set to float32. + + Parameters + ---------- + None + """ + + def __init__(self): + super().__init__(jnp.float32, jnp.float32) + + +class Fp64Fp64(JaxPrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation and storage + precision both set to float64. + """ + + def __init__(self): + super().__init__(jnp.float64, jnp.float64) + + +class Fp64Fp32(JaxPrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation precision + set to float64 and storage precision set to float32. + """ + + def __init__(self): + super().__init__(jnp.float64, jnp.float32) + + +class Fp64Fp16(JaxPrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation precision + set to float64 and storage precision set to float16. + """ + + def __init__(self): + super().__init__(jnp.float64, jnp.float16) + + +class Fp32Fp16(JaxPrecisionPolicy): + """ + Precision policy for lattice Boltzmann method with computation precision + set to float32 and storage precision set to float16. + """ + + def __init__(self): + super().__init__(jnp.float32, jnp.float16) diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 459ce74..9a4bdb0 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,159 +1,22 @@ -# Precision policy for lattice Boltzmann method -# TODO: possibly refctor this to be more general - -from functools import partial -import jax.numpy as jnp -from jax import jit -import numba -from numba import cuda - - class PrecisionPolicy(object): """ Base class for precision policy in lattice Boltzmann method. - Basic idea is to allow for storing the lattice in a different precision than the computation. - - Stores dtype in jax but also contains same information for other backends such as numba. - - Parameters - ---------- - compute_dtype: jax.numpy.dtype - The precision used for computation. - storage_dtype: jax.numpy.dtype - The precision used for storage. + Stores dtype information and provides an interface for casting operations. """ - def __init__(self, compute_dtype, storage_dtype): - # Store the dtypes (jax) self.compute_dtype = compute_dtype self.storage_dtype = storage_dtype - # Get the corresponding numba dtypes - self.compute_dtype_numba = self._get_numba_dtype(compute_dtype) - self.storage_dtype_numba = self._get_numba_dtype(storage_dtype) - - # Check that compute dtype is one of the supported dtypes (float16, float32, float64) - self.supported_compute_dtypes = [jnp.float16, jnp.float32, jnp.float64] - if self.compute_dtype not in self.supported_compute_dtypes: - raise ValueError( - f"Compute dtype {self.compute_dtype} is not supported. Supported dtypes are {self.supported_compute_dtypes}" - ) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def cast_to_compute_jax(self, array): - """ - Cast the array to the computation precision - - Parameters - ---------- - Array: jax.numpy.ndarray - The array to cast. - - Returns - ------- - jax.numpy.ndarray - The casted array + def cast_to_compute(self, array): """ - return array.astype(self.compute_dtype) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def cast_to_store_jax(self, array): + Cast the array to the computation precision. + To be implemented by subclass. """ - Cast the array to the storage precision + raise NotImplementedError - Parameters - ---------- - Array: jax.numpy.ndarray - The array to cast. - - Returns - ------- - jax.numpy.ndarray - The casted array + def cast_to_store(self, array): """ - return array.astype(self.storage_dtype) - - def cast_to_compute_numba(self): + Cast the array to the storage precision. + To be implemented by subclass. """ - Constructs a numba function to cast a value to the computation precision - - Parameters - ---------- - value: float - The value to cast. - - Returns - ------- - float - The casted value - """ - return self._cast_to_dtype_numba(self.compute_dtype_numba) - - def cast_to_store_numba(self): - """ - Constructs a numba function to cast a value to the storage precision - - Parameters - ---------- - value: float - The value to cast. - - Returns - ------- - float - The casted value - """ - return self._cast_to_dtype_numba(self.storage_dtype_numba) - - def _cast_to_dytpe_numba(self, dtype): - """ - Constructs a numba function to cast a value to the computation precision - - Parameters - ---------- - value: float - The value to cast. - - Returns - ------- - float - The casted value - """ - - @cuda.jit(device=True) - def cast_to_dtype(value): - return dtype(value) - - def _get_numba_dtype(self, dtype): - """ - Get the corresponding numba dtype - - # TODO: Make this more general - - Parameters - ---------- - dtype: jax.numpy.dtype - The dtype to convert - - Returns - ------- - numba.dtype - The corresponding numba dtype - """ - if dtype == jnp.float16: - return numba.float16 - elif dtype == jnp.float32: - return numba.float32 - elif dtype == jnp.float64: - return numba.float64 - elif dtype == jnp.int32: - return numba.int32 - elif dtype == jnp.int64: - return numba.int64 - elif dtype == jnp.int16: - return numba.int16 - else: - raise ValueError(f"Unsupported dtype {dtype}") - - def __repr__(self): - return f"compute_dtype={self.compute_dtype}/{self.storage_dtype}" + raise NotImplementedError \ No newline at end of file diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index d5f8793..9b09dd4 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -10,7 +10,7 @@ class D2Q9(VelocitySet): Velocity Set for 2D D2Q9 lattice. D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the - Lat tice Boltzmann Method for simulating fluid flows in two dimensions. + Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ def __init__(self): From c0b30fc212a48a1d90bc0d145d0a80f3d7790219 Mon Sep 17 00:00:00 2001 From: Oliver Date: Tue, 23 Jan 2024 14:00:20 -0800 Subject: [PATCH 004/144] small LBM example for functional programing --- examples/backend_comparisons/small_example.py | 300 ++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 examples/backend_comparisons/small_example.py diff --git a/examples/backend_comparisons/small_example.py b/examples/backend_comparisons/small_example.py new file mode 100644 index 0000000..23ca639 --- /dev/null +++ b/examples/backend_comparisons/small_example.py @@ -0,0 +1,300 @@ +# Simple example of functions to generate a warp kernel for LBM + +import warp as wp +import numpy as np + +# Initialize Warp +wp.init() + +def make_warp_kernel( + velocity_weight, + velocity_set, + dtype=wp.float32, + dim=3, # slightly hard coded for 3d right now + q=19, +): + + # Make needed vector classes + lattice_vec = wp.vec(q, dtype=dtype) + velocity_vec = wp.vec(dim, dtype=dtype) + + # Make array type + if dim == 2: + array_type = wp.array3d(dtype=dtype) + elif dim == 3: + array_type = wp.array4d(dtype=dtype) + + # Make everything constant + velocity_weight = wp.constant(velocity_weight) + velocity_set = wp.constant(velocity_set) + q = wp.constant(q) + dim = wp.constant(dim) + + # Make function for computing exu + @wp.func + def compute_exu(u: velocity_vec): + exu = lattice_vec() + for _ in range(q): + for d in range(dim): + if velocity_set[_, d] == 1: + exu[_] += u[d] + elif velocity_set[_, d] == -1: + exu[_] -= u[d] + return exu + + # Make function for computing feq + @wp.func + def compute_feq( + p: dtype, + uxu: dtype, + exu: lattice_vec, + ): + factor_1 = 1.5 + factor_2 = 4.5 + feq = lattice_vec() + for _ in range(q): + feq[_] = ( + velocity_weight[_] * p * ( + 1.0 + + factor_1 * (2.0 * exu[_] - uxu) + + factor_2 * exu[_] * exu[_] + ) + ) + return feq + + # Make function for computing u and p + @wp.func + def compute_u_and_p(f: lattice_vec): + p = wp.float32(0.0) + u = velocity_vec() + for d in range(dim): + u[d] = wp.float32(0.0) + for _ in range(q): + p += f[_] + for d in range(dim): + if velocity_set[_, d] == 1: + u[d] += f[_] + elif velocity_set[_, d] == -1: + u[d] -= f[_] + u /= p + return u, p + + # Make function for getting stream index + @wp.func + def get_streamed_index( + i: int, + x: int, + y: int, + z: int, + width: int, + height: int, + length: int, + ): + streamed_x = x + velocity_set[i, 0] + streamed_y = y + velocity_set[i, 1] + streamed_z = z + velocity_set[i, 2] + if streamed_x == -1: # TODO hacky + streamed_x = width - 1 + if streamed_y == -1: + streamed_y = height - 1 + if streamed_z == -1: + streamed_z = length - 1 + if streamed_x == width: + streamed_x = 0 + if streamed_y == height: + streamed_y = 0 + if streamed_z == length: + streamed_z = 0 + return streamed_x, streamed_y, streamed_z + + # Make kernel for stream and collide + @wp.kernel + def collide_stream( + f0: array_type, + f1: array_type, + width: int, + height: int, + length: int, + tau: float, + ): + + # Get indices (TODO: no good way to do variable dimension indexing) + f = lattice_vec() + x, y, z = wp.tid() + for i in range(q): + f[i] = f0[i, x, y, z] + + # Compute p and u + u, p = compute_u_and_p(f) + + # get uxu + uxu = wp.dot(u, u) + + # Compute velocity_set dot u + exu = compute_exu(u) + + # Compute equilibrium + feq = compute_feq(p, uxu, exu) + + # Set value + new_f = f - (f - feq) / tau + for i in range(q): + (streamed_x, streamed_y, streamed_z) = get_streamed_index( + i, x, y, z, width, height, length + ) + f1[i, streamed_x, streamed_y, streamed_z] = new_f[i] + + # make kernel for initialization + @wp.kernel + def initialize_taylor_green( + f0: array_type, + dx: float, + vel: float, + width: int, + height: int, + length: int, + tau: float, + ): + + # Get indices (TODO: no good way to do variable dimension indexing) + i, j, k = wp.tid() + + # Get real coordinates + x = wp.float(i) * dx + y = wp.float(j) * dx + z = wp.float(k) * dx + + # Compute velocity + u = velocity_vec() + u[0] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) + u[1] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + u[2] = 0.0 + + # Compute p + p = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * ( + wp.cos(2.0 * x) + + wp.cos(2.0 * y) + + wp.cos(2.0 * z) + ) + + 1.0 + ) + + # Compute uxu + uxu = wp.dot(u, u) + + # Compute velocity_set dot u + exu = compute_exu(u) + + # Compute equilibrium + feq = compute_feq(p, uxu, exu) + + # Set value + for _ in range(q): + f0[_, i, j, k] = feq[_] + + return collide_stream, initialize_taylor_green + +def plt_f(f): + import matplotlib.pyplot as plt + plt.imshow(f.numpy()[3, :, :, f.shape[3] // 4]) + plt.show() + +if __name__ == "__main__": + + # Parameters + n = 256 + tau = 0.505 + dim = 3 + q = 19 + lattice_dtype = wp.float32 + lattice_vec = wp.vec(q, dtype=lattice_dtype) + + # Make arrays + f0 = wp.empty((q, n, n, n), dtype=lattice_dtype, device="cuda:0") + f1 = wp.empty((q, n, n, n), dtype=lattice_dtype, device="cuda:0") + + # Make velocity set + velocity_weight = wp.vec(q, dtype=lattice_dtype)( + [1.0/3.0] + [1.0/18.0] * 6 + [1.0/36.0] * 12 + ) + velocity_set = wp.mat((q, dim), dtype=wp.int32)( + [ + [0, 0, 0], + [1, 0, 0], + [-1, 0, 0], + [0, 1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, -1], + [0, 1, 1], + [0, -1, -1], + [0, 1, -1], + [0, -1, 1], + [1, 0, 1], + [-1, 0, -1], + [1, 0, -1], + [-1, 0, 1], + [1, 1, 0], + [-1, -1, 0], + [1, -1, 0], + [-1, 1, 0], + ] + ) + + # Make kernel + collide_stream, initialize = make_warp_kernel( + velocity_weight, + velocity_set, + dtype=lattice_dtype, + dim=dim, + q=q, + ) + + # Initialize + cs = 1.0 / np.sqrt(3.0) + vel = 0.1 * cs + dx = 2.0 * np.pi / n + wp.launch( + initialize, + inputs=[ + f0, + dx, + vel, + n, + n, + n, + tau, + ], + dim=(n, n, n), + ) + + # Compute MLUPS + import time + import tqdm + nr_iterations = 128 + start = time.time() + for i in tqdm.tqdm(range(nr_iterations)): + #if i % 10 == 0: + # plt_f(f0) + + wp.launch( + collide_stream, + inputs=[ + f0, + f1, + n, + n, + n, + tau, + ], + dim=(n, n, n), + ) + f0, f1 = f1, f0 + wp.synchronize() + end = time.time() + print("MLUPS: ", (nr_iterations * n * n * n) / (end - start) / 1e6) From a05441f173a557bcd22e3fe5cffedb4b600bcce5 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Fri, 26 Jan 2024 17:21:38 -0500 Subject: [PATCH 005/144] example_mehdi works with no error --- examples/refactor/example_mehdi.py | 37 ++++++ test.py | 40 +++++++ xlb/__init__.py | 12 +- xlb/global_config.py | 10 ++ xlb/grid/__init__.py | 1 + xlb/grid/grid.py | 34 ++++++ xlb/grid/jax_grid.py | 46 ++++++++ xlb/operator/collision/bgk.py | 6 +- xlb/operator/collision/collision.py | 5 +- xlb/operator/collision/kbc.py | 14 ++- xlb/operator/equilibrium/equilibrium.py | 5 +- .../equilibrium/quadratic_equilibrium.py | 16 +-- xlb/operator/initializer/__init__.py | 2 + xlb/operator/initializer/const_init.py | 28 +++++ xlb/operator/initializer/equilibrium_init.py | 28 +++++ xlb/operator/macroscopic/macroscopic.py | 17 ++- xlb/operator/operator.py | 33 +++--- xlb/operator/stepper/__init__.py | 2 - xlb/operator/stepper/nse.py | 92 --------------- xlb/operator/stepper/stepper.py | 84 -------------- xlb/operator/stream/stream.py | 95 ++++++++------- xlb/precision_policy/__init__.py | 2 +- xlb/precision_policy/base_precision_policy.py | 14 +++ .../jax_precision_policy.py | 12 +- .../jax_precision_policy/___init__.py | 1 - xlb/precision_policy/precision_policy.py | 81 +++++++++---- xlb/solver/__init__.py | 1 + xlb/solver/nse.py | 108 ++++++++++++++++++ xlb/solver/solver.py | 37 ++++++ 29 files changed, 573 insertions(+), 290 deletions(-) create mode 100644 examples/refactor/example_mehdi.py create mode 100644 test.py create mode 100644 xlb/global_config.py create mode 100644 xlb/grid/__init__.py create mode 100644 xlb/grid/grid.py create mode 100644 xlb/grid/jax_grid.py create mode 100644 xlb/operator/initializer/__init__.py create mode 100644 xlb/operator/initializer/const_init.py create mode 100644 xlb/operator/initializer/equilibrium_init.py delete mode 100644 xlb/operator/stepper/__init__.py delete mode 100644 xlb/operator/stepper/nse.py delete mode 100644 xlb/operator/stepper/stepper.py create mode 100644 xlb/precision_policy/base_precision_policy.py rename xlb/precision_policy/{jax_precision_policy => }/jax_precision_policy.py (85%) delete mode 100644 xlb/precision_policy/jax_precision_policy/___init__.py create mode 100644 xlb/solver/__init__.py create mode 100644 xlb/solver/nse.py create mode 100644 xlb/solver/solver.py diff --git a/examples/refactor/example_mehdi.py b/examples/refactor/example_mehdi.py new file mode 100644 index 0000000..b758d3e --- /dev/null +++ b/examples/refactor/example_mehdi.py @@ -0,0 +1,37 @@ +import xlb +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 + +from xlb.solver import IncompressibleNavierStokes +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.stream import Stream +from xlb.global_config import GlobalConfig +from xlb.grid import Grid +from xlb.operator.initializer import EquilibriumInitializer, ConstInitializer + +import numpy as np +import jax.numpy as jnp + +xlb.init(precision_policy=Fp32Fp32, compute_backend=ComputeBackends.JAX, velocity_set=xlb.velocity_set.D2Q9) + +grid_shape = (100, 100) +grid = Grid.create(grid_shape) + +f_init = grid.create_field(cardinality=9, callback=EquilibriumInitializer(grid)) + +u_init = grid.create_field(cardinality=2, callback=ConstInitializer(grid, cardinality=2, const_value=0.0)) +rho_init = grid.create_field(cardinality=1, callback=ConstInitializer(grid, cardinality=1, const_value=1.0)) + + +st = Stream(grid) + +f_init = st(f_init) +print("here") +solver = IncompressibleNavierStokes(grid) + +num_steps = 100 +f = f_init +for step in range(num_steps): + f = solver.step(f, timestep=step) + print(f"Step {step+1}/{num_steps} complete") + diff --git a/test.py b/test.py new file mode 100644 index 0000000..5529ac3 --- /dev/null +++ b/test.py @@ -0,0 +1,40 @@ +import jax.numpy as jnp +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.stream import Stream +from xlb.velocity_set import D2Q9, D3Q27 +from xlb.operator.collision.kbc import KBC, BGK +from xlb.compute_backends import ComputeBackends +from xlb.grid import Grid +import xlb + + +xlb.init(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +collision = BGK(omega=0.6) + +# eq = QuadraticEquilibrium(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +# macro = Macroscopic(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +# s = Stream(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +Q = 19 +# create random jnp arrays +f = jnp.ones((Q, 10, 10)) +rho = jnp.ones((1, 10, 10)) +u = jnp.zeros((2, 10, 10)) +# feq = eq(rho, u) + +print(collision(f, f)) + +grid = Grid.create(grid_shape=(10, 10), velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +def advection_result(index): + return 1.0 + + +f = grid.initialize_pop(advection_result) + +print(f) +print(f.sharding) \ No newline at end of file diff --git a/xlb/__init__.py b/xlb/__init__.py index e4edc5c..7f13a8e 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -2,6 +2,11 @@ from xlb.compute_backends import ComputeBackends from xlb.physics_type import PhysicsType + +# Config +from .global_config import init + + # Precision policy import xlb.precision_policy @@ -15,4 +20,9 @@ import xlb.operator.boundary_condition # import xlb.operator.force import xlb.operator.macroscopic -import xlb.operator.stepper + +# Grids +import xlb.grid + +# Solvers +import xlb.solver \ No newline at end of file diff --git a/xlb/global_config.py b/xlb/global_config.py new file mode 100644 index 0000000..dd3e705 --- /dev/null +++ b/xlb/global_config.py @@ -0,0 +1,10 @@ +class GlobalConfig: + precision_policy = None + velocity_set = None + compute_backend = None + + +def init(velocity_set, compute_backend, precision_policy): + GlobalConfig.velocity_set = velocity_set() + GlobalConfig.compute_backend = compute_backend + GlobalConfig.precision_policy = precision_policy() diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py new file mode 100644 index 0000000..583b72e --- /dev/null +++ b/xlb/grid/__init__.py @@ -0,0 +1 @@ +from xlb.grid.grid import Grid \ No newline at end of file diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py new file mode 100644 index 0000000..7a03950 --- /dev/null +++ b/xlb/grid/grid.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from xlb.compute_backends import ComputeBackends +from xlb.global_config import GlobalConfig +from xlb.velocity_set import VelocitySet + + +class Grid(ABC): + def __init__(self, grid_shape, velocity_set, compute_backend): + self.velocity_set: VelocitySet = velocity_set + self.compute_backend = compute_backend + self.grid_shape = grid_shape + self.pop_shape = (self.velocity_set.q, *grid_shape) + self.u_shape = (self.velocity_set.d, *grid_shape) + self.rho_shape = (1, *grid_shape) + self.dim = self.velocity_set.d + + @abstractmethod + def create_field(self, cardinality, callback=None): + pass + + @staticmethod + def create(grid_shape, velocity_set=None, compute_backend=None): + compute_backend = compute_backend or GlobalConfig.compute_backend + velocity_set = velocity_set or GlobalConfig.velocity_set + + if compute_backend == ComputeBackends.JAX: + from xlb.grid.jax_grid import JaxGrid # Avoids circular import + + return JaxGrid(grid_shape, velocity_set, compute_backend) + raise ValueError(f"Compute backend {compute_backend} is not supported") + + @abstractmethod + def global_to_local_shape(self, shape): + pass diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py new file mode 100644 index 0000000..57270c5 --- /dev/null +++ b/xlb/grid/jax_grid.py @@ -0,0 +1,46 @@ +from xlb.grid.grid import Grid +from xlb.compute_backends import ComputeBackends +from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, Mesh +from jax.experimental import mesh_utils +from xlb.operator.initializer import ConstInitializer +import jax + + +class JaxGrid(Grid): + def __init__(self, grid_shape, velocity_set, compute_backend): + super().__init__(grid_shape, velocity_set, compute_backend) + self.initialize_jax_backend() + + def initialize_jax_backend(self): + self.nDevices = jax.device_count() + self.backend = jax.default_backend() + device_mesh = ( + mesh_utils.create_device_mesh((1, self.nDevices, 1)) + if self.dim == 2 + else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) + ) + self.global_mesh = ( + Mesh(device_mesh, axis_names=("cardinality", "x", "y")) + if self.dim == 2 + else Mesh(self.devices, axis_names=("cardinality", "x", "y", "z")) + ) + self.sharding = ( + NamedSharding(self.global_mesh, P("cardinality", "x", "y")) + if self.dim == 2 + else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) + ) + + def global_to_local_shape(self, shape): + if len(shape) < 2: + raise ValueError("Shape must have at least two dimensions") + + new_second_index = shape[1] // self.nDevices + + return shape[:1] + (new_second_index,) + shape[2:] + + def create_field(self, cardinality, callback=None): + if callback is None: + callback = ConstInitializer(self, cardinality, const_value=0.0) + shape = (cardinality,) + (self.grid_shape) + return jax.make_array_from_callback(shape, self.sharding, callback) diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 90943b8..9dfdf33 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -15,8 +15,8 @@ class BGK(Collision): def __init__( self, omega: float, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__( omega=omega, velocity_set=velocity_set, compute_backend=compute_backend @@ -24,7 +24,7 @@ def __init__( @Operator.register_backend(ComputeBackends.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation_2(self, f: jnp.ndarray, feq: jnp.ndarray): + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq return fout diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 9c22895..3243e6f 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -1,7 +1,6 @@ """ Base class for Collision operators """ -from xlb.compute_backends import ComputeBackends from xlb.velocity_set import VelocitySet from xlb.operator import Operator @@ -23,8 +22,8 @@ class Collision(Operator): def __init__( self, omega: float, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__(velocity_set, compute_backend) self.omega = omega diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index e79b6bd..1b24a6b 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -22,8 +22,8 @@ class KBC(Collision): def __init__( self, omega, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__( omega=omega, velocity_set=velocity_set, compute_backend=compute_backend @@ -75,6 +75,16 @@ def jax_implementation( return fout + @Operator.register_backend(ComputeBackends.WARP) + @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) + def warp_implementation( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + ): + raise NotImplementedError("Warp implementation not yet implemented") + @partial(jit, static_argnums=(0,), inline=True) def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarray): """ diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py index f60d9ee..4bed600 100644 --- a/xlb/operator/equilibrium/equilibrium.py +++ b/xlb/operator/equilibrium/equilibrium.py @@ -1,6 +1,5 @@ # Base class for all equilibriums from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator @@ -11,7 +10,7 @@ class Equilibrium(Operator): def __init__( self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__(velocity_set, compute_backend) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index bc4282b..1fd7458 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -5,6 +5,7 @@ from xlb.operator.equilibrium.equilibrium import Equilibrium from functools import partial from xlb.operator import Operator +from xlb.global_config import GlobalConfig class QuadraticEquilibrium(Equilibrium): @@ -17,21 +18,20 @@ class QuadraticEquilibrium(Equilibrium): def __init__( self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): + velocity_set = velocity_set or GlobalConfig.velocity_set + compute_backend = compute_backend or GlobalConfig.compute_backend + super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) - # @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) w = self.velocity_set.w.reshape(-1, 1, 1) - feq = ( - rho - * w - * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) - ) + feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq diff --git a/xlb/operator/initializer/__init__.py b/xlb/operator/initializer/__init__.py new file mode 100644 index 0000000..4d3f07d --- /dev/null +++ b/xlb/operator/initializer/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer +from xlb.operator.initializer.const_init import ConstInitializer \ No newline at end of file diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py new file mode 100644 index 0000000..d3b13cb --- /dev/null +++ b/xlb/operator/initializer/const_init.py @@ -0,0 +1,28 @@ +from xlb.velocity_set import VelocitySet +from xlb.global_config import GlobalConfig +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator +from xlb.grid.grid import Grid +import numpy as np +import jax + + +class ConstInitializer(Operator): + def __init__( + self, + grid: Grid, + cardinality, + const_value=0.0, + velocity_set: VelocitySet = None, + compute_backend: ComputeBackends = None, + ): + velocity_set = velocity_set or GlobalConfig.velocity_set + compute_backend = compute_backend or GlobalConfig.compute_backend + shape = (cardinality,) + (grid.grid_shape) + self.init_values = np.zeros(grid.global_to_local_shape(shape)) + const_value + + super().__init__(velocity_set, compute_backend) + + @Operator.register_backend(ComputeBackends.JAX) + def jax_implementation(self, index): + return self.init_values diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py new file mode 100644 index 0000000..9d9dc56 --- /dev/null +++ b/xlb/operator/initializer/equilibrium_init.py @@ -0,0 +1,28 @@ +from xlb.velocity_set import VelocitySet +from xlb.global_config import GlobalConfig +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator +from xlb.grid.grid import Grid +import numpy as np +import jax + + +class EquilibriumInitializer(Operator): + def __init__( + self, + grid: Grid, + velocity_set: VelocitySet = None, + compute_backend: ComputeBackends = None, + ): + velocity_set = velocity_set or GlobalConfig.velocity_set + compute_backend = compute_backend or GlobalConfig.compute_backend + local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) + self.init_values = np.zeros( + grid.global_to_local_shape(grid.pop_shape) + ) + velocity_set.w.reshape(local_shape) + + super().__init__(velocity_set, compute_backend) + + @Operator.register_backend(ComputeBackends.JAX) + def jax_implementation(self, index): + return self.init_values diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 55121e6..fc04db2 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -19,20 +19,19 @@ class Macroscopic(Operator): """ def __init__( - self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, - ): + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, + ): super().__init__(velocity_set, compute_backend) + @Operator.register_backend(ComputeBackends.JAX) @partial(jit, static_argnums=(0), inline=True) - def apply_jax(self, f): + def jax_implementation(self, f): """ Apply the macroscopic operator to the lattice distribution function """ - - rho = jnp.sum(f, axis=-1, keepdims=True) - c = jnp.array(self.velocity_set.c, dtype=f.dtype).T - u = jnp.dot(f, c) / rho + rho = jnp.sum(f, axis=0, keepdims=True) + u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho return rho, u diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 592961d..700de7e 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,9 +1,7 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) -from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backends import ComputeBackends - - +from xlb.global_config import GlobalConfig class Operator: """ Base class for all operators, collision, streaming, equilibrium, etc. @@ -14,9 +12,11 @@ class Operator: _backends = {} def __init__(self, velocity_set, compute_backend): - self.velocity_set = velocity_set - self.compute_backend = compute_backend - if compute_backend not in ComputeBackends: + # Set the default values from the global config + self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.compute_backend = compute_backend or GlobalConfig.compute_backend + + if self.compute_backend not in ComputeBackends: raise ValueError(f"Compute backend {compute_backend} is not supported") @classmethod @@ -34,14 +34,25 @@ def decorator(func): return decorator - def __call__(self, *args, **kwargs): + def __call__(self, *args, callback=None, **kwargs): """ Calls the operator with the compute backend specified in the constructor. + If a callback is provided, it is called either with the result of the operation + or with the original arguments and keyword arguments if the backend modifies them by reference. """ key = (self.__class__.__name__, self.compute_backend) backend_method = self._backends.get(key) + if backend_method: - return backend_method(self, *args, **kwargs) + result = backend_method(self, *args, **kwargs) + + # Determine what to pass to the callback based on the backend behavior + callback_arg = result if result is not None else (args, kwargs) + + if callback and callable(callback): + callback(callback_arg) + + return result else: raise NotImplementedError(f"Backend {self.compute_backend} not implemented") @@ -63,9 +74,3 @@ def _is_method_overridden(self, method_name): def __repr__(self): return f"{self.__class__.__name__}()" - - def data_handler(self, *args, **kwargs): - """ - Handles data for the operator. - """ - raise NotImplementedError("Child class must implement data_handler") diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py deleted file mode 100644 index 0469f13..0000000 --- a/xlb/operator/stepper/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from xlb.operator.stepper.stepper import Stepper -from xlb.operator.stepper.nse import NSE diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py deleted file mode 100644 index a8863b2..0000000 --- a/xlb/operator/stepper/nse.py +++ /dev/null @@ -1,92 +0,0 @@ -# Base class for all stepper operators - -from functools import partial -from jax import jit - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends -from xlb.operator.stepper.stepper import Stepper -from xlb.operator.boundary_condition import ImplementationStep - - -class NSE(Stepper): - """ - Class that handles the construction of lattice boltzmann stepping operator for the Navier-Stokes equations - - TODO: Check that the given operators (collision, stream, equilibrium, macroscopic, ...) are compatible - with the Navier-Stokes equations - """ - - def __init__( - self, - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions=[], - forcing=None, - precision_policy=None, - ): - super().__init__( - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions, - forcing, - precision_policy, - ) - - @partial(jit, static_argnums=(0,)) - def apply_jax(self, f, boundary_id, mask, timestep): - """ - Perform a single step of the lattice boltzmann method - """ - - # Cast to compute precision - f_pre_collision = self.precision_policy.cast_to_compute_jax(f) - - # Compute the macroscopic variables - rho, u = self.macroscopic(f_pre_collision) - - # Compute equilibrium - feq = self.equilibrium(rho, u) - - # Apply collision - f_post_collision = self.collision( - f, - feq, - rho, - u, - ) - - # Apply collision type boundary conditions - for id_number, bc in self.collision_boundary_conditions.items(): - f_post_collision = bc( - f_pre_collision, - f_post_collision, - boundary_id == id_number, - mask, - ) - f_pre_streaming = f_post_collision - - ## Apply forcing - # if self.forcing_op is not None: - # f = self.forcing_op.apply_jax(f, timestep) - - # Apply streaming - f_post_streaming = self.stream(f_pre_streaming) - - # Apply boundary conditions - for id_number, bc in self.stream_boundary_conditions.items(): - f_post_streaming = bc( - f_pre_streaming, - f_post_streaming, - boundary_id == id_number, - mask, - ) - - # Copy back to store precision - f = self.precision_policy.cast_to_store_jax(f_post_streaming) - - return f \ No newline at end of file diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py deleted file mode 100644 index 8730478..0000000 --- a/xlb/operator/stepper/stepper.py +++ /dev/null @@ -1,84 +0,0 @@ -# Base class for all stepper operators - -import jax.numpy as jnp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.operator.boundary_condition import ImplementationStep - - -class Stepper(Operator): - """ - Class that handles the construction of lattice boltzmann stepping operator - """ - - def __init__( - self, - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions=[], - forcing=None, - precision_policy=None, - compute_backend=ComputeBackends.JAX, - ): - # Set parameters - self.collision = collision - self.stream = stream - self.equilibrium = equilibrium - self.macroscopic = macroscopic - self.boundary_conditions = boundary_conditions - self.forcing = forcing - self.precision_policy = precision_policy - self.compute_backend = compute_backend - - # Get collision and stream boundary conditions - self.collision_boundary_conditions = {} - self.stream_boundary_conditions = {} - for id_number, bc in enumerate(self.boundary_conditions): - bc_id = id_number + 1 - if bc.implementation_step == ImplementationStep.COLLISION: - self.collision_boundary_conditions[bc_id] = bc - elif bc.implementation_step == ImplementationStep.STREAMING: - self.stream_boundary_conditions[bc_id] = bc - else: - raise ValueError("Boundary condition step not recognized") - - # Get all operators for checking - self.operators = [ - collision, - stream, - equilibrium, - macroscopic, - *boundary_conditions, - ] - - # Get velocity set and backend - velocity_sets = set([op.velocity_set for op in self.operators]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - self.velocity_set = velocity_sets.pop() - compute_backends = set([op.compute_backend for op in self.operators]) - assert len(compute_backends) == 1, "All compute backends must be the same" - self.compute_backend = compute_backends.pop() - - # Initialize operator - super().__init__(self.velocity_set, self.compute_backend) - - def set_boundary(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - # Empty boundary condition array - boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) - mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) - - # Set boundary condition arrays - for id_number, bc in self.collision_boundary_conditions.items(): - boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) - for id_number, bc in self.stream_boundary_conditions.items(): - boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) - - return boundary_id, mask diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 7a4a1d8..277e9a1 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -2,32 +2,27 @@ from functools import partial import jax.numpy as jnp -from jax import jit, vmap -import numba +from jax import jit, vmap, lax from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P class Stream(Operator): """ Base class for all streaming operators. - - TODO: Currently only this one streaming operator is implemented but - in the future we may have more streaming operators. For example, - one might want a multi-step streaming operator. """ - def __init__( - self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, - ): + def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None): + self.grid = grid super().__init__(velocity_set, compute_backend) - @partial(jit, static_argnums=(0), donate_argnums=(1,)) - def apply_jax(self, f): + @Operator.register_backend(ComputeBackends.JAX) + # @partial(jit, static_argnums=(0)) + def jax_implementation(self, f): """ JAX implementation of the streaming step. @@ -36,8 +31,18 @@ def apply_jax(self, f): f: jax.numpy.ndarray The distribution function. """ - - def _streaming(f, c): + in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) + out_specs = in_specs + return shard_map( + self._streaming_jax_m, + mesh=self.grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + )(f) + + def _streaming_jax_p(self, f): + def _streaming_jax_i(f, c): """ Perform individual streaming operation in a direction. @@ -56,33 +61,45 @@ def _streaming(f, c): elif self.velocity_set.d == 3: return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) - return vmap(_streaming, in_axes=(-1, 0), out_axes=-1)( + return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T ) - def construct_numba(self, dtype=numba.float32): + def _streaming_jax_m(self, f): """ - Numba implementation of the streaming step. + This function performs the streaming step in the Lattice Boltzmann Method, which is + the propagation of the distribution functions in the lattice. + + To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the + distribution functions that need to be communicated to the neighboring processes. + + The function then sends the left boundary slice to the right neighboring process and the right + boundary slice to the left neighboring process. The received data is then set to the + corresponding indices in the receiving domain. + + Parameters + ---------- + f: jax.numpy.ndarray + The array holding the distribution functions for the simulation. + + Returns + ------- + jax.numpy.ndarray + The distribution functions after the streaming operation. """ + rightPerm = [(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)] + leftPerm = [((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)] + + f = self._streaming_jax_p(f) + left_comm, right_comm = ( + f[self.velocity_set.right_indices, :1, ...], + f[self.velocity_set.left_indices, -1:, ...], + ) + left_comm, right_comm = ( + lax.ppermute(left_comm, perm=rightPerm, axis_name='x'), + lax.ppermute(right_comm, perm=leftPerm, axis_name='x'), + ) + f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) + f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) - # Get needed values for numba functions - d = velocity_set.d - q = velocity_set.q - c = velocity_set.c.T - - # Make numba functions - @cuda.jit(device=True) - def _streaming(f_array, f, ijk): - # Stream to the next node - for _ in range(q): - if d == 2: - i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] - j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] - f_array[i, j, _] = f[_] - else: - i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] - j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] - k = (ijk[2] + int32(c[_, 2])) % f_array.shape[2] - f_array[i, j, k, _] = f[_] - - return _streaming + return f diff --git a/xlb/precision_policy/__init__.py b/xlb/precision_policy/__init__.py index 996d8e9..b1555aa 100644 --- a/xlb/precision_policy/__init__.py +++ b/xlb/precision_policy/__init__.py @@ -1 +1 @@ -from xlb.precision_policy.precision_policy import PrecisionPolicy \ No newline at end of file +from xlb.precision_policy.precision_policy import Fp64Fp64, Fp64Fp32, Fp32Fp32, Fp64Fp16, Fp32Fp16 \ No newline at end of file diff --git a/xlb/precision_policy/base_precision_policy.py b/xlb/precision_policy/base_precision_policy.py new file mode 100644 index 0000000..397e48d --- /dev/null +++ b/xlb/precision_policy/base_precision_policy.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + +class PrecisionPolicy(ABC): + def __init__(self, compute_dtype, storage_dtype): + self.compute_dtype = compute_dtype + self.storage_dtype = storage_dtype + + @abstractmethod + def cast_to_compute(self, array): + pass + + @abstractmethod + def cast_to_store(self, array): + pass diff --git a/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py b/xlb/precision_policy/jax_precision_policy.py similarity index 85% rename from xlb/precision_policy/jax_precision_policy/jax_precision_policy.py rename to xlb/precision_policy/jax_precision_policy.py index c9063c2..bd63adf 100644 --- a/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py +++ b/xlb/precision_policy/jax_precision_policy.py @@ -1,4 +1,4 @@ -from xlb.precision_policy.precision_policy import PrecisionPolicy +from xlb.precision_policy.base_precision_policy import PrecisionPolicy from jax import jit from functools import partial import jax.numpy as jnp @@ -18,7 +18,7 @@ def cast_to_store(self, array): return array.astype(self.storage_dtype) -class Fp32Fp32(JaxPrecisionPolicy): +class JaxFp32Fp32(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation and storage precision both set to float32. @@ -32,7 +32,7 @@ def __init__(self): super().__init__(jnp.float32, jnp.float32) -class Fp64Fp64(JaxPrecisionPolicy): +class JaxFp64Fp64(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation and storage precision both set to float64. @@ -42,7 +42,7 @@ def __init__(self): super().__init__(jnp.float64, jnp.float64) -class Fp64Fp32(JaxPrecisionPolicy): +class JaxFp64Fp32(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation precision set to float64 and storage precision set to float32. @@ -52,7 +52,7 @@ def __init__(self): super().__init__(jnp.float64, jnp.float32) -class Fp64Fp16(JaxPrecisionPolicy): +class JaxFp64Fp16(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation precision set to float64 and storage precision set to float16. @@ -62,7 +62,7 @@ def __init__(self): super().__init__(jnp.float64, jnp.float16) -class Fp32Fp16(JaxPrecisionPolicy): +class JaxFp32Fp16(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation precision set to float32 and storage precision set to float16. diff --git a/xlb/precision_policy/jax_precision_policy/___init__.py b/xlb/precision_policy/jax_precision_policy/___init__.py deleted file mode 100644 index 1955b79..0000000 --- a/xlb/precision_policy/jax_precision_policy/___init__.py +++ /dev/null @@ -1 +0,0 @@ -from xlb.precision_policy.jax_precision_policy import Fp32Fp16, Fp32Fp32, Fp64Fp16, Fp64Fp32, Fp64Fp64 \ No newline at end of file diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 9a4bdb0..75f4434 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,22 +1,59 @@ -class PrecisionPolicy(object): - """ - Base class for precision policy in lattice Boltzmann method. - Stores dtype information and provides an interface for casting operations. - """ - def __init__(self, compute_dtype, storage_dtype): - self.compute_dtype = compute_dtype - self.storage_dtype = storage_dtype - - def cast_to_compute(self, array): - """ - Cast the array to the computation precision. - To be implemented by subclass. - """ - raise NotImplementedError - - def cast_to_store(self, array): - """ - Cast the array to the storage precision. - To be implemented by subclass. - """ - raise NotImplementedError \ No newline at end of file +from abc import ABC, abstractmethod +from xlb.compute_backends import ComputeBackends +from xlb.global_config import GlobalConfig + +from xlb.precision_policy.jax_precision_policy import ( + JaxFp32Fp32, + JaxFp32Fp16, + JaxFp64Fp64, + JaxFp64Fp32, + JaxFp64Fp16, +) + + +class Fp64Fp64: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp64Fp64() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + +class Fp64Fp32: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp64Fp32() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + + +class Fp32Fp32: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp32Fp32() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + + +class Fp64Fp16: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp64Fp16() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + +class Fp32Fp16: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp32Fp16() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) \ No newline at end of file diff --git a/xlb/solver/__init__.py b/xlb/solver/__init__.py new file mode 100644 index 0000000..62dfc30 --- /dev/null +++ b/xlb/solver/__init__.py @@ -0,0 +1 @@ +from xlb.solver.nse import IncompressibleNavierStokes diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py new file mode 100644 index 0000000..96a529c --- /dev/null +++ b/xlb/solver/nse.py @@ -0,0 +1,108 @@ +# Base class for all stepper operators + +from functools import partial +from jax import jit + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backends import ComputeBackends +from xlb.operator.boundary_condition import ImplementationStep +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.collision import BGK, KBC +from xlb.operator.stream import Stream +from xlb.operator.macroscopic import Macroscopic +from xlb.solver.solver import Solver +from xlb.operator import Operator + + +class IncompressibleNavierStokes(Solver): + def __init__( + self, + grid, + velocity_set: VelocitySet = None, + compute_backend=None, + precision_policy=None, + boundary_conditions=[], + collision_kernel="BGK", + ): + self.grid = grid + self.collision_kernel = collision_kernel + super().__init__(velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, boundary_conditions=boundary_conditions) + self.create_operators() + + # Operators + def create_operators(self): + self.macroscopic = Macroscopic( + velocity_set=self.velocity_set, compute_backend=self.compute_backend + ) + self.equilibrium = QuadraticEquilibrium( + velocity_set=self.velocity_set, compute_backend=self.compute_backend + ) + self.collision = ( + KBC( + omega=1.0, + velocity_set=self.velocity_set, + compute_backend=self.compute_backend, + ) + if self.collision_kernel == "KBC" + else BGK( + omega=1.0, + velocity_set=self.velocity_set, + compute_backend=self.compute_backend, + ) + ) + self.stream = Stream(self.grid, + velocity_set=self.velocity_set, compute_backend=self.compute_backend + ) + + @Operator.register_backend(ComputeBackends.JAX) + @partial(jit, static_argnums=(0,)) + def step(self, f, timestep): + """ + Perform a single step of the lattice boltzmann method + """ + + # Cast to compute precision + f = self.precision_policy.cast_to_compute(f) + + # Compute the macroscopic variables + rho, u = self.macroscopic(f) + + # Compute equilibrium + feq = self.equilibrium(rho, u) + + # Apply collision + f_post_collision = self.collision( + f, + feq, + ) + + # # Apply collision type boundary conditions + # for id_number, bc in self.collision_boundary_conditions.items(): + # f_post_collision = bc( + # f_pre_collision, + # f_post_collision, + # boundary_id == id_number, + # mask, + # ) + f_pre_streaming = f_post_collision + + ## Apply forcing + # if self.forcing_op is not None: + # f = self.forcing_op.apply_jax(f, timestep) + + # Apply streaming + f_post_streaming = self.stream(f_pre_streaming) + + # Apply boundary conditions + # for id_number, bc in self.stream_boundary_conditions.items(): + # f_post_streaming = bc( + # f_pre_streaming, + # f_post_streaming, + # boundary_id == id_number, + # mask, + # ) + + # Copy back to store precision + f = self.precision_policy.cast_to_store(f_post_streaming) + + return f diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py new file mode 100644 index 0000000..9f4deaa --- /dev/null +++ b/xlb/solver/solver.py @@ -0,0 +1,37 @@ +# Base class for all stepper operators + +from xlb.compute_backends import ComputeBackends +from xlb.operator.boundary_condition import ImplementationStep +from xlb.global_config import GlobalConfig +from xlb.operator import Operator + + +class Solver(Operator): + """ + Abstract class for the construction of lattice boltzmann solvers + """ + + def __init__( + self, + velocity_set=None, + compute_backend=None, + precision_policy=None, + boundary_conditions=[], + ): + # Set parameters + self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.compute_backend = compute_backend or GlobalConfig.compute_backend + self.precision_policy = precision_policy or GlobalConfig.precision_policy + self.boundary_conditions = boundary_conditions + + # Get collision and stream boundary conditions + self.collision_boundary_conditions = {} + self.stream_boundary_conditions = {} + for id_number, bc in enumerate(self.boundary_conditions): + bc_id = id_number + 1 + if bc.implementation_step == ImplementationStep.COLLISION: + self.collision_boundary_conditions[bc_id] = bc + elif bc.implementation_step == ImplementationStep.STREAMING: + self.stream_boundary_conditions[bc_id] = bc + else: + raise ValueError("Boundary condition step not recognized") From 71c952b709ca66e24cc95712ea60445fc0b7fc3c Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Mon, 29 Jan 2024 10:16:58 -0500 Subject: [PATCH 006/144] Added multi-GPU support and mlups computation --- examples/refactor/example_basic.py | 62 +++++++++++++++++++ examples/refactor/example_mehdi.py | 37 ----------- examples/refactor/mlups3d.py | 53 ++++++++++++++++ xlb/__init__.py | 7 ++- xlb/global_config.py | 4 ++ xlb/grid/jax_grid.py | 2 +- .../equilibrium/quadratic_equilibrium.py | 2 +- xlb/operator/initializer/equilibrium_init.py | 1 + xlb/operator/macroscopic/macroscopic.py | 30 +++++++-- xlb/operator/stream/stream.py | 3 +- xlb/solver/nse.py | 6 +- xlb/utils/__init__.py | 1 + xlb/utils/utils.py | 55 ++-------------- 13 files changed, 161 insertions(+), 102 deletions(-) create mode 100644 examples/refactor/example_basic.py delete mode 100644 examples/refactor/example_mehdi.py create mode 100644 examples/refactor/mlups3d.py create mode 100644 xlb/utils/__init__.py diff --git a/examples/refactor/example_basic.py b/examples/refactor/example_basic.py new file mode 100644 index 0000000..9c9b033 --- /dev/null +++ b/examples/refactor/example_basic.py @@ -0,0 +1,62 @@ +import xlb +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 + +from xlb.solver import IncompressibleNavierStokes +from xlb.grid import Grid +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.utils import save_fields_vtk, save_image + +xlb.init( + precision_policy=Fp32Fp32, + compute_backend=ComputeBackends.JAX, + velocity_set=xlb.velocity_set.D2Q9, +) + +grid_shape = (1000, 1000) +grid = Grid.create(grid_shape) + + +def initializer(): + rho = grid.create_field(cardinality=1) + 1.0 + u = grid.create_field(cardinality=2) + + circle_center = (grid_shape[0] // 2, grid_shape[1] // 2) + circle_radius = 10 + + for x in range(grid_shape[0]): + for y in range(grid_shape[1]): + if (x - circle_center[0]) ** 2 + ( + y - circle_center[1] + ) ** 2 <= circle_radius**2: + rho = rho.at[0, x, y].add(0.001) + + func_eq = QuadraticEquilibrium() + f_eq = func_eq(rho, u) + + return f_eq + + +f = initializer() + +compute_macro = Macroscopic() + +solver = IncompressibleNavierStokes(grid, omega=1.0) + + +def perform_io(f, step): + rho, u = compute_macro(f) + fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1]} + save_fields_vtk(fields, step) + save_image(rho[0], step) + print(f"Step {step + 1} complete") + + +num_steps = 1000 +io_rate = 100 +for step in range(num_steps): + f = solver.step(f, timestep=step) + + if step % io_rate == 0: + perform_io(f, step) diff --git a/examples/refactor/example_mehdi.py b/examples/refactor/example_mehdi.py deleted file mode 100644 index b758d3e..0000000 --- a/examples/refactor/example_mehdi.py +++ /dev/null @@ -1,37 +0,0 @@ -import xlb -from xlb.compute_backends import ComputeBackends -from xlb.precision_policy import Fp32Fp32 - -from xlb.solver import IncompressibleNavierStokes -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.stream import Stream -from xlb.global_config import GlobalConfig -from xlb.grid import Grid -from xlb.operator.initializer import EquilibriumInitializer, ConstInitializer - -import numpy as np -import jax.numpy as jnp - -xlb.init(precision_policy=Fp32Fp32, compute_backend=ComputeBackends.JAX, velocity_set=xlb.velocity_set.D2Q9) - -grid_shape = (100, 100) -grid = Grid.create(grid_shape) - -f_init = grid.create_field(cardinality=9, callback=EquilibriumInitializer(grid)) - -u_init = grid.create_field(cardinality=2, callback=ConstInitializer(grid, cardinality=2, const_value=0.0)) -rho_init = grid.create_field(cardinality=1, callback=ConstInitializer(grid, cardinality=1, const_value=1.0)) - - -st = Stream(grid) - -f_init = st(f_init) -print("here") -solver = IncompressibleNavierStokes(grid) - -num_steps = 100 -f = f_init -for step in range(num_steps): - f = solver.step(f, timestep=step) - print(f"Step {step+1}/{num_steps} complete") - diff --git a/examples/refactor/mlups3d.py b/examples/refactor/mlups3d.py new file mode 100644 index 0000000..f1207ff --- /dev/null +++ b/examples/refactor/mlups3d.py @@ -0,0 +1,53 @@ +import xlb +import time +import jax +import argparse +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 +from xlb.operator.initializer import EquilibriumInitializer + +from xlb.solver import IncompressibleNavierStokes +from xlb.grid import Grid + +parser = argparse.ArgumentParser( + description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" +) +parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") +parser.add_argument("num_steps", type=int, help="Timestep for the simulation") + +args = parser.parse_args() + +cube_edge = args.cube_edge +num_steps = args.num_steps + + +xlb.init( + precision_policy=Fp32Fp32, + compute_backend=ComputeBackends.JAX, + velocity_set=xlb.velocity_set.D3Q19, +) + +grid_shape = (cube_edge, cube_edge, cube_edge) +grid = Grid.create(grid_shape) + +f = grid.create_field(cardinality=19, callback=EquilibriumInitializer(grid)) + +solver = IncompressibleNavierStokes(grid, omega=1.0) + +# Ahead-of-Time Compilation to remove JIT overhead + + +if xlb.current_backend() == ComputeBackends.JAX: + lowered = jax.jit(solver.step).lower(f, timestep=0) + solver_step_compiled = lowered.compile() + +start_time = time.time() + +for step in range(num_steps): + f = solver_step_compiled(f, timestep=step) + +end_time = time.time() +total_lattice_updates = cube_edge**3 * num_steps +total_time_seconds = end_time - start_time +mlups = (total_lattice_updates / total_time_seconds) / 1e6 +print(f"MLUPS: {mlups}") diff --git a/xlb/__init__.py b/xlb/__init__.py index 7f13a8e..7845bb2 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -4,7 +4,7 @@ # Config -from .global_config import init +from .global_config import init, current_backend # Precision policy @@ -25,4 +25,7 @@ import xlb.grid # Solvers -import xlb.solver \ No newline at end of file +import xlb.solver + +# Utils +import xlb.utils \ No newline at end of file diff --git a/xlb/global_config.py b/xlb/global_config.py index dd3e705..c0047c9 100644 --- a/xlb/global_config.py +++ b/xlb/global_config.py @@ -8,3 +8,7 @@ def init(velocity_set, compute_backend, precision_policy): GlobalConfig.velocity_set = velocity_set() GlobalConfig.compute_backend = compute_backend GlobalConfig.precision_policy = precision_policy() + + +def current_backend(): + return GlobalConfig.compute_backend \ No newline at end of file diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 57270c5..af6d239 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -23,7 +23,7 @@ def initialize_jax_backend(self): self.global_mesh = ( Mesh(device_mesh, axis_names=("cardinality", "x", "y")) if self.dim == 2 - else Mesh(self.devices, axis_names=("cardinality", "x", "y", "z")) + else Mesh(device_mesh, axis_names=("cardinality", "x", "y", "z")) ) self.sharding = ( NamedSharding(self.global_mesh, P("cardinality", "x", "y")) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 1fd7458..3dc4993 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -31,7 +31,7 @@ def __init__( def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) - w = self.velocity_set.w.reshape(-1, 1, 1) + w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1)) feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py index 9d9dc56..bad7c85 100644 --- a/xlb/operator/initializer/equilibrium_init.py +++ b/xlb/operator/initializer/equilibrium_init.py @@ -17,6 +17,7 @@ def __init__( velocity_set = velocity_set or GlobalConfig.velocity_set compute_backend = compute_backend or GlobalConfig.compute_backend local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) + self.init_values = np.zeros( grid.global_to_local_shape(grid.pop_shape) ) + velocity_set.w.reshape(local_shape) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index fc04db2..733f8f7 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -1,13 +1,14 @@ # Base class for all equilibriums +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator + from functools import partial import jax.numpy as jnp from jax import jit -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator - class Macroscopic(Operator): """ @@ -20,9 +21,12 @@ class Macroscopic(Operator): def __init__( self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): + self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.compute_backend = compute_backend or GlobalConfig.compute_backend + super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) @@ -30,6 +34,20 @@ def __init__( def jax_implementation(self, f): """ Apply the macroscopic operator to the lattice distribution function + TODO: Check if the following implementation is more efficient ( + as the compiler may be able to remove operations resulting in zero) + c_x = tuple(self.velocity_set.c[0]) + c_y = tuple(self.velocity_set.c[1]) + + u_x = 0.0 + u_y = 0.0 + + rho = jnp.sum(f, axis=0, keepdims=True) + + for i in range(self.velocity_set.q): + u_x += c_x[i] * f[i, ...] + u_y += c_y[i] * f[i, ...] + return rho, jnp.stack((u_x, u_y)) """ rho = jnp.sum(f, axis=0, keepdims=True) u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 277e9a1..e942961 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -21,7 +21,7 @@ def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None) super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) - # @partial(jit, static_argnums=(0)) + @partial(jit, static_argnums=(0)) def jax_implementation(self, f): """ JAX implementation of the streaming step. @@ -38,7 +38,6 @@ def jax_implementation(self, f): mesh=self.grid.global_mesh, in_specs=in_specs, out_specs=out_specs, - check_rep=False, )(f) def _streaming_jax_p(self, f): diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 96a529c..251fe37 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -18,6 +18,7 @@ class IncompressibleNavierStokes(Solver): def __init__( self, grid, + omega, velocity_set: VelocitySet = None, compute_backend=None, precision_policy=None, @@ -25,6 +26,7 @@ def __init__( collision_kernel="BGK", ): self.grid = grid + self.omega = omega self.collision_kernel = collision_kernel super().__init__(velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, boundary_conditions=boundary_conditions) self.create_operators() @@ -39,13 +41,13 @@ def create_operators(self): ) self.collision = ( KBC( - omega=1.0, + omega=self.omega, velocity_set=self.velocity_set, compute_backend=self.compute_backend, ) if self.collision_kernel == "KBC" else BGK( - omega=1.0, + omega=self.omega, velocity_set=self.velocity_set, compute_backend=self.compute_backend, ) diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py new file mode 100644 index 0000000..2107fc8 --- /dev/null +++ b/xlb/utils/__init__.py @@ -0,0 +1 @@ +from .utils import downsample_field, save_image, save_fields_vtk, save_BCs_vtk, rotate_geometry, voxelize_stl, axangle2mat diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index d01c500..6d9c627 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -47,7 +47,7 @@ def downsample_field(field, factor, method="bicubic"): return jnp.stack(downsampled_components, axis=-1) -def save_image(timestep, fld, prefix=None): +def save_image(fld, timestep, prefix=None): """ Save an image of a field at a given timestep. @@ -78,13 +78,13 @@ def save_image(timestep, fld, prefix=None): if len(fld.shape) > 3: raise ValueError("The input field should be 2D!") elif len(fld.shape) == 3: - fld = np.sqrt(fld[..., 0] ** 2 + fld[..., 1] ** 2) + fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2) plt.clf() plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower") -def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): +def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): """ Save VTK fields to the specified directory. @@ -111,7 +111,7 @@ def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): will be saved as 'fields_0000010.vtk'in the specified directory. """ - # Assert that all fields have the same dimensions except for the last dimension assuming fields is a dictionary + # Assert that all fields have the same dimensions for key, value in fields.items(): if key == list(fields.keys())[0]: dimensions = value.shape @@ -140,53 +140,6 @@ def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): grid.save(output_filename, binary=True) print(f"Saved {output_filename} in {time() - start:.6f} seconds.") - -def live_volume_randering(timestep, field): - # WORK IN PROGRESS - """ - Live rendering of a 3D volume using pyvista. - - Parameters - ---------- - field (np.ndarray): A 3D array containing the field to be rendered. - - Returns - ------- - None - - Notes - ----- - This function uses pyvista to render a 3D volume. The volume is rendered with a colormap based on the field values. - The colormap is updated every 0.1 seconds to reflect changes to the field. - - """ - # Create a uniform grid (Note that the field must be 3D) otherwise raise error - if field.ndim != 3: - raise ValueError("The input field must be 3D!") - dimensions = field.shape - grid = pv.UniformGrid(dimensions=dimensions) - - # Add the field to the grid - grid["field"] = field.flatten(order="F") - - # Create the rendering scene - if timestep == 0: - plt.ion() - plt.figure(figsize=(10, 10)) - plt.axis("off") - plt.title("Live rendering of the field") - pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) - plt.imshow(pl.screenshot()) - - else: - pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) - # Update the rendering scene every 0.1 seconds - plt.imshow(pl.screenshot()) - plt.pause(0.1) - - def save_BCs_vtk(timestep, BCs, gridInfo, output_dir="."): """ Save boundary conditions as VTK format to the specified directory. From 2d8eb5773f71e4eee22ccda1271a2e232ab719f2 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Sun, 4 Feb 2024 19:46:52 -0500 Subject: [PATCH 007/144] Added inline --- examples/refactor/mlups3d.py | 2 -- xlb/operator/collision/bgk.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/refactor/mlups3d.py b/examples/refactor/mlups3d.py index f1207ff..ec88427 100644 --- a/examples/refactor/mlups3d.py +++ b/examples/refactor/mlups3d.py @@ -35,8 +35,6 @@ solver = IncompressibleNavierStokes(grid, omega=1.0) # Ahead-of-Time Compilation to remove JIT overhead - - if xlb.current_backend() == ComputeBackends.JAX: lowered = jax.jit(solver.step).lower(f, timestep=0) solver_step_compiled = lowered.compile() diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 9dfdf33..cbd2b36 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -23,7 +23,7 @@ def __init__( ) @Operator.register_backend(ComputeBackends.JAX) - @partial(jit, static_argnums=(0,)) + @partial(jit, static_argnums=(0,), inline=True) def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq From a48510cefc7af0cb965b67c86854a609b7d8d1d4 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Mon, 5 Feb 2024 23:30:06 -0500 Subject: [PATCH 008/144] Added pallas but there is a shape issue yet --- examples/refactor/example_pallas_3d.py | 73 ++++++++++++++++ examples/refactor/mlups3d.py | 17 ++-- examples/refactor/mlups_pallas_3d.py | 84 ++++++++++++++++++ xlb/__init__.py | 1 - xlb/compute_backends.py | 3 +- xlb/grid/grid.py | 4 +- xlb/grid/jax_grid.py | 8 +- xlb/operator/__init__.py | 1 + xlb/operator/collision/bgk.py | 6 ++ xlb/operator/collision/kbc.py | 4 +- .../equilibrium/quadratic_equilibrium.py | 28 ++++++ xlb/operator/initializer/const_init.py | 28 ++++-- xlb/operator/initializer/equilibrium_init.py | 6 +- xlb/operator/macroscopic/macroscopic.py | 27 ++++++ xlb/operator/parallel_operator.py | 86 +++++++++++++++++++ xlb/operator/stream/stream.py | 58 ++----------- xlb/precision_policy/precision_policy.py | 31 +++++-- xlb/solver/nse.py | 67 ++++++++++++++- xlb/utils/utils.py | 4 +- xlb/velocity_set/velocity_set.py | 2 - 20 files changed, 450 insertions(+), 88 deletions(-) create mode 100644 examples/refactor/example_pallas_3d.py create mode 100644 examples/refactor/mlups_pallas_3d.py create mode 100644 xlb/operator/parallel_operator.py diff --git a/examples/refactor/example_pallas_3d.py b/examples/refactor/example_pallas_3d.py new file mode 100644 index 0000000..09b0305 --- /dev/null +++ b/examples/refactor/example_pallas_3d.py @@ -0,0 +1,73 @@ +import xlb +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 + +from xlb.solver import IncompressibleNavierStokes +from xlb.grid import Grid +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.utils import save_fields_vtk, save_image +import numpy as np +import jax.numpy as jnp + +# Initialize XLB with Pallas backend for 3D simulation +xlb.init( + precision_policy=Fp32Fp32, + compute_backend=ComputeBackends.PALLAS, # Changed to Pallas backend + velocity_set=xlb.velocity_set.D3Q19, # Changed to D3Q19 for 3D +) + +grid_shape = (128, 128, 128) # Adjusted for 3D grid +grid = Grid.create(grid_shape) + + +def initializer(): + rho = grid.create_field(cardinality=1) + 1.0 + u = grid.create_field(cardinality=3) + + sphere_center = np.array([s // 2 for s in grid_shape]) + sphere_radius = 10 + + x, y, z = np.meshgrid( + np.arange(grid_shape[0]), + np.arange(grid_shape[1]), + np.arange(grid_shape[2]), + indexing="ij", + ) + + squared_dist = ( + (x - sphere_center[0]) ** 2 + + (y - sphere_center[1]) ** 2 + + (z - sphere_center[2]) ** 2 + ) + + inside_sphere = squared_dist <= sphere_radius**2 + + rho = jnp.where(inside_sphere, rho.at[0, x, y, z].add(0.001), rho) + + func_eq = QuadraticEquilibrium(compute_backend=ComputeBackends.JAX) + f_eq = func_eq(rho, u) + + return f_eq + + +f = initializer() + +compute_macro = Macroscopic(compute_backend=ComputeBackends.JAX) + +solver = IncompressibleNavierStokes(grid, omega=1.0) + +def perform_io(f, step): + rho, u = compute_macro(f) + fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_z": u[2]} + save_fields_vtk(fields, step) + # save_image function might not be suitable for 3D, consider alternative visualization + print(f"Step {step + 1} complete") + +num_steps = 1000 +io_rate = 100 +for step in range(num_steps): + f = solver.step(f, timestep=step) + + if step % io_rate == 0: + perform_io(f, step) diff --git a/examples/refactor/mlups3d.py b/examples/refactor/mlups3d.py index ec88427..2a37fdb 100644 --- a/examples/refactor/mlups3d.py +++ b/examples/refactor/mlups3d.py @@ -23,26 +23,31 @@ xlb.init( precision_policy=Fp32Fp32, - compute_backend=ComputeBackends.JAX, + compute_backend=ComputeBackends.PALLAS, velocity_set=xlb.velocity_set.D3Q19, ) grid_shape = (cube_edge, cube_edge, cube_edge) grid = Grid.create(grid_shape) -f = grid.create_field(cardinality=19, callback=EquilibriumInitializer(grid)) +f = grid.create_field(cardinality=19) + +print("f shape", f.shape) solver = IncompressibleNavierStokes(grid, omega=1.0) # Ahead-of-Time Compilation to remove JIT overhead -if xlb.current_backend() == ComputeBackends.JAX: - lowered = jax.jit(solver.step).lower(f, timestep=0) - solver_step_compiled = lowered.compile() +# if xlb.current_backend() == ComputeBackends.JAX or xlb.current_backend() == ComputeBackends.PALLAS: +# lowered = jax.jit(solver.step).lower(f, timestep=0) +# solver_step_compiled = lowered.compile() + +# Ahead-of-Time Compilation to remove JIT overhead +f = solver.step(f, timestep=0) start_time = time.time() for step in range(num_steps): - f = solver_step_compiled(f, timestep=step) + f = solver.step(f, timestep=step) end_time = time.time() total_lattice_updates = cube_edge**3 * num_steps diff --git a/examples/refactor/mlups_pallas_3d.py b/examples/refactor/mlups_pallas_3d.py new file mode 100644 index 0000000..4c8ff50 --- /dev/null +++ b/examples/refactor/mlups_pallas_3d.py @@ -0,0 +1,84 @@ +import xlb +import time +import argparse +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 +from xlb.solver import IncompressibleNavierStokes +from xlb.grid import Grid +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +import numpy as np +import jax.numpy as jnp + +# Command line argument parsing +parser = argparse.ArgumentParser( + description="3D Lattice Boltzmann Method Simulation using XLB" +) +parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") +parser.add_argument( + "num_steps", type=int, help="Number of timesteps for the simulation" +) +args = parser.parse_args() + +# Initialize XLB +xlb.init( + precision_policy=Fp32Fp32, + compute_backend=ComputeBackends.PALLAS, + velocity_set=xlb.velocity_set.D3Q19, +) + +# Grid initialization +grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) +grid = Grid.create(grid_shape) + + +def initializer(): + rho = grid.create_field(cardinality=1) + 1.0 + u = grid.create_field(cardinality=3) + + sphere_center = np.array([s // 2 for s in grid_shape]) + sphere_radius = 10 + + x, y, z = np.meshgrid( + np.arange(grid_shape[0]), + np.arange(grid_shape[1]), + np.arange(grid_shape[2]), + indexing="ij", + ) + + squared_dist = ( + (x - sphere_center[0]) ** 2 + + (y - sphere_center[1]) ** 2 + + (z - sphere_center[2]) ** 2 + ) + + inside_sphere = squared_dist <= sphere_radius**2 + + rho = jnp.where(inside_sphere, rho.at[0, x, y, z].add(0.001), rho) + + func_eq = QuadraticEquilibrium(compute_backend=ComputeBackends.JAX) + f_eq = func_eq(rho, u) + + return f_eq + + +f = initializer() + +solver = IncompressibleNavierStokes(grid, omega=1.0) + +# AoT compile +f = solver.step(f, timestep=0) + +# Start the simulation +start_time = time.time() + +for step in range(args.num_steps): + f = solver.step(f, timestep=step) + +end_time = time.time() + +# MLUPS calculation +total_lattice_updates = args.cube_edge**3 * args.num_steps +total_time_seconds = end_time - start_time +mlups = (total_lattice_updates / total_time_seconds) / 1e6 +print(f"MLUPS: {mlups}") diff --git a/xlb/__init__.py b/xlb/__init__.py index 7845bb2..4589bf5 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -18,7 +18,6 @@ import xlb.operator.collision import xlb.operator.stream import xlb.operator.boundary_condition -# import xlb.operator.force import xlb.operator.macroscopic # Grids diff --git a/xlb/compute_backends.py b/xlb/compute_backends.py index f60073a..aff65cc 100644 --- a/xlb/compute_backends.py +++ b/xlb/compute_backends.py @@ -4,4 +4,5 @@ class ComputeBackends(Enum): JAX = 1 - WARP = 2 + PALLAS = 2 + WARP = 3 diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 7a03950..66e4b68 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -23,12 +23,12 @@ def create(grid_shape, velocity_set=None, compute_backend=None): compute_backend = compute_backend or GlobalConfig.compute_backend velocity_set = velocity_set or GlobalConfig.velocity_set - if compute_backend == ComputeBackends.JAX: + if compute_backend == ComputeBackends.JAX or compute_backend == ComputeBackends.PALLAS: from xlb.grid.jax_grid import JaxGrid # Avoids circular import return JaxGrid(grid_shape, velocity_set, compute_backend) raise ValueError(f"Compute backend {compute_backend} is not supported") @abstractmethod - def global_to_local_shape(self, shape): + def field_global_to_local_shape(self, shape): pass diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index af6d239..6b95c25 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -30,8 +30,11 @@ def initialize_jax_backend(self): if self.dim == 2 else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) ) + self.grid_shape_per_gpu = ( + self.grid_shape[0] // self.nDevices, + ) + self.grid_shape[1:] - def global_to_local_shape(self, shape): + def field_global_to_local_shape(self, shape): if len(shape) < 2: raise ValueError("Shape must have at least two dimensions") @@ -41,6 +44,7 @@ def global_to_local_shape(self, shape): def create_field(self, cardinality, callback=None): if callback is None: - callback = ConstInitializer(self, cardinality, const_value=0.0) + f = ConstInitializer(self, cardinality=cardinality)(0.0) + return f shape = (cardinality,) + (self.grid_shape) return jax.make_array_from_callback(shape, self.sharding, callback) diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 1cf32f9..02b8a59 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1 +1,2 @@ from xlb.operator.operator import Operator +from xlb.operator.parallel_operator import ParallelOperator diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index cbd2b36..f129811 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -28,6 +28,12 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq return fout + + @Operator.register_backend(ComputeBackends.PALLAS) + def pallas_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): + fneq = f - feq + fout = f - self.omega * fneq + return fout @Operator.register_backend(ComputeBackends.WARP) def warp_implementation(self, *args, **kwargs): diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 1b24a6b..0fcc0af 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -5,7 +5,6 @@ import jax.numpy as jnp from jax import jit from functools import partial -from numba import cuda, float32 from xlb.operator import Operator from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 from xlb.compute_backends import ComputeBackends @@ -21,12 +20,11 @@ class KBC(Collision): def __init__( self, - omega, velocity_set: VelocitySet = None, compute_backend=None, ): super().__init__( - omega=omega, velocity_set=velocity_set, compute_backend=compute_backend + velocity_set=velocity_set, compute_backend=compute_backend ) self.epsilon = 1e-32 self.beta = self.omega * 0.5 diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 3dc4993..1df695e 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -35,3 +35,31 @@ def jax_implementation(self, rho, u): feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq + + @Operator.register_backend(ComputeBackends.PALLAS) + def pallas_implementation(self, rho, u): + u0, u1, u2 = u[0], u[1], u[2] + usqr = 1.5 * (u0**2 + u1**2 + u2**2) + + eq = [ + rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u0 + 4.5 * u0 * u0 - usqr), + rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u1 + 4.5 * u1 * u1 - usqr), + rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u2 + 4.5 * u2 * u2 - usqr), + ] + + combined_velocities = [u0 + u1, u0 - u1, u0 + u2, u0 - u2, u1 + u2, u1 - u2] + + for vel in combined_velocities: + eq.append( + rho[0] * (1.0 / 36.0) * (1.0 - 3.0 * vel + 4.5 * vel * vel - usqr) + ) + + eq.append(rho[0] * (1.0 / 3.0) * (1.0 - usqr)) + + for i in range(3): + eq.append(eq[i] + rho[0] * (1.0 / 18.0) * 6.0 * u[i]) + + for i, vel in enumerate(combined_velocities, 3): + eq.append(eq[i] + rho[0] * (1.0 / 36.0) * 6.0 * vel) + + return jnp.array(eq) diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py index d3b13cb..b12ec41 100644 --- a/xlb/operator/initializer/const_init.py +++ b/xlb/operator/initializer/const_init.py @@ -3,6 +3,7 @@ from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator from xlb.grid.grid import Grid +from functools import partial import numpy as np import jax @@ -12,17 +13,34 @@ def __init__( self, grid: Grid, cardinality, - const_value=0.0, + type=np.float32, velocity_set: VelocitySet = None, compute_backend: ComputeBackends = None, ): + self.type = type + self.grid = grid velocity_set = velocity_set or GlobalConfig.velocity_set compute_backend = compute_backend or GlobalConfig.compute_backend - shape = (cardinality,) + (grid.grid_shape) - self.init_values = np.zeros(grid.global_to_local_shape(shape)) + const_value + self.shape = (cardinality,) + (grid.grid_shape) super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) - def jax_implementation(self, index): - return self.init_values + @partial(jax.jit, static_argnums=(0, 2)) + def jax_implementation(self, const_value, sharding=None): + if sharding is None: + sharding = self.grid.sharding + x = jax.numpy.full( + shape=self.shape, fill_value=const_value, dtype=self.type + ) + return jax.lax.with_sharding_constraint(x, sharding) + + @Operator.register_backend(ComputeBackends.PALLAS) + @partial(jax.jit, static_argnums=(0, 2)) + def jax_implementation(self, const_value, sharding=None): + if sharding is None: + sharding = self.grid.sharding + x = jax.numpy.full( + shape=self.shape, fill_value=const_value, dtype=self.type + ) + return jax.lax.with_sharding_constraint(x, sharding) diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py index bad7c85..f44722f 100644 --- a/xlb/operator/initializer/equilibrium_init.py +++ b/xlb/operator/initializer/equilibrium_init.py @@ -19,7 +19,7 @@ def __init__( local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) self.init_values = np.zeros( - grid.global_to_local_shape(grid.pop_shape) + grid.field_global_to_local_shape(grid.pop_shape) ) + velocity_set.w.reshape(local_shape) super().__init__(velocity_set, compute_backend) @@ -27,3 +27,7 @@ def __init__( @Operator.register_backend(ComputeBackends.JAX) def jax_implementation(self, index): return self.init_values + + @Operator.register_backend(ComputeBackends.PALLAS) + def jax_implementation(self, index): + return self.init_values diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 733f8f7..5a4ad93 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -53,3 +53,30 @@ def jax_implementation(self, f): u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho return rho, u + + @Operator.register_backend(ComputeBackends.PALLAS) + def pallas_implementation(self, f): + # TODO: Maybe this can be done with jnp.sum + rho = jnp.sum(f, axis=0, keepdims=True) + + u = jnp.zeros((3, *rho.shape[1:])) + u.at[0].set( + -f[9] + - f[10] + - f[11] + - f[12] + - f[13] + + f[14] + + f[15] + + f[16] + + f[17] + + f[18] + ) / rho + u.at[1].set( + -f[3] - f[4] - f[5] + f[6] + f[7] + f[8] - f[12] + f[13] - f[17] + f[18] + ) / rho + u.at[2].set( + -f[1] + f[2] - f[4] + f[5] - f[7] + f[8] - f[10] + f[11] - f[15] + f[16] + ) / rho + + return rho, jnp.array(u) diff --git a/xlb/operator/parallel_operator.py b/xlb/operator/parallel_operator.py new file mode 100644 index 0000000..9309b21 --- /dev/null +++ b/xlb/operator/parallel_operator.py @@ -0,0 +1,86 @@ +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P +from jax import lax + + +class ParallelOperator: + """ + A generic class for parallelizing operations across multiple GPUs/TPUs. + """ + + def __init__(self, grid, func, velocity_set): + """ + Initialize the ParallelOperator. + + Parameters + ---------- + grid : Grid object + The computational grid. + func : function + The function to be parallelized. + velocity_set : VelocitySet object + The velocity set used in the Lattice Boltzmann Method. + """ + self.grid = grid + self.func = func + self.velocity_set = velocity_set + + def __call__(self, f): + """ + Execute the parallel operation. + + Parameters + ---------- + f : jax.numpy.ndarray + The input data for the operation. + + Returns + ------- + jax.numpy.ndarray + The result after applying the parallel operation. + """ + in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) + out_specs = in_specs + + f = shard_map( + self._parallel_func, + mesh=self.grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + )(f) + return f + + def _parallel_func(self, f): + """ + Internal function to handle data communication and apply the given function in parallel. + + Parameters + ---------- + f : jax.numpy.ndarray + The input data. + + Returns + ------- + jax.numpy.ndarray + The processed data. + """ + rightPerm = [ + (i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices) + ] + leftPerm = [ + ((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices) + ] + f = self.func(f) + left_comm, right_comm = ( + f[self.velocity_set.right_indices, :1, ...], + f[self.velocity_set.left_indices, -1:, ...], + ) + left_comm, right_comm = ( + lax.ppermute(left_comm, perm=rightPerm, axis_name="x"), + lax.ppermute(right_comm, perm=leftPerm, axis_name="x"), + ) + f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) + f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) + + return f diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index e942961..a441ca0 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -2,14 +2,11 @@ from functools import partial import jax.numpy as jnp -from jax import jit, vmap, lax - +from jax import jit, vmap from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from jax.experimental.shard_map import shard_map -from jax.sharding import PartitionSpec as P - +from xlb.operator import Operator +from xlb.operator import ParallelOperator class Stream(Operator): """ @@ -18,6 +15,7 @@ class Stream(Operator): def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None): self.grid = grid + self.parallel_operator = ParallelOperator(grid, self._streaming_jax_p, velocity_set) super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) @@ -31,14 +29,7 @@ def jax_implementation(self, f): f: jax.numpy.ndarray The distribution function. """ - in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) - out_specs = in_specs - return shard_map( - self._streaming_jax_m, - mesh=self.grid.global_mesh, - in_specs=in_specs, - out_specs=out_specs, - )(f) + return self.parallel_operator(f) def _streaming_jax_p(self, f): def _streaming_jax_i(f, c): @@ -63,42 +54,3 @@ def _streaming_jax_i(f, c): return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T ) - - def _streaming_jax_m(self, f): - """ - This function performs the streaming step in the Lattice Boltzmann Method, which is - the propagation of the distribution functions in the lattice. - - To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the - distribution functions that need to be communicated to the neighboring processes. - - The function then sends the left boundary slice to the right neighboring process and the right - boundary slice to the left neighboring process. The received data is then set to the - corresponding indices in the receiving domain. - - Parameters - ---------- - f: jax.numpy.ndarray - The array holding the distribution functions for the simulation. - - Returns - ------- - jax.numpy.ndarray - The distribution functions after the streaming operation. - """ - rightPerm = [(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)] - leftPerm = [((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)] - - f = self._streaming_jax_p(f) - left_comm, right_comm = ( - f[self.velocity_set.right_indices, :1, ...], - f[self.velocity_set.left_indices, -1:, ...], - ) - left_comm, right_comm = ( - lax.ppermute(left_comm, perm=rightPerm, axis_name='x'), - lax.ppermute(right_comm, perm=leftPerm, axis_name='x'), - ) - f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) - f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) - - return f diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 75f4434..aacc785 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -13,16 +13,23 @@ class Fp64Fp64: def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: + if ( + GlobalConfig.compute_backend == ComputeBackends.JAX + or GlobalConfig.compute_backend == ComputeBackends.PALLAS + ): return JaxFp64Fp64() else: raise ValueError( f"Unsupported compute backend: {GlobalConfig.compute_backend}" ) - + + class Fp64Fp32: def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: + if ( + GlobalConfig.compute_backend == ComputeBackends.JAX + or GlobalConfig.compute_backend == ComputeBackends.PALLAS + ): return JaxFp64Fp32() else: raise ValueError( @@ -32,7 +39,10 @@ def __new__(cls): class Fp32Fp32: def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: + if ( + GlobalConfig.compute_backend == ComputeBackends.JAX + or GlobalConfig.compute_backend == ComputeBackends.PALLAS + ): return JaxFp32Fp32() else: raise ValueError( @@ -42,18 +52,25 @@ def __new__(cls): class Fp64Fp16: def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: + if ( + GlobalConfig.compute_backend == ComputeBackends.JAX + or GlobalConfig.compute_backend == ComputeBackends.PALLAS + ): return JaxFp64Fp16() else: raise ValueError( f"Unsupported compute backend: {GlobalConfig.compute_backend}" ) + class Fp32Fp16: def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: + if ( + GlobalConfig.compute_backend == ComputeBackends.JAX + or GlobalConfig.compute_backend == ComputeBackends.PALLAS + ): return JaxFp32Fp16() else: raise ValueError( f"Unsupported compute backend: {GlobalConfig.compute_backend}" - ) \ No newline at end of file + ) diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 251fe37..878f6ce 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -2,6 +2,7 @@ from functools import partial from jax import jit +import jax from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backends import ComputeBackends @@ -12,6 +13,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.solver.solver import Solver from xlb.operator import Operator +from jax.experimental import pallas as pl class IncompressibleNavierStokes(Solver): @@ -28,7 +30,12 @@ def __init__( self.grid = grid self.omega = omega self.collision_kernel = collision_kernel - super().__init__(velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, boundary_conditions=boundary_conditions) + super().__init__( + velocity_set=velocity_set, + compute_backend=compute_backend, + precision_policy=precision_policy, + boundary_conditions=boundary_conditions, + ) self.create_operators() # Operators @@ -52,8 +59,10 @@ def create_operators(self): compute_backend=self.compute_backend, ) ) - self.stream = Stream(self.grid, - velocity_set=self.velocity_set, compute_backend=self.compute_backend + self.stream = Stream( + self.grid, + velocity_set=self.velocity_set, + compute_backend=self.compute_backend, ) @Operator.register_backend(ComputeBackends.JAX) @@ -108,3 +117,55 @@ def step(self, f, timestep): f = self.precision_policy.cast_to_store(f_post_streaming) return f + + @Operator.register_backend(ComputeBackends.PALLAS) + @partial(jit, static_argnums=(0,)) + def step(self, fin, timestep): + from xlb.operator.parallel_operator import ParallelOperator + + def _pallas_collide(fin, fout): + idx = pl.program_id(0) + + f = (pl.load(fin, (slice(None), idx, slice(None), slice(None)))) + + print("f shape", f.shape) + + rho, u = self.macroscopic(f) + + print("rho shape", rho.shape) + print("u shape", u.shape) + + feq = self.equilibrium(rho, u) + + print("feq shape", feq.shape) + + for i in range(self.velocity_set.q): + print("f shape", f[i].shape) + f_post_collision = self.collision(f[i], feq[i]) + print("f_post_collision shape", f_post_collision.shape) + pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) + # f_post_collision = self.collision(f, feq) + # pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) + + @jit + def _pallas_collide_kernel(fin): + return pl.pallas_call( + partial(_pallas_collide), + out_shape=jax.ShapeDtypeStruct( + ((self.velocity_set.q,) + (self.grid.grid_shape_per_gpu)), fin.dtype + ), + # grid=1, + grid=(self.grid.grid_shape_per_gpu[0], 1, 1), + )(fin) + + def _pallas_collide_and_stream(f): + f = _pallas_collide_kernel(f) + # f = self.stream._streaming_jax_p(f) + + return f + + fout = ParallelOperator( + self.grid, _pallas_collide_and_stream, self.velocity_set + )(fin) + + return fout diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index 6d9c627..c20dacc 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -129,7 +129,7 @@ def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): if value.ndim == 2: dimensions = dimensions + (1,) - grid = pv.UniformGrid(dimensions=dimensions) + grid = pv.ImageData(dimensions=dimensions) # Add the fields to the grid for key, value in fields.items(): @@ -168,7 +168,7 @@ def save_BCs_vtk(timestep, BCs, gridInfo, output_dir="."): gridDimensions = (gridInfo["nx"] + 1, gridInfo["ny"] + 1, gridInfo["nz"] + 1) fieldDimensions = (gridInfo["nx"], gridInfo["ny"], gridInfo["nz"]) - grid = pv.UniformGrid(dimensions=gridDimensions) + grid = pv.ImageData(dimensions=gridDimensions) # Dictionary to keep track of encountered BC names bcNamesCount = {} diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 05ba2a2..c38333d 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -5,8 +5,6 @@ from functools import partial import jax.numpy as jnp from jax import jit, vmap -import numba -from numba import cuda, float32, int32 class VelocitySet(object): From b4e13c61a4a780d66af5f9d4a7b737c3c8ddd593 Mon Sep 17 00:00:00 2001 From: Oliver Date: Tue, 6 Feb 2024 15:00:00 -0800 Subject: [PATCH 009/144] first restructuring --- .../backend_comparisons/lattice_boltzmann.py | 2 + examples/warp_backend/equilibrium.py | 33 +++++++ xlb/__init__.py | 10 +-- xlb/compute_backend.py | 7 ++ xlb/compute_backends.py | 7 -- xlb/global_config.py | 5 +- xlb/grid/grid.py | 4 +- .../boundary_condition/boundary_condition.py | 6 +- xlb/operator/boundary_condition/do_nothing.py | 6 +- .../equilibrium_boundary.py | 6 +- .../boundary_condition/full_bounce_back.py | 6 +- .../boundary_condition/halfway_bounce_back.py | 6 +- xlb/operator/collision/bgk.py | 6 +- xlb/operator/collision/kbc.py | 6 +- xlb/operator/equilibrium/equilibrium.py | 4 +- .../equilibrium/quadratic_equilibrium.py | 90 +++++++++++++++++-- xlb/operator/macroscopic/macroscopic.py | 4 +- xlb/operator/operator.py | 80 ++++++++++++++++- xlb/operator/stream/stream.py | 4 +- xlb/physics_type.py | 6 +- xlb/precision_policy.py | 45 ++++++++++ xlb/precision_policy/__init__.py | 1 - xlb/precision_policy/base_precision_policy.py | 14 --- xlb/precision_policy/jax_precision_policy.py | 72 --------------- xlb/precision_policy/precision_policy.py | 59 ------------ xlb/solver/nse.py | 4 +- xlb/solver/solver.py | 2 +- xlb/velocity_set/velocity_set.py | 60 ++----------- 28 files changed, 295 insertions(+), 260 deletions(-) create mode 100644 examples/warp_backend/equilibrium.py create mode 100644 xlb/compute_backend.py delete mode 100644 xlb/compute_backends.py create mode 100644 xlb/precision_policy.py delete mode 100644 xlb/precision_policy/__init__.py delete mode 100644 xlb/precision_policy/base_precision_policy.py delete mode 100644 xlb/precision_policy/jax_precision_policy.py delete mode 100644 xlb/precision_policy/precision_policy.py diff --git a/examples/backend_comparisons/lattice_boltzmann.py b/examples/backend_comparisons/lattice_boltzmann.py index a8e1c39..57f0e12 100644 --- a/examples/backend_comparisons/lattice_boltzmann.py +++ b/examples/backend_comparisons/lattice_boltzmann.py @@ -1070,6 +1070,8 @@ def jax_apply_collide_stream(f, tau: float): # Compute MLUPS mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 backend.append("Warp") + print(mlups) + exit() mlups.append(mlups) # Plot results diff --git a/examples/warp_backend/equilibrium.py b/examples/warp_backend/equilibrium.py new file mode 100644 index 0000000..845bc90 --- /dev/null +++ b/examples/warp_backend/equilibrium.py @@ -0,0 +1,33 @@ +# from IPython import display +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt + +import warp as wp +wp.init() + +import xlb + +if __name__ == "__main__": + + # Make operator + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q27() + compute_backend = xlb.ComputeBackend.WARP + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + + # Make warp arrays + nr = 128 + f = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + u = wp.zeros((3, nr, nr, nr), dtype=wp.float32) + rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32) + + # Run simulation + equilibrium(rho, u, f) diff --git a/xlb/__init__.py b/xlb/__init__.py index 7845bb2..e72c99e 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -1,15 +1,11 @@ # Enum classes -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy from xlb.physics_type import PhysicsType - # Config from .global_config import init, current_backend - -# Precision policy -import xlb.precision_policy - # Velocity Set import xlb.velocity_set @@ -28,4 +24,4 @@ import xlb.solver # Utils -import xlb.utils \ No newline at end of file +import xlb.utils diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py new file mode 100644 index 0000000..15a2ea4 --- /dev/null +++ b/xlb/compute_backend.py @@ -0,0 +1,7 @@ +# Enum used to keep track of the compute backends + +from enum import Enum, auto + +class ComputeBackend(Enum): + JAX = auto() + WARP = auto() diff --git a/xlb/compute_backends.py b/xlb/compute_backends.py deleted file mode 100644 index f60073a..0000000 --- a/xlb/compute_backends.py +++ /dev/null @@ -1,7 +0,0 @@ -# Enum used to keep track of the compute backends - -from enum import Enum - -class ComputeBackends(Enum): - JAX = 1 - WARP = 2 diff --git a/xlb/global_config.py b/xlb/global_config.py index c0047c9..563a75d 100644 --- a/xlb/global_config.py +++ b/xlb/global_config.py @@ -7,8 +7,7 @@ class GlobalConfig: def init(velocity_set, compute_backend, precision_policy): GlobalConfig.velocity_set = velocity_set() GlobalConfig.compute_backend = compute_backend - GlobalConfig.precision_policy = precision_policy() - + GlobalConfig.precision_policy = precision_policy def current_backend(): - return GlobalConfig.compute_backend \ No newline at end of file + return GlobalConfig.compute_backend diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 7a03950..c969d18 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.global_config import GlobalConfig from xlb.velocity_set import VelocitySet @@ -23,7 +23,7 @@ def create(grid_shape, velocity_set=None, compute_backend=None): compute_backend = compute_backend or GlobalConfig.compute_backend velocity_set = velocity_set or GlobalConfig.velocity_set - if compute_backend == ComputeBackends.JAX: + if compute_backend == ComputeBackend.JAX: from xlb.grid.jax_grid import JaxGrid # Avoids circular import return JaxGrid(grid_shape, velocity_set, compute_backend) diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index cca6ed0..7e8909a 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -10,7 +10,7 @@ from xlb.operator.operator import Operator from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend # Enum for implementation step class ImplementationStep(Enum): @@ -27,7 +27,7 @@ def __init__( set_boundary, implementation_step: ImplementationStep, velocity_set: VelocitySet, - compute_backend: ComputeBackends.JAX, + compute_backend: ComputeBackend.JAX, ): super().__init__(velocity_set, compute_backend) @@ -35,7 +35,7 @@ def __init__( self.implementation_step = implementation_step # Set boundary function - if compute_backend == ComputeBackends.JAX: + if compute_backend == ComputeBackend.JAX: self.set_boundary = set_boundary else: raise NotImplementedError diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py index 39fa4d8..828390a 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -5,7 +5,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, ImplementationStep, @@ -20,7 +20,7 @@ def __init__( self, set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): super().__init__( set_boundary=set_boundary, @@ -34,7 +34,7 @@ def from_indices( cls, indices, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index b3f124f..f615a53 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -5,7 +5,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.stream.stream import Stream from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator.boundary_condition.boundary_condition import ( @@ -25,7 +25,7 @@ def __init__( u: tuple[float, float], equilibrium: Equilibrium, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): super().__init__( set_boundary=set_boundary, @@ -43,7 +43,7 @@ def from_indices( u: tuple[float, float], equilibrium: Equilibrium, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index 311c73f..fc883c8 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -9,7 +9,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, ImplementationStep, @@ -24,7 +24,7 @@ def __init__( self, set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): super().__init__( set_boundary=set_boundary, @@ -38,7 +38,7 @@ def from_indices( cls, indices, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index 8eb14fb..3b5b6de 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -5,7 +5,7 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.stream.stream import Stream from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, @@ -21,7 +21,7 @@ def __init__( self, set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): super().__init__( set_boundary=set_boundary, @@ -35,7 +35,7 @@ def from_indices( cls, indices, velocity_set: VelocitySet, - compute_backend: ComputeBackends = ComputeBackends.JAX, + compute_backend: ComputeBackend = ComputeBackend.JAX, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 9dfdf33..51588ac 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from jax import jit from xlb.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision from xlb.operator import Operator from functools import partial @@ -22,14 +22,14 @@ def __init__( omega=omega, velocity_set=velocity_set, compute_backend=compute_backend ) - @Operator.register_backend(ComputeBackends.JAX) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq return fout - @Operator.register_backend(ComputeBackends.WARP) + @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, *args, **kwargs): # Implementation for the Warp backend raise NotImplementedError diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 1b24a6b..3f6a6b0 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -8,7 +8,7 @@ from numba import cuda, float32 from xlb.operator import Operator from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision @@ -32,7 +32,7 @@ def __init__( self.beta = self.omega * 0.5 self.inv_beta = 1.0 / self.beta - @Operator.register_backend(ComputeBackends.JAX) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) def jax_implementation( self, @@ -75,7 +75,7 @@ def jax_implementation( return fout - @Operator.register_backend(ComputeBackends.WARP) + @Operator.register_backend(ComputeBackend.WARP) @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) def warp_implementation( self, diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py index 4bed600..bc61793 100644 --- a/xlb/operator/equilibrium/equilibrium.py +++ b/xlb/operator/equilibrium/equilibrium.py @@ -1,5 +1,6 @@ # Base class for all equilibriums from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.operator.operator import Operator @@ -11,6 +12,7 @@ class Equilibrium(Operator): def __init__( self, velocity_set: VelocitySet = None, + presision_policy=None, compute_backend=None, ): - super().__init__(velocity_set, compute_backend) + super().__init__(velocity_set, presision_policy, compute_backend) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 3dc4993..e603db0 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -1,13 +1,14 @@ +from functools import partial import jax.numpy as jnp from jax import jit +import warp as wp + from xlb.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium.equilibrium import Equilibrium -from functools import partial from xlb.operator import Operator from xlb.global_config import GlobalConfig - class QuadraticEquilibrium(Equilibrium): """ Quadratic equilibrium of Boltzmann equation using hermite polynomials. @@ -19,19 +20,92 @@ class QuadraticEquilibrium(Equilibrium): def __init__( self, velocity_set: VelocitySet = None, + precision_policy=None, compute_backend=None, ): - velocity_set = velocity_set or GlobalConfig.velocity_set - compute_backend = compute_backend or GlobalConfig.compute_backend + super().__init__(velocity_set, precision_policy, compute_backend) - super().__init__(velocity_set, compute_backend) + # Construct the warp implementation + if self.compute_backend == ComputeBackend.WARP: + self._warp_functional = self._construct_warp_functional() + self._warp_kernel = self._construct_warp_kernel() - @Operator.register_backend(ComputeBackends.JAX) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1)) - feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq + + def _construct_warp_functional(self): + # Make constants for warp + _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _q = wp.constant(self.velocity_set.q) + _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) + + @wp.func + def equilibrium(rho: self.compute_dtype, u: self._warp_u_vec) -> self._warp_lattice_vec: + feq = self._warp_lattice_vec() # empty lattice vector + for i in range(_q): + # Compute cu + cu = self.compute_dtype(0.0) + for d in range(_q): + if _c[i, d] == 1: + cu += u[d] + elif _c[i, d] == -1: + cu -= u[d] + cu *= self.compute_dtype(3.0) + + # Compute usqr + usqr = 1.5 * wp.dot(u, u) + + # Compute feq + feq[i] = rho * _w[i] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + + return feq + + return equilibrium + + def _construct_warp_kernel(self): + # Make constants for warp + _d = wp.constant(self.velocity_set.d) + _q = wp.constant(self.velocity_set.q) + + @wp.kernel + def warp_kernel( + rho: self._warp_array_type, + u: self._warp_array_type, + f: self._warp_array_type + ): + # Get the global index + i, j, k = wp.tid() + + # Get the equilibrium + _u = self._warp_u_vec() + for d in range(_d): + _u[i] = u[d, i, j, k] + _rho = rho[0, i, j, k] + feq = self._warp_functional(_rho, _u) + + # Set the output + for l in range(_q): + f[l, i, j, k] = feq[l] + + return warp_kernel + + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, rho, u, f): + # Launch the warp kernel + wp.launch( + self._warp_kernel, + inputs=[ + rho, + u, + f, + ], + dim=rho.shape, + ) + return f diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 733f8f7..8fefde5 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -1,7 +1,7 @@ # Base class for all equilibriums from xlb.global_config import GlobalConfig from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator @@ -29,7 +29,7 @@ def __init__( super().__init__(velocity_set, compute_backend) - @Operator.register_backend(ComputeBackends.JAX) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), inline=True) def jax_implementation(self, f): """ diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 700de7e..1de9b8c 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,7 +1,10 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) +import warp as wp -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy, Precision from xlb.global_config import GlobalConfig + class Operator: """ Base class for all operators, collision, streaming, equilibrium, etc. @@ -11,12 +14,13 @@ class Operator: _backends = {} - def __init__(self, velocity_set, compute_backend): + def __init__(self, velocity_set, precision_policy, compute_backend): # Set the default values from the global config self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.precision_policy = precision_policy or GlobalConfig.precision_policy self.compute_backend = compute_backend or GlobalConfig.compute_backend - if self.compute_backend not in ComputeBackends: + if self.compute_backend not in ComputeBackend: raise ValueError(f"Compute backend {compute_backend} is not supported") @classmethod @@ -74,3 +78,73 @@ def _is_method_overridden(self, method_name): def __repr__(self): return f"{self.__class__.__name__}()" + + @property + def backend(self): + """ + Returns the compute backend object for the operator (e.g. jax, warp) + This should be used with caution as all backends may not have the same API. + """ + if self.compute_backend == ComputeBackend.JAX: + import jax as backend + elif self.compute_backend == ComputeBackend.WARP: + import warp as backend + return backend + + @property + def compute_dtype(self): + """ + Returns the compute dtype + """ + if self.precision_policy.compute_precision == Precision.FP64: + return self.backend.float64 + elif self.precision_policy.compute_precision == Precision.FP32: + return self.backend.float32 + elif self.precision_policy.compute_precision == Precision.FP16: + return self.backend.float16 + + @property + def store_dtype(self): + """ + Returns the store dtype + """ + if self.precision_policy.store_precision == Precision.FP64: + return self.backend.float64 + elif self.precision_policy.store_precision == Precision.FP32: + return self.backend.float32 + elif self.precision_policy.store_precision == Precision.FP16: + return self.backend.float16 + + ### WARP specific types ### + # These are used to define the types for the warp backend + # TODO: There might be a better place to put these + @property + def _warp_u_vec(self): + """ + Returns the warp type for velocity + """ + return wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + + @property + def _warp_lattice_vec(self): + """ + Returns the warp type for the lattice + """ + return wp.vec(len(self.velocity_set.w), dtype=self.compute_dtype) + + @property + def _warp_stream_mat(self): + """ + Returns the warp type for the streaming matrix (c) + """ + return wp.mat((self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype) + + @property + def _warp_array_type(self): + """ + Returns the warp type for arrays + """ + if self.velocity_set.d == 2: + return wp.array3d(dtype=self.store_dtype) + elif self.velocity_set.d == 3: + return wp.array4d(dtype=self.store_dtype) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index e942961..3ccf5cc 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -5,7 +5,7 @@ from jax import jit, vmap, lax from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from jax.experimental.shard_map import shard_map from jax.sharding import PartitionSpec as P @@ -20,7 +20,7 @@ def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None) self.grid = grid super().__init__(velocity_set, compute_backend) - @Operator.register_backend(ComputeBackends.JAX) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def jax_implementation(self, f): """ diff --git a/xlb/physics_type.py b/xlb/physics_type.py index 73d9e70..586eefe 100644 --- a/xlb/physics_type.py +++ b/xlb/physics_type.py @@ -1,7 +1,7 @@ # Enum used to keep track of the physics types supported by different operators -from enum import Enum +from enum import Enum, auto class PhysicsType(Enum): - NSE = 1 # Navier-Stokes Equations - ADE = 2 # Advection-Diffusion Equations + NSE = auto() # Navier-Stokes Equations + ADE = auto() # Advection-Diffusion Equations diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py new file mode 100644 index 0000000..7f3b075 --- /dev/null +++ b/xlb/precision_policy.py @@ -0,0 +1,45 @@ +# Enum for precision policy + +from enum import Enum, auto + +class Precision(Enum): + FP64 = auto() + FP32 = auto() + FP16 = auto() + +class PrecisionPolicy(Enum): + FP64FP64 = auto() + FP64FP32 = auto() + FP64FP16 = auto() + FP32FP32 = auto() + FP32FP16 = auto() + + @property + def compute_precision(self): + if self == PrecisionPolicy.FP64FP64: + return Precision.FP64 + elif self == PrecisionPolicy.FP64FP32: + return Precision.FP32 + elif self == PrecisionPolicy.FP64FP16: + return Precision.FP16 + elif self == PrecisionPolicy.FP32FP32: + return Precision.FP32 + elif self == PrecisionPolicy.FP32FP16: + return Precision.FP16 + else: + raise ValueError("Invalid precision policy") + + @property + def store_precision(self): + if self == PrecisionPolicy.FP64FP64: + return Precision.FP64 + elif self == PrecisionPolicy.FP64FP32: + return Precision.FP32 + elif self == PrecisionPolicy.FP64FP16: + return Precision.FP16 + elif self == PrecisionPolicy.FP32FP32: + return Precision.FP32 + elif self == PrecisionPolicy.FP32FP16: + return Precision.FP16 + else: + raise ValueError("Invalid precision policy") diff --git a/xlb/precision_policy/__init__.py b/xlb/precision_policy/__init__.py deleted file mode 100644 index b1555aa..0000000 --- a/xlb/precision_policy/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from xlb.precision_policy.precision_policy import Fp64Fp64, Fp64Fp32, Fp32Fp32, Fp64Fp16, Fp32Fp16 \ No newline at end of file diff --git a/xlb/precision_policy/base_precision_policy.py b/xlb/precision_policy/base_precision_policy.py deleted file mode 100644 index 397e48d..0000000 --- a/xlb/precision_policy/base_precision_policy.py +++ /dev/null @@ -1,14 +0,0 @@ -from abc import ABC, abstractmethod - -class PrecisionPolicy(ABC): - def __init__(self, compute_dtype, storage_dtype): - self.compute_dtype = compute_dtype - self.storage_dtype = storage_dtype - - @abstractmethod - def cast_to_compute(self, array): - pass - - @abstractmethod - def cast_to_store(self, array): - pass diff --git a/xlb/precision_policy/jax_precision_policy.py b/xlb/precision_policy/jax_precision_policy.py deleted file mode 100644 index bd63adf..0000000 --- a/xlb/precision_policy/jax_precision_policy.py +++ /dev/null @@ -1,72 +0,0 @@ -from xlb.precision_policy.base_precision_policy import PrecisionPolicy -from jax import jit -from functools import partial -import jax.numpy as jnp - - -class JaxPrecisionPolicy(PrecisionPolicy): - """ - JAX-specific precision policy. - """ - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def cast_to_compute(self, array): - return array.astype(self.compute_dtype) - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def cast_to_store(self, array): - return array.astype(self.storage_dtype) - - -class JaxFp32Fp32(JaxPrecisionPolicy): - """ - Precision policy for lattice Boltzmann method with computation and storage - precision both set to float32. - - Parameters - ---------- - None - """ - - def __init__(self): - super().__init__(jnp.float32, jnp.float32) - - -class JaxFp64Fp64(JaxPrecisionPolicy): - """ - Precision policy for lattice Boltzmann method with computation and storage - precision both set to float64. - """ - - def __init__(self): - super().__init__(jnp.float64, jnp.float64) - - -class JaxFp64Fp32(JaxPrecisionPolicy): - """ - Precision policy for lattice Boltzmann method with computation precision - set to float64 and storage precision set to float32. - """ - - def __init__(self): - super().__init__(jnp.float64, jnp.float32) - - -class JaxFp64Fp16(JaxPrecisionPolicy): - """ - Precision policy for lattice Boltzmann method with computation precision - set to float64 and storage precision set to float16. - """ - - def __init__(self): - super().__init__(jnp.float64, jnp.float16) - - -class JaxFp32Fp16(JaxPrecisionPolicy): - """ - Precision policy for lattice Boltzmann method with computation precision - set to float32 and storage precision set to float16. - """ - - def __init__(self): - super().__init__(jnp.float32, jnp.float16) diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py deleted file mode 100644 index 75f4434..0000000 --- a/xlb/precision_policy/precision_policy.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import ABC, abstractmethod -from xlb.compute_backends import ComputeBackends -from xlb.global_config import GlobalConfig - -from xlb.precision_policy.jax_precision_policy import ( - JaxFp32Fp32, - JaxFp32Fp16, - JaxFp64Fp64, - JaxFp64Fp32, - JaxFp64Fp16, -) - - -class Fp64Fp64: - def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: - return JaxFp64Fp64() - else: - raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" - ) - -class Fp64Fp32: - def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: - return JaxFp64Fp32() - else: - raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" - ) - - -class Fp32Fp32: - def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: - return JaxFp32Fp32() - else: - raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" - ) - - -class Fp64Fp16: - def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: - return JaxFp64Fp16() - else: - raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" - ) - -class Fp32Fp16: - def __new__(cls): - if GlobalConfig.compute_backend == ComputeBackends.JAX: - return JaxFp32Fp16() - else: - raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" - ) \ No newline at end of file diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 251fe37..c627f51 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -4,7 +4,7 @@ from jax import jit from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.boundary_condition import ImplementationStep from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.collision import BGK, KBC @@ -56,7 +56,7 @@ def create_operators(self): velocity_set=self.velocity_set, compute_backend=self.compute_backend ) - @Operator.register_backend(ComputeBackends.JAX) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) def step(self, f, timestep): """ diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 9f4deaa..7c15344 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -1,6 +1,6 @@ # Base class for all stepper operators -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.operator.boundary_condition import ImplementationStep from xlb.global_config import GlobalConfig from xlb.operator import Operator diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 05ba2a2..ed4cae8 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -5,8 +5,8 @@ from functools import partial import jax.numpy as jnp from jax import jit, vmap -import numba -from numba import cuda, float32, int32 + +import warp as wp class VelocitySet(object): @@ -44,58 +44,14 @@ def __init__(self, d, q, c, w): self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() - @partial(jit, static_argnums=(0,)) - def momentum_flux_jax(self, fneq): - """ - This function computes the momentum flux, which is the product of the non-equilibrium - distribution functions (fneq) and the lattice moments (cc). - - The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann - Method (LBM). - - Parameters - ---------- - fneq: jax.numpy.ndarray - The non-equilibrium distribution functions. - - Returns - ------- - jax.numpy.ndarray - The computed momentum flux. - """ - - return jnp.dot(fneq, self.cc) - - def momentum_flux_numba(self): - """ - This function computes the momentum flux, which is the product of the non-equilibrium - """ - raise NotImplementedError - - @partial(jit, static_argnums=(0,)) - def decompose_shear_jax(self, fneq): - """ - Decompose fneq into shear components for D3Q27 lattice. - - TODO: add generali + def warp_lattice_vec(self, dtype): + return wp.vec(len(self.c), dtype=dtype) - Parameters - ---------- - fneq : jax.numpy.ndarray - Non-equilibrium distribution function. + def warp_u_vec(self, dtype): + return wp.vec(self.d, dtype=dtype) - Returns - ------- - jax.numpy.ndarray - Shear components of fneq. - """ - raise NotImplementedError - - def decompose_shear_numba(self): - """ - Decompose fneq into shear components for D3Q27 lattice. - """ - raise NotImplementedError + def warp_stream_mat(self, dtype): + return wp.mat((self.q, self.d), dtype=dtype) def _construct_lattice_moment(self): """ From 02fcbf5109700f901ff98947996e778ac7d2d341 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 12 Feb 2024 16:10:12 -0800 Subject: [PATCH 010/144] almost finished bc --- examples/warp_backend/equilibrium.py | 5 + xlb/operator/boundary_condition/__init__.py | 5 +- .../boundary_condition/boundary_condition.py | 112 +++++++--------- xlb/operator/boundary_condition/do_nothing.py | 24 ++-- .../equilibrium_boundary.py | 36 +++--- .../boundary_condition/full_bounce_back.py | 113 +++++++++++++--- .../boundary_condition/halfway_bounce_back.py | 56 +++++--- xlb/operator/collision/bgk.py | 65 ++++++++-- xlb/operator/collision/collision.py | 3 +- xlb/operator/collision/kbc.py | 7 +- xlb/operator/equilibrium/equilibrium.py | 4 +- .../equilibrium/quadratic_equilibrium.py | 59 ++++----- xlb/operator/initializer/__init__.py | 2 +- xlb/operator/initializer/equilibrium_init.py | 2 +- xlb/operator/macroscopic/macroscopic.py | 83 +++++++++--- xlb/operator/operator.py | 54 +++++++- xlb/operator/stream/stream.py | 121 ++++++++++-------- 17 files changed, 492 insertions(+), 259 deletions(-) diff --git a/examples/warp_backend/equilibrium.py b/examples/warp_backend/equilibrium.py index 845bc90..a99ace4 100644 --- a/examples/warp_backend/equilibrium.py +++ b/examples/warp_backend/equilibrium.py @@ -22,6 +22,10 @@ velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) # Make warp arrays nr = 128 @@ -31,3 +35,4 @@ # Run simulation equilibrium(rho, u, f) + macroscopic(f, rho, u) diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 9937bc3..3d10b59 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,4 +1,7 @@ -from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition, ImplementationStep +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) from xlb.operator.boundary_condition.full_bounce_back import FullBounceBack from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBack from xlb.operator.boundary_condition.do_nothing import DoNothing diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 7e8909a..f5ad4dd 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -10,91 +10,75 @@ from xlb.operator.operator import Operator from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend + # Enum for implementation step class ImplementationStep(Enum): COLLISION = 1 STREAMING = 2 + class BoundaryCondition(Operator): """ Base class for boundary conditions in a LBM simulation. """ def __init__( - self, - set_boundary, - implementation_step: ImplementationStep, - velocity_set: VelocitySet, - compute_backend: ComputeBackend.JAX, - ): - super().__init__(velocity_set, compute_backend) + self, + implementation_step: ImplementationStep, + boundary_masker: BoundaryMasker, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, precision_policy, compute_backend) # Set implementation step self.implementation_step = implementation_step - # Set boundary function - if compute_backend == ComputeBackend.JAX: - self.set_boundary = set_boundary - else: - raise NotImplementedError + # Set boundary masker + self.boundary_masker = boundary_masker @classmethod - def from_indices(cls, indices, implementation_step: ImplementationStep): - """ - Creates a boundary condition from a list of indices. - """ - raise NotImplementedError - - @partial(jit, static_argnums=(0,)) - def apply_jax(self, f_pre, f_post, mask, velocity_set: VelocitySet): - """ - Applies the boundary condition. - """ - pass - - @staticmethod - def _indices_to_tuple(indices): + def from_indices( + cls, + implementation_step: ImplementationStep, + indices: np.ndarray, + stream_indices: bool, + velocity_set, + precision_policy, + compute_backend, + ): """ - Converts a tensor of indices to a tuple for indexing - TODO: Might be better to index + Create a boundary condition from indices and boundary id. """ - return tuple([indices[:, i] for i in range(indices.shape[1])]) + # Create boundary mask + boundary_mask = IndicesBoundaryMask( + indices, stream_indices, velocity_set, precision_policy, compute_backend + ) + + # Create boundary condition + return cls( + implementation_step, + boundary_mask, + velocity_set, + precision_policy, + compute_backend, + ) - @staticmethod - def _set_boundary_from_indices(indices): + @classmethod + def from_stl( + cls, + implementation_step: ImplementationStep, + stl_file: str, + stream_indices: bool, + velocity_set, + precision_policy, + compute_backend, + ): """ - This create the standard set_boundary function from a list of indices. - `boundary_id` is set to `id_number` at the indices and `mask` is set to `True` at the indices. - Many boundary conditions can be created from this function however some may require a custom function such as - HalfwayBounceBack. + Create a boundary condition from an STL file. """ - - # Create a mask function - def set_boundary(ijk, boundary_id, mask, id_number): - """ - Sets the mask id for the boundary condition. - - Parameters - ---------- - ijk : jnp.ndarray - Array of shape (N, N, N, 3) containing the meshgrid of lattice points. - boundary_id : jnp.ndarray - Array of shape (N, N, N) containing the boundary id. This will be modified in place and returned. - mask : jnp.ndarray - Array of shape (N, N, N, Q) containing the mask. This will be modified in place and returned. - """ - - # Get local indices from the meshgrid and the indices - local_indices = ijk[BoundaryCondition._indices_to_tuple(indices)] - - # Set the boundary id - boundary_id = boundary_id.at[BoundaryCondition._indices_to_tuple(indices)].set(id_number) - - # Set the mask - mask = mask.at[BoundaryCondition._indices_to_tuple(indices)].set(True) - - return boundary_id, mask - - return set_boundary + raise NotImplementedError diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py index 828390a..6251660 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -11,17 +11,18 @@ ImplementationStep, ) + class DoNothing(BoundaryCondition): """ A boundary condition that skips the streaming step. """ def __init__( - self, - set_boundary, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): super().__init__( set_boundary=set_boundary, implementation_step=ImplementationStep.STREAMING, @@ -31,22 +32,21 @@ def __init__( @classmethod def from_indices( - cls, - indices, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): """ Creates a boundary condition from a list of indices. """ - + return cls( set_boundary=cls._set_boundary_from_indices(indices), velocity_set=velocity_set, compute_backend=compute_backend, ) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index f615a53..fbc5418 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -13,20 +13,21 @@ ImplementationStep, ) + class EquilibriumBoundary(BoundaryCondition): """ A boundary condition that skips the streaming step. """ def __init__( - self, - set_boundary, - rho: float, - u: tuple[float, float], - equilibrium: Equilibrium, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + self, + set_boundary, + rho: float, + u: tuple[float, float], + equilibrium: Equilibrium, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): super().__init__( set_boundary=set_boundary, implementation_step=ImplementationStep.STREAMING, @@ -37,18 +38,18 @@ def __init__( @classmethod def from_indices( - cls, - indices, - rho: float, - u: tuple[float, float], - equilibrium: Equilibrium, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + cls, + indices, + rho: float, + u: tuple[float, float], + equilibrium: Equilibrium, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): """ Creates a boundary condition from a list of indices. """ - + return cls( set_boundary=cls._set_boundary_from_indices(indices), rho=rho, @@ -58,7 +59,6 @@ def from_indices( compute_backend=compute_backend, ) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index fc883c8..4cd86b6 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -14,6 +14,11 @@ BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) + class FullBounceBack(BoundaryCondition): """ @@ -21,37 +26,103 @@ class FullBounceBack(BoundaryCondition): """ def __init__( - self, - set_boundary, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + self, + boundary_masker: BoundaryMasker, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): super().__init__( - set_boundary=set_boundary, - implementation_step=ImplementationStep.COLLISION, - velocity_set=velocity_set, - compute_backend=compute_backend, + ImplementationStep.COLLISION, + boundary_masker, + velocity_set, + precision_policy, + compute_backend, ) @classmethod def from_indices( - cls, - indices, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + cls, indices: np.ndarray, velocity_set, precision_policy, compute_backend + ): """ - Creates a boundary condition from a list of indices. + Create a full bounce-back boundary condition from indices. """ - + # Create boundary mask + boundary_mask = IndicesBoundaryMask( + indices, False, velocity_set, precision_policy, compute_backend + ) + + # Create boundary condition return cls( - set_boundary=cls._set_boundary_from_indices(indices), - velocity_set=velocity_set, - compute_backend=compute_backend, + ImplementationStep.COLLISION, + boundary_mask, + velocity_set, + precision_policy, + compute_backend, ) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): - flip = jnp.repeat(boundary[..., jnp.newaxis], self.velocity_set.q, axis=-1) - flipped_f = lax.select(flip, f_pre[..., self.velocity_set.opp_indices], f_post) + flip = jnp.repeat(boundary, self.velocity_set.q, axis=-1) + flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) return flipped_f + + def _construct_warp(self): + # Make constants for warp + _opp_indices = wp.constant(self.velocity_set.opp_indices) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + mask: self._warp_bool_lattice_vec, + ): + fliped_f = self._warp_lattice_vec() + for l in range(_q): + fliped_f[l] = f_pre[_opp_indices[l]] + return fliped_f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: self._warp_array_type, + f_post: self._warp_array_type, + f: self._warp_array_type, + boundary: self._warp_bool_array_type, + mask: self._warp_bool_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Make vectors for the lattice + _f_pre = self._warp_lattice_vec() + _f_post = self._warp_lattice_vec() + _mask = self._warp_bool_lattice_vec() + for l in range(_q): + _f_pre[l] = f_pre[l, i, j, k] + _f_post[l] = f_post[l, i, j, k] + _mask[l] = mask[l, i, j, k] + + # Check if the boundary is active + if boundary[i, j, k]: + _f = functional(_f_pre, _f_post, _mask) + else: + _f = _f_post + + # Write the result to the output + for l in range(_q): + f[l, i, j, k] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary, mask): + # Launch the warp kernel + wp.launch( + self._kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index 3b5b6de..d8cb4ad 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -12,17 +12,18 @@ ImplementationStep, ) + class HalfwayBounceBack(BoundaryCondition): """ Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. """ def __init__( - self, - set_boundary, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + self, + set_boundary, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): super().__init__( set_boundary=set_boundary, implementation_step=ImplementationStep.STREAMING, @@ -32,15 +33,15 @@ def __init__( @classmethod def from_indices( - cls, - indices, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): + cls, + indices, + velocity_set: VelocitySet, + compute_backend: ComputeBackend = ComputeBackend.JAX, + ): """ Creates a boundary condition from a list of indices. """ - + # Make stream operator to get edge points stream = Stream(velocity_set=velocity_set) @@ -62,24 +63,42 @@ def set_boundary(ijk, boundary_id, mask, id_number): """ # Get local indices from the meshgrid and the indices - local_indices = ijk[tuple(s[:, 0] for s in jnp.split(indices, velocity_set.d, axis=1))] + local_indices = ijk[ + tuple(s[:, 0] for s in jnp.split(indices, velocity_set.d, axis=1)) + ] # Make mask then stream to get the edge points pre_stream_mask = jnp.zeros_like(mask) - pre_stream_mask = pre_stream_mask.at[tuple([s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)])].set(True) + pre_stream_mask = pre_stream_mask.at[ + tuple( + [s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)] + ) + ].set(True) post_stream_mask = stream(pre_stream_mask) # Set false for points inside the boundary - post_stream_mask = post_stream_mask.at[post_stream_mask[..., 0] == True].set(False) + post_stream_mask = post_stream_mask.at[ + post_stream_mask[..., 0] == True + ].set(False) # Get indices on edges edge_indices = jnp.argwhere(post_stream_mask) # Set the boundary id - boundary_id = boundary_id.at[tuple([s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)])].set(id_number) + boundary_id = boundary_id.at[ + tuple( + [s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)] + ) + ].set(id_number) # Set the mask - mask = mask.at[edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], :].set(post_stream_mask[edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], :]) + mask = mask.at[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ].set( + post_stream_mask[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ] + ) return boundary_id, mask @@ -89,9 +108,10 @@ def set_boundary(ijk, boundary_id, mask, id_number): compute_backend=compute_backend, ) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): flip_mask = boundary[..., jnp.newaxis] & mask - flipped_f = lax.select(flip_mask, f_pre[..., self.velocity_set.opp_indices], f_post) + flipped_f = lax.select( + flip_mask, f_pre[..., self.velocity_set.opp_indices], f_post + ) return flipped_f diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 51588ac..2d0a931 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -12,16 +12,6 @@ class BGK(Collision): BGK collision operator for LBM. """ - def __init__( - self, - omega: float, - velocity_set: VelocitySet = None, - compute_backend=None, - ): - super().__init__( - omega=omega, velocity_set=velocity_set, compute_backend=compute_backend - ) - @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): @@ -29,7 +19,56 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fout = f - self.omega * fneq return fout + def _construct_warp(self): + # Make constants for warp + _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _q = wp.constant(self.velocity_set.q) + _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) + _d = wp.constant(self.velocity_set.d) + + # Construct the functional + @wp.func + def functional( + f: self._warp_lattice_vec, feq: self._warp_lattice_vec + ) -> self._warp_lattice_vec: + fneq = f - feq + fout = f - self.omega * fneq + return fout + + # Construct the warp kernel + @wp.kernel + def kernel( + f: self._warp_array_type, + feq: self._warp_array_type, + fout: self._warp_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Get the equilibrium + _f = self._warp_lattice_vec() + _feq = self._warp_lattice_vec() + for l in range(_q): + _f[l] = f[l, i, j, k] + _feq[l] = feq[l, i, j, k] + _fout = functional(_f, _feq) + + # Write the result + for l in range(_q): + fout[l, i, j, k] = _fout[l] + + return functional, kernel + @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, *args, **kwargs): - # Implementation for the Warp backend - raise NotImplementedError + def warp_implementation(self, f, feq, fout): + # Launch the warp kernel + wp.launch( + self._kernel, + inputs=[ + f, + feq, + fout, + ], + dim=f.shape[1:], + ) + return fout diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 3243e6f..1fe0a5b 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -23,7 +23,8 @@ def __init__( self, omega: float, velocity_set: VelocitySet = None, + precision_policy=None, compute_backend=None, ): - super().__init__(velocity_set, compute_backend) + super().__init__(velocity_set, precision_policy, compute_backend) self.omega = omega diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 3f6a6b0..27c20b7 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -23,10 +23,14 @@ def __init__( self, omega, velocity_set: VelocitySet = None, + precision_policy=None, compute_backend=None, ): super().__init__( - omega=omega, velocity_set=velocity_set, compute_backend=compute_backend + omega=omega, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, ) self.epsilon = 1e-32 self.beta = self.omega * 0.5 @@ -53,7 +57,6 @@ def jax_implementation( Density. """ fneq = f - feq - print(self.velocity_set) if isinstance(self.velocity_set, D2Q9): shear = self.decompose_shear_d2q9_jax(fneq) delta_s = shear * rho / 4.0 # TODO: Check this diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py index bc61793..726ca37 100644 --- a/xlb/operator/equilibrium/equilibrium.py +++ b/xlb/operator/equilibrium/equilibrium.py @@ -12,7 +12,7 @@ class Equilibrium(Operator): def __init__( self, velocity_set: VelocitySet = None, - presision_policy=None, + precision_policy=None, compute_backend=None, ): - super().__init__(velocity_set, presision_policy, compute_backend) + super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index e603db0..8838150 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -9,6 +9,7 @@ from xlb.operator import Operator from xlb.global_config import GlobalConfig + class QuadraticEquilibrium(Equilibrium): """ Quadratic equilibrium of Boltzmann equation using hermite polynomials. @@ -17,19 +18,6 @@ class QuadraticEquilibrium(Equilibrium): TODO: move this to a separate file and lower and higher order equilibriums """ - def __init__( - self, - velocity_set: VelocitySet = None, - precision_policy=None, - compute_backend=None, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Construct the warp implementation - if self.compute_backend == ComputeBackend.WARP: - self._warp_functional = self._construct_warp_functional() - self._warp_kernel = self._construct_warp_kernel() - @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) def jax_implementation(self, rho, u): @@ -39,22 +27,26 @@ def jax_implementation(self, rho, u): feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq - def _construct_warp_functional(self): + def _construct_warp(self): # Make constants for warp _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) _q = wp.constant(self.velocity_set.q) _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) + _d = wp.constant(self.velocity_set.d) + # Construct the equilibrium functional @wp.func - def equilibrium(rho: self.compute_dtype, u: self._warp_u_vec) -> self._warp_lattice_vec: - feq = self._warp_lattice_vec() # empty lattice vector - for i in range(_q): + def functional( + rho: self.compute_dtype, u: self._warp_u_vec + ) -> self._warp_lattice_vec: + feq = self._warp_lattice_vec() # empty lattice vector + for l in range(_q): # Compute cu cu = self.compute_dtype(0.0) - for d in range(_q): - if _c[i, d] == 1: + for d in range(_d): + if _c[l, d] == 1: cu += u[d] - elif _c[i, d] == -1: + elif _c[l, d] == -1: cu -= u[d] cu *= self.compute_dtype(3.0) @@ -62,22 +54,16 @@ def equilibrium(rho: self.compute_dtype, u: self._warp_u_vec) -> self._warp_latt usqr = 1.5 * wp.dot(u, u) # Compute feq - feq[i] = rho * _w[i] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + feq[l] = rho * _w[l] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq - return equilibrium - - def _construct_warp_kernel(self): - # Make constants for warp - _d = wp.constant(self.velocity_set.d) - _q = wp.constant(self.velocity_set.q) - + # Construct the warp kernel @wp.kernel - def warp_kernel( - rho: self._warp_array_type, - u: self._warp_array_type, - f: self._warp_array_type + def kernel( + rho: self._warp_array_type, + u: self._warp_array_type, + f: self._warp_array_type, ): # Get the global index i, j, k = wp.tid() @@ -87,25 +73,24 @@ def warp_kernel( for d in range(_d): _u[i] = u[d, i, j, k] _rho = rho[0, i, j, k] - feq = self._warp_functional(_rho, _u) + feq = functional(_rho, _u) # Set the output for l in range(_q): f[l, i, j, k] = feq[l] - return warp_kernel - + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, rho, u, f): # Launch the warp kernel wp.launch( - self._warp_kernel, + self._kernel, inputs=[ rho, u, f, ], - dim=rho.shape, + dim=rho.shape[1:], ) return f diff --git a/xlb/operator/initializer/__init__.py b/xlb/operator/initializer/__init__.py index 4d3f07d..b2d14b9 100644 --- a/xlb/operator/initializer/__init__.py +++ b/xlb/operator/initializer/__init__.py @@ -1,2 +1,2 @@ from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer -from xlb.operator.initializer.const_init import ConstInitializer \ No newline at end of file +from xlb.operator.initializer.const_init import ConstInitializer diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py index bad7c85..1fb95f4 100644 --- a/xlb/operator/initializer/equilibrium_init.py +++ b/xlb/operator/initializer/equilibrium_init.py @@ -17,7 +17,7 @@ def __init__( velocity_set = velocity_set or GlobalConfig.velocity_set compute_backend = compute_backend or GlobalConfig.compute_backend local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) - + self.init_values = np.zeros( grid.global_to_local_shape(grid.pop_shape) ) + velocity_set.w.reshape(local_shape) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 8fefde5..6f06bf7 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -1,13 +1,15 @@ # Base class for all equilibriums -from xlb.global_config import GlobalConfig -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - from functools import partial import jax.numpy as jnp from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator class Macroscopic(Operator): @@ -19,16 +21,6 @@ class Macroscopic(Operator): and other physic types (e.g. temperature, electromagnetism, etc...) """ - def __init__( - self, - velocity_set: VelocitySet = None, - compute_backend=None, - ): - self.velocity_set = velocity_set or GlobalConfig.velocity_set - self.compute_backend = compute_backend or GlobalConfig.compute_backend - - super().__init__(velocity_set, compute_backend) - @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), inline=True) def jax_implementation(self, f): @@ -53,3 +45,64 @@ def jax_implementation(self, f): u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho return rho, u + + def _construct_warp(self): + # Make constants for warp + _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + + # Construct the functional + @wp.func + def functional(f: self._warp_lattice_vec): + # Compute rho and u + rho = self.compute_dtype(0.0) + u = self._warp_u_vec() + for l in range(_q): + rho += f[l] + for d in range(_d): + if _c[l, d] == 1: + u[d] += f[l] + elif _c[l, d] == -1: + u[d] -= f[l] + u /= rho + + return rho, u + # return u, rho + + # Construct the kernel + @wp.kernel + def kernel( + f: self._warp_array_type, + rho: self._warp_array_type, + u: self._warp_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Get the equilibrium + _f = self._warp_lattice_vec() + for l in range(_q): + _f[l] = f[l, i, j, k] + (_rho, _u) = functional(_f) + + # Set the output + rho[0, i, j, k] = _rho + for d in range(_d): + u[d, i, j, k] = _u[d] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho, u): + # Launch the warp kernel + wp.launch( + self._kernel, + inputs=[ + f, + rho, + u, + ], + dim=rho.shape[1:], + ) + return rho, u diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 1de9b8c..20f7947 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -5,6 +5,7 @@ from xlb.precision_policy import PrecisionPolicy, Precision from xlb.global_config import GlobalConfig + class Operator: """ Base class for all operators, collision, streaming, equilibrium, etc. @@ -20,9 +21,14 @@ def __init__(self, velocity_set, precision_policy, compute_backend): self.precision_policy = precision_policy or GlobalConfig.precision_policy self.compute_backend = compute_backend or GlobalConfig.compute_backend + # Check if the compute backend is supported if self.compute_backend not in ComputeBackend: raise ValueError(f"Compute backend {compute_backend} is not supported") + # Construct the kernel based backend functions TODO: Maybe move this to the register or something + if self.compute_backend == ComputeBackend.WARP: + self._functional, self._kernel = self._construct_warp() + @classmethod def register_backend(cls, backend_name): """ @@ -130,14 +136,30 @@ def _warp_lattice_vec(self): """ Returns the warp type for the lattice """ - return wp.vec(len(self.velocity_set.w), dtype=self.compute_dtype) + return wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @property + def _warp_int_lattice_vec(self): + """ + Returns the warp type for the streaming matrix (c) + """ + return wp.vec(self.velocity_set.q, dtype=wp.int32) + + @property + def _warp_bool_lattice_vec(self): + """ + Returns the warp type for the streaming matrix (c) + """ + return wp.vec(self.velocity_set.q, dtype=wp.bool) @property def _warp_stream_mat(self): """ Returns the warp type for the streaming matrix (c) """ - return wp.mat((self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype) + return wp.mat( + (self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype + ) @property def _warp_array_type(self): @@ -148,3 +170,31 @@ def _warp_array_type(self): return wp.array3d(dtype=self.store_dtype) elif self.velocity_set.d == 3: return wp.array4d(dtype=self.store_dtype) + + @property + def _warp_uint8_array_type(self): + """ + Returns the warp type for arrays + """ + if self.velocity_set.d == 2: + return wp.array3d(dtype=wp.bool) + elif self.velocity_set.d == 3: + return wp.array4d(dtype=wp.bool) + + @property + def _warp_bool_array_type(self): + """ + Returns the warp type for arrays + """ + if self.velocity_set.d == 2: + return wp.array3d(dtype=wp.bool) + elif self.velocity_set.d == 3: + return wp.array4d(dtype=wp.bool) + + def _construct_warp(self): + """ + Construct the warp functional and kernel of the operator + TODO: Maybe a better way to do this? + Maybe add this to the backend decorator? + """ + raise NotImplementedError("Children must implement this method") diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 3ccf5cc..af432df 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -16,10 +16,6 @@ class Stream(Operator): Base class for all streaming operators. """ - def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None): - self.grid = grid - super().__init__(velocity_set, compute_backend) - @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def jax_implementation(self, f): @@ -31,16 +27,7 @@ def jax_implementation(self, f): f: jax.numpy.ndarray The distribution function. """ - in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) - out_specs = in_specs - return shard_map( - self._streaming_jax_m, - mesh=self.grid.global_mesh, - in_specs=in_specs, - out_specs=out_specs, - )(f) - - def _streaming_jax_p(self, f): + def _streaming_jax_i(f, c): """ Perform individual streaming operation in a direction. @@ -64,41 +51,73 @@ def _streaming_jax_i(f, c): f, jnp.array(self.velocity_set.c).T ) - def _streaming_jax_m(self, f): - """ - This function performs the streaming step in the Lattice Boltzmann Method, which is - the propagation of the distribution functions in the lattice. - - To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the - distribution functions that need to be communicated to the neighboring processes. - - The function then sends the left boundary slice to the right neighboring process and the right - boundary slice to the left neighboring process. The received data is then set to the - corresponding indices in the receiving domain. - - Parameters - ---------- - f: jax.numpy.ndarray - The array holding the distribution functions for the simulation. - - Returns - ------- - jax.numpy.ndarray - The distribution functions after the streaming operation. - """ - rightPerm = [(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)] - leftPerm = [((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)] - - f = self._streaming_jax_p(f) - left_comm, right_comm = ( - f[self.velocity_set.right_indices, :1, ...], - f[self.velocity_set.left_indices, -1:, ...], - ) - left_comm, right_comm = ( - lax.ppermute(left_comm, perm=rightPerm, axis_name='x'), - lax.ppermute(right_comm, perm=leftPerm, axis_name='x'), + def _construct_warp(self): + # Make constants for warp + _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + l: int, + i: int, + j: int, + k: int, + max_i: int, + max_j: int, + max_k: int, + ): + streamed_i = i + _c[l, 0] + streamed_j = j + _c[l, 1] + streamed_k = k + _c[l, 2] + if streamed_i < 0: + streamed_i = max_i - 1 + elif streamed_i >= max_i: + streamed_i = 0 + if streamed_j < 0: + streamed_j = max_j - 1 + elif streamed_j >= max_j: + streamed_j = 0 + if streamed_k < 0: + streamed_k = max_k - 1 + elif streamed_k >= max_k: + streamed_k = 0 + return streamed_i, streamed_j, streamed_k + + # Construct the warp kernel + @wp.kernel + def kernel( + f_0: self._warp_array_type, + f_1: self._warp_array_type, + max_i: int, + max_j: int, + max_k: int, + ): + # Get the global index + i, j, k = wp.tid() + + # Set the output + for l in range(_q): + streamed_i, streamed_j, streamed_k = functional( + l, i, j, k, max_i, max_j, max_k + ) + f_1[l, streamed_i, streamed_j, streamed_k] = f_0[l, i, j, k] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_0, f_1): + # Launch the warp kernel + wp.launch( + self._kernel, + inputs=[ + f_0, + f_1, + f_0.shape[1], + f_0.shape[2], + f_0.shape[3], + ], + dim=f_0.shape[1:], ) - f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) - f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) - - return f + return f_1 From 18a13b8ef4d82bcf47e8f76dfdd149e984898175 Mon Sep 17 00:00:00 2001 From: Oliver Date: Fri, 16 Feb 2024 09:23:28 -0800 Subject: [PATCH 011/144] structure fully in --- examples/CFD/cavity2d.py | 4 +- examples/CFD_refactor/windtunnel3d.py | 139 +++++ examples/backend_comparisons/small_example.py | 27 + xlb/base.py | 474 ++++++++++++------ xlb/compute_backend.py | 2 + xlb/compute_backends.py | 8 - xlb/experimental/ooc/ooc_array.py | 9 +- xlb/experimental/ooc/out_of_core.py | 6 +- xlb/global_config.py | 1 + xlb/grid/__init__.py | 2 +- xlb/grid/grid.py | 45 +- xlb/grid/jax_grid.py | 32 +- xlb/grid_backend.py | 9 + .../boundary_masker/boundary_masker.py | 37 ++ .../indices_boundary_masker.py | 100 ++++ .../boundary_condition/halfway_bounce_back.py | 77 --- xlb/operator/collision/bgk.py | 21 +- xlb/operator/collision/kbc.py | 8 +- xlb/operator/initializer/const_init.py | 14 +- xlb/operator/initializer/initializer.py | 13 + xlb/operator/macroscopic/macroscopic.py | 56 +-- .../precision_caster/precision_caster.py | 98 ++++ xlb/operator/stepper/nse.py | 218 ++++++++ xlb/operator/stepper/stepper.py | 141 ++++++ xlb/operator/stream/stream.py | 24 +- xlb/physics_type.py | 1 + xlb/precision_policy.py | 2 + xlb/solver/nse.py | 237 ++++----- xlb/solver/solver.py | 32 +- xlb/utils/__init__.py | 10 +- xlb/utils/utils.py | 1 + 31 files changed, 1347 insertions(+), 501 deletions(-) create mode 100644 examples/CFD_refactor/windtunnel3d.py delete mode 100644 xlb/compute_backends.py create mode 100644 xlb/grid_backend.py create mode 100644 xlb/operator/boundary_condition/boundary_masker/boundary_masker.py create mode 100644 xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py create mode 100644 xlb/operator/initializer/initializer.py create mode 100644 xlb/operator/precision_caster/precision_caster.py create mode 100644 xlb/operator/stepper/nse.py create mode 100644 xlb/operator/stepper/stepper.py diff --git a/examples/CFD/cavity2d.py b/examples/CFD/cavity2d.py index 28536e4..8d1ffc8 100644 --- a/examples/CFD/cavity2d.py +++ b/examples/CFD/cavity2d.py @@ -21,8 +21,8 @@ import jax.numpy as jnp import os -from src.boundary_conditions import * -from src.models import BGKSim, KBCSim +from src. import * +from src.solver import BGKSim, KBCSim from src.lattice import LatticeD2Q9 from src.utils import * diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py new file mode 100644 index 0000000..1de0e2d --- /dev/null +++ b/examples/CFD_refactor/windtunnel3d.py @@ -0,0 +1,139 @@ +import os +import jax +import trimesh +from time import time +import numpy as np +import jax.numpy as jnp +from jax import config + +from xlb.solver import IncompressibleNavierStokesSolver +from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC + + +class WindTunnel(IncompressibleNavierStokesSolver): + """ + This class extends the IncompressibleNavierStokesSolver class to define the boundary conditions for the wind tunnel simulation. + Units are in meters, seconds, and kilograms. + """ + + def __init__( + self, + stl_filename: str + stl_center: tuple[float, float, float] = (0.0, 0.0, 0.0), # m + inlet_velocity: float = 27.78 # m/s + lower_bounds: tuple[float, float, float] = (0.0, 0.0, 0.0), # m + upper_bounds: tuple[float, float, float] = (1.0, 0.5, 0.5), # m + dx: float = 0.01, # m + viscosity: float = 1.42e-5, # air at 20 degrees Celsius + density: float = 1.2754, # kg/m^3 + ): + + + + omega: float, + shape: tuple[int, int, int], + collision="BGK", + equilibrium="Quadratic", + boundary_conditions=[], + initializer=None, + forcing=None, + velocity_set: VelocitySet = None, + precision_policy=None, + compute_backend=None, + grid_backend=None, + grid_configs={}, + ): + + super().__init__(**kwargs) + + def voxelize_stl(self, stl_filename, length_lbm_unit): + mesh = trimesh.load_mesh(stl_filename, process=False) + length_phys_unit = mesh.extents.max() + pitch = length_phys_unit/length_lbm_unit + mesh_voxelized = mesh.voxelized(pitch=pitch) + mesh_matrix = mesh_voxelized.matrix + return mesh_matrix, pitch + + def set_boundary_conditions(self): + print('Voxelizing mesh...') + time_start = time() + stl_filename = 'stl-files/DrivAer-Notchback.stl' + car_length_lbm_unit = self.nx / 4 + car_voxelized, pitch = voxelize_stl(stl_filename, car_length_lbm_unit) + car_matrix = car_voxelized.matrix + print('Voxelization time for pitch={}: {} seconds'.format(pitch, time() - time_start)) + print("Car matrix shape: ", car_matrix.shape) + + self.car_area = np.prod(car_matrix.shape[1:]) + tx, ty, tz = np.array([nx, ny, nz]) - car_matrix.shape + shift = [tx//4, ty//2, 0] + car_indices = np.argwhere(car_matrix) + shift + self.BCs.append(BounceBackHalfway(tuple(car_indices.T), self.gridInfo, self.precisionPolicy)) + + wall = np.concatenate((self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'], + self.boundingBoxIndices['front'], self.boundingBoxIndices['back'])) + self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) + + doNothing = self.boundingBoxIndices['right'] + self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) + self.BCs[-1].implementationStep = 'PostCollision' + # rho_outlet = np.ones(doNothing.shape[0], dtype=self.precisionPolicy.compute_dtype) + # self.BCs.append(ZouHe(tuple(doNothing.T), + # self.gridInfo, + # self.precisionPolicy, + # 'pressure', rho_outlet)) + + inlet = self.boundingBoxIndices['left'] + rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) + vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) + + vel_inlet[:, 0] = prescribed_vel + self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) + # self.BCs.append(ZouHe(tuple(inlet.T), + # self.gridInfo, + # self.precisionPolicy, + # 'velocity', vel_inlet)) + + def output_data(self, **kwargs): + # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) + rho = np.array(kwargs['rho'][..., 1:-1, 1:-1, :]) + u = np.array(kwargs['u'][..., 1:-1, 1:-1, :]) + timestep = kwargs['timestep'] + u_prev = kwargs['u_prev'][..., 1:-1, 1:-1, :] + + # compute lift and drag over the car + car = self.BCs[0] + boundary_force = car.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) + boundary_force = np.sum(boundary_force, axis=0) + drag = np.sqrt(boundary_force[0]**2 + boundary_force[1]**2) #xy-plane + lift = boundary_force[2] #z-direction + cd = 2. * drag / (prescribed_vel ** 2 * self.car_area) + cl = 2. * lift / (prescribed_vel ** 2 * self.car_area) + + u_old = np.linalg.norm(u_prev, axis=2) + u_new = np.linalg.norm(u, axis=2) + + err = np.sum(np.abs(u_old - u_new)) + print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) + fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} + save_fields_vtk(timestep, fields) + +if __name__ == '__main__': + precision = 'f32/f32' + lattice = LatticeD3Q27(precision) + + nx = 601 + ny = 351 + nz = 251 + + Re = 50000.0 + prescribed_vel = 0.05 + clength = nx - 1 + + visc = prescribed_vel * clength / Re + omega = 1.0 / (3. * visc + 0.5) + + os.system('rm -rf ./*.vtk && rm -rf ./*.png') + + sim = Car(**kwargs) + sim.run(200000) diff --git a/examples/backend_comparisons/small_example.py b/examples/backend_comparisons/small_example.py index 23ca639..6d6213b 100644 --- a/examples/backend_comparisons/small_example.py +++ b/examples/backend_comparisons/small_example.py @@ -79,6 +79,25 @@ def compute_u_and_p(f: lattice_vec): u /= p return u, p + # bc function + @wp.func + def bc_0(pre_f: lattice_vec, post_f: lattice_vec): + return pre_f + @wp.func + def bc_1(pre_f: lattice_vec, post_f: lattice_vec): + return post_f + tup_bc = tuple([bc_0, bc_1]) + single_bc = bc_0 + for bc in tup_bc: + def make_bc(bc, prev_bc): + @wp.func + def _bc(pre_f: lattice_vec, post_f: lattice_vec): + pre_f = prev_bc(pre_f, post_f) + post_f = single_bc(pre_f, post_f) + return bc(pre_f, post_f) + return _bc + single_bc = make_bc(bc, single_bc) + # Make function for getting stream index @wp.func def get_streamed_index( @@ -136,6 +155,14 @@ def collide_stream( # Compute equilibrium feq = compute_feq(p, uxu, exu) + # Set bc + if x == 0: + #tup_bc[0](feq, f) + bc_0(feq, f) + if x == width - 1: + bc_1(feq, f) + #tup_bc[1](feq, f) + # Set value new_f = f - (f - feq) / tau for i in range(q): diff --git a/xlb/base.py b/xlb/base.py index 7fb4f57..fa99d1d 100644 --- a/xlb/base.py +++ b/xlb/base.py @@ -23,14 +23,15 @@ # Local/Custom Libraries from src.utils import downsample_field -jax.config.update("jax_spmd_mode", 'allow_all') +jax.config.update("jax_spmd_mode", "allow_all") # Disables annoying TF warnings -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + class LBMBase(object): """ LBMBase: A class that represents a base for Lattice Boltzmann Method simulation. - + Parameters ---------- lattice (object): The lattice object that contains the lattice structure and weights. @@ -40,7 +41,7 @@ class LBMBase(object): nz (int, optional): Number of grid points in the z-direction. Defaults to 0. precision (str, optional): A string specifying the precision used for the simulation. Defaults to "f32/f32". """ - + def __init__(self, **kwargs): self.omega = kwargs.get("omega") self.nx = kwargs.get("nx") @@ -49,12 +50,15 @@ def __init__(self, **kwargs): self.precision = kwargs.get("precision") computedType, storedType = self.set_precisions(self.precision) - self.precisionPolicy = jmp.Policy(compute_dtype=computedType, - param_dtype=computedType, output_dtype=storedType) - + self.precisionPolicy = jmp.Policy( + compute_dtype=computedType, + param_dtype=computedType, + output_dtype=storedType, + ) + self.lattice = kwargs.get("lattice") self.checkpointRate = kwargs.get("checkpoint_rate", 0) - self.checkpointDir = kwargs.get("checkpoint_dir", './checkpoints') + self.checkpointDir = kwargs.get("checkpoint_dir", "./checkpoints") self.downsamplingFactor = kwargs.get("downsampling_factor", 1) self.printInfoRate = kwargs.get("print_info_rate", 100) self.ioRate = kwargs.get("io_rate", 0) @@ -72,8 +76,10 @@ def __init__(self, **kwargs): # Check for distributed mode if self.nDevices > jax.local_device_count(): - print("WARNING: Running in distributed mode. Make sure that jax.distributed.initialize is called before performing any JAX computations.") - + print( + "WARNING: Running in distributed mode. Make sure that jax.distributed.initialize is called before performing any JAX computations." + ) + self.c = self.lattice.c self.q = self.lattice.q self.w = self.lattice.w @@ -81,34 +87,44 @@ def __init__(self, **kwargs): # Set the checkpoint manager if self.checkpointRate > 0: - mngr_options = orb.CheckpointManagerOptions(save_interval_steps=self.checkpointRate, max_to_keep=1) - self.mngr = orb.CheckpointManager(self.checkpointDir, orb.PyTreeCheckpointer(), options=mngr_options) + mngr_options = orb.CheckpointManagerOptions( + save_interval_steps=self.checkpointRate, max_to_keep=1 + ) + self.mngr = orb.CheckpointManager( + self.checkpointDir, orb.PyTreeCheckpointer(), options=mngr_options + ) else: self.mngr = None - + # Adjust the number of grid points in the x direction, if necessary. # If the number of grid points is not divisible by the number of devices # it increases the number of grid points to the next multiple of the number of devices. # This is done in order to accommodate the domain sharding per XLA device nx, ny, nz = kwargs.get("nx"), kwargs.get("ny"), kwargs.get("nz") if None in {nx, ny, nz}: - raise ValueError("nx, ny, and nz must be provided. For 2D examples, nz must be set to 0.") + raise ValueError( + "nx, ny, and nz must be provided. For 2D examples, nz must be set to 0." + ) self.nx = nx if nx % self.nDevices: self.nx = nx + (self.nDevices - nx % self.nDevices) - print("WARNING: nx increased from {} to {} in order to accommodate domain sharding per XLA device.".format(nx, self.nx)) + print( + "WARNING: nx increased from {} to {} in order to accommodate domain sharding per XLA device.".format( + nx, self.nx + ) + ) self.ny = ny self.nz = nz self.show_simulation_parameters() - + # Store grid information self.gridInfo = { "nx": self.nx, "ny": self.ny, "nz": self.nz, "dim": self.lattice.d, - "lattice": self.lattice + "lattice": self.lattice, } P = PartitionSpec @@ -124,8 +140,15 @@ def __init__(self, **kwargs): self.mesh = Mesh(self.devices, axis_names=("x", "y", "value")) self.sharding = NamedSharding(self.mesh, P("x", "y", "value")) - self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh, - in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False)) + self.streaming = jit( + shard_map( + self.streaming_m, + mesh=self.mesh, + in_specs=P("x", None, None), + out_specs=P("x", None, None), + check_rep=False, + ) + ) # Set up the sharding and streaming for 2D and 3D simulations elif self.dim == 3: @@ -133,14 +156,21 @@ def __init__(self, **kwargs): self.mesh = Mesh(self.devices, axis_names=("x", "y", "z", "value")) self.sharding = NamedSharding(self.mesh, P("x", "y", "z", "value")) - self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh, - in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False)) + self.streaming = jit( + shard_map( + self.streaming_m, + mesh=self.mesh, + in_specs=P("x", None, None, None), + out_specs=P("x", None, None, None), + check_rep=False, + ) + ) else: raise ValueError(f"dim = {self.dim} not supported") - + # Compute the bounding box indices for boundary conditions - self.boundingBoxIndices= self.bounding_box_indices() + self.boundingBoxIndices = self.bounding_box_indices() # Create boundary data for the simulation self._create_boundary_data() self.force = self.get_force() @@ -153,11 +183,13 @@ def lattice(self): def lattice(self, value): if value is None: raise ValueError("Lattice type must be provided.") - if self.nz == 0 and value.name not in ['D2Q9']: + if self.nz == 0 and value.name not in ["D2Q9"]: raise ValueError("For 2D simulations, lattice type must be LatticeD2Q9.") - if self.nz != 0 and value.name not in ['D3Q19', 'D3Q27']: - raise ValueError("For 3D simulations, lattice type must be LatticeD3Q19, or LatticeD3Q27.") - + if self.nz != 0 and value.name not in ["D3Q19", "D3Q27"]: + raise ValueError( + "For 3D simulations, lattice type must be LatticeD3Q19, or LatticeD3Q27." + ) + self._lattice = value @property @@ -310,42 +342,60 @@ def nDevices(self, value): def show_simulation_parameters(self): attributes_to_show = [ - 'omega', 'nx', 'ny', 'nz', 'dim', 'precision', 'lattice', - 'checkpointRate', 'checkpointDir', 'downsamplingFactor', - 'printInfoRate', 'ioRate', 'computeMLUPS', - 'restore_checkpoint', 'backend', 'nDevices' + "omega", + "nx", + "ny", + "nz", + "dim", + "precision", + "lattice", + "checkpointRate", + "checkpointDir", + "downsamplingFactor", + "printInfoRate", + "ioRate", + "computeMLUPS", + "restore_checkpoint", + "backend", + "nDevices", ] descriptive_names = { - 'omega': 'Omega', - 'nx': 'Grid Points in X', - 'ny': 'Grid Points in Y', - 'nz': 'Grid Points in Z', - 'dim': 'Dimensionality', - 'precision': 'Precision Policy', - 'lattice': 'Lattice Type', - 'checkpointRate': 'Checkpoint Rate', - 'checkpointDir': 'Checkpoint Directory', - 'downsamplingFactor': 'Downsampling Factor', - 'printInfoRate': 'Print Info Rate', - 'ioRate': 'I/O Rate', - 'computeMLUPS': 'Compute MLUPS', - 'restore_checkpoint': 'Restore Checkpoint', - 'backend': 'Backend', - 'nDevices': 'Number of Devices' + "omega": "Omega", + "nx": "Grid Points in X", + "ny": "Grid Points in Y", + "nz": "Grid Points in Z", + "dim": "Dimensionality", + "precision": "Precision Policy", + "lattice": "Lattice Type", + "checkpointRate": "Checkpoint Rate", + "checkpointDir": "Checkpoint Directory", + "downsamplingFactor": "Downsampling Factor", + "printInfoRate": "Print Info Rate", + "ioRate": "I/O Rate", + "computeMLUPS": "Compute MLUPS", + "restore_checkpoint": "Restore Checkpoint", + "backend": "Backend", + "nDevices": "Number of Devices", } simulation_name = self.__class__.__name__ - - print(colored(f'**** Simulation Parameters for {simulation_name} ****', 'green')) - + + print( + colored(f"**** Simulation Parameters for {simulation_name} ****", "green") + ) + header = f"{colored('Parameter', 'blue'):>30} | {colored('Value', 'yellow')}" print(header) - print('-' * 50) - + print("-" * 50) + for attr in attributes_to_show: - value = getattr(self, attr, 'Attribute not set') - descriptive_name = descriptive_names.get(attr, attr) # Use the attribute name as a fallback - row = f"{colored(descriptive_name, 'blue'):>30} | {colored(value, 'yellow')}" + value = getattr(self, attr, "Attribute not set") + descriptive_name = descriptive_names.get( + attr, attr + ) # Use the attribute name as a fallback + row = ( + f"{colored(descriptive_name, 'blue'):>30} | {colored(value, 'yellow')}" + ) print(row) def _create_boundary_data(self): @@ -358,7 +408,9 @@ def _create_boundary_data(self): # Accumulate the indices of all BCs to create the grid mask with FALSE along directions that # stream into a boundary voxel. solid_halo_list = [np.array(bc.indices).T for bc in self.BCs if bc.isSolid] - solid_halo_voxels = np.unique(np.vstack(solid_halo_list), axis=0) if solid_halo_list else None + solid_halo_voxels = ( + np.unique(np.vstack(solid_halo_list), axis=0) if solid_halo_list else None + ) # Create the grid mask on each process start = time.time() @@ -367,7 +419,7 @@ def _create_boundary_data(self): start = time.time() for bc in self.BCs: - assert bc.implementationStep in ['PostStreaming', 'PostCollision'] + assert bc.implementationStep in ["PostStreaming", "PostCollision"] bc.create_local_mask_and_normal_arrays(grid_mask) print("Time to create the local masks and normal arrays:", time.time() - start) @@ -383,8 +435,8 @@ def _create_boundary_data(self): # if init_val is None: # x = jnp.zeros(shape=device_shape, dtype=type) # else: - # x = jnp.full(shape=device_shape, fill_value=init_val, dtype=type) - # arrays += [jax.device_put(x, d)] + # x = jnp.full(shape=device_shape, fill_value=init_val, dtype=type) + # arrays += [jax.device_put(x, d)] # jax.default_device = jax.devices()[0] # return jax.make_array_from_single_device_arrays(shape, self.sharding, arrays) @@ -407,18 +459,18 @@ def distributed_array_init(self, shape, type, init_val=0, sharding=None): """ if sharding is None: sharding = self.sharding - x = jnp.full(shape=shape, fill_value=init_val, dtype=type) + x = jnp.full(shape=shape, fill_value=init_val, dtype=type) return jax.lax.with_sharding_constraint(x, sharding) - + @partial(jit, static_argnums=(0,)) def create_grid_mask(self, solid_halo_voxels): """ This function creates a mask for the background grid that accounts for the location of the boundaries. - + Parameters ---------- solid_halo_voxels: A numpy array representing the voxels in the halo of the solid object. - + Returns ------- A JAX array representing the grid mask of the grid. @@ -427,19 +479,41 @@ def create_grid_mask(self, solid_halo_voxels): hw_x = self.nDevices hw_y = hw_z = 1 if self.dim == 2: - grid_mask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, init_val=True) - grid_mask = grid_mask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None))].set(False) + grid_mask = self.distributed_array_init( + (self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), + jnp.bool_, + init_val=True, + ) + grid_mask = grid_mask.at[ + (slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None)) + ].set(False) if solid_halo_voxels is not None: solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) - grid_mask = grid_mask.at[tuple(solid_halo_voxels.T)].set(True) + grid_mask = grid_mask.at[tuple(solid_halo_voxels.T)].set(True) grid_mask = self.streaming(grid_mask) return lax.with_sharding_constraint(grid_mask, self.sharding) elif self.dim == 3: - grid_mask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, init_val=True) - grid_mask = grid_mask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(hw_z, -hw_z), slice(None))].set(False) + grid_mask = self.distributed_array_init( + ( + self.nx + 2 * hw_x, + self.ny + 2 * hw_y, + self.nz + 2 * hw_z, + self.lattice.q, + ), + jnp.bool_, + init_val=True, + ) + grid_mask = grid_mask.at[ + ( + slice(hw_x, -hw_x), + slice(hw_y, -hw_y), + slice(hw_z, -hw_z), + slice(None), + ) + ].set(False) if solid_halo_voxels is not None: solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) @@ -463,11 +537,15 @@ def bounding_box_indices(self): # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. # Each edge is represented as an array of indices. For example, the bottom edge includes # all points where the y-coordinate is 0, so its indices are [[i, 0] for i in range(self.nx)]. - bounding_box = {"bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), - "top": np.array([[i, self.ny - 1] for i in range(self.nx)], dtype=int), - "left": np.array([[0, i] for i in range(self.ny)], dtype=int), - "right": np.array([[self.nx - 1, i] for i in range(self.ny)], dtype=int)} - + bounding_box = { + "bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), + "top": np.array([[i, self.ny - 1] for i in range(self.nx)], dtype=int), + "left": np.array([[0, i] for i in range(self.ny)], dtype=int), + "right": np.array( + [[self.nx - 1, i] for i in range(self.ny)], dtype=int + ), + } + return bounding_box elif self.dim == 3: @@ -475,12 +553,43 @@ def bounding_box_indices(self): # Each face is represented as an array of indices. For example, the bottom face includes all points # where the z-coordinate is 0, so its indices are [[i, j, 0] for i in range(self.nx) for j in range(self.ny)]. bounding_box = { - "bottom": np.array([[i, j, 0] for i in range(self.nx) for j in range(self.ny)], dtype=int), - "top": np.array([[i, j, self.nz - 1] for i in range(self.nx) for j in range(self.ny)],dtype=int), - "left": np.array([[0, j, k] for j in range(self.ny) for k in range(self.nz)], dtype=int), - "right": np.array([[self.nx - 1, j, k] for j in range(self.ny) for k in range(self.nz)], dtype=int), - "front": np.array([[i, 0, k] for i in range(self.nx) for k in range(self.nz)], dtype=int), - "back": np.array([[i, self.ny - 1, k] for i in range(self.nx) for k in range(self.nz)], dtype=int)} + "bottom": np.array( + [[i, j, 0] for i in range(self.nx) for j in range(self.ny)], + dtype=int, + ), + "top": np.array( + [ + [i, j, self.nz - 1] + for i in range(self.nx) + for j in range(self.ny) + ], + dtype=int, + ), + "left": np.array( + [[0, j, k] for j in range(self.ny) for k in range(self.nz)], + dtype=int, + ), + "right": np.array( + [ + [self.nx - 1, j, k] + for j in range(self.ny) + for k in range(self.nz) + ], + dtype=int, + ), + "front": np.array( + [[i, 0, k] for i in range(self.nx) for k in range(self.nz)], + dtype=int, + ), + "back": np.array( + [ + [i, self.ny - 1, k] + for i in range(self.nx) + for k in range(self.nz) + ], + dtype=int, + ), + } return bounding_box @@ -522,7 +631,9 @@ def initialize_macroscopic_fields(self): None, None: The default density and velocity, both None. This indicates that the actual values should be set elsewhere. """ print("WARNING: Default initial conditions assumed: density = 1, velocity = 0") - print(" To set explicit initial density and velocity, use self.initialize_macroscopic_fields.") + print( + " To set explicit initial density and velocity, use self.initialize_macroscopic_fields." + ) return None, None def assign_fields_sharded(self): @@ -530,10 +641,10 @@ def assign_fields_sharded(self): This function is used to initialize the simulation by assigning the macroscopic fields and populations. The function first initializes the macroscopic fields, which are the density (rho0) and velocity (u0). - Depending on the dimension of the simulation (2D or 3D), it then sets the shape of the array that will hold the + Depending on the dimension of the simulation (2D or 3D), it then sets the shape of the array that will hold the distribution functions (f). - If the density or velocity are not provided, the function initializes the distribution functions with a default + If the density or velocity are not provided, the function initializes the distribution functions with a default value (self.w), representing density=1 and velocity=0. Otherwise, it uses the provided density and velocity to initialize the populations. Parameters @@ -550,18 +661,20 @@ def assign_fields_sharded(self): shape = (self.nx, self.ny, self.lattice.q) if self.dim == 3: shape = (self.nx, self.ny, self.nz, self.lattice.q) - + if rho0 is None or u0 is None: - f = self.distributed_array_init(shape, self.precisionPolicy.output_dtype, init_val=self.w) + f = self.distributed_array_init( + shape, self.precisionPolicy.output_dtype, init_val=self.w + ) else: f = self.initialize_populations(rho0, u0) return f - + def initialize_populations(self, rho0, u0): """ This function initializes the populations (distribution functions) for the simulation. - It uses the equilibrium distribution function, which is a function of the macroscopic + It uses the equilibrium distribution function, which is a function of the macroscopic density and velocity. Parameters @@ -596,7 +709,7 @@ def send_right(self, x, axis_name): The data after being sent to the right neighboring process. """ return lax.ppermute(x, perm=self.rightPerm, axis_name=axis_name) - + def send_left(self, x, axis_name): """ This function sends the data to the left neighboring process in a parallel computing environment. @@ -614,17 +727,17 @@ def send_left(self, x, axis_name): The data after being sent to the left neighboring process. """ return lax.ppermute(x, perm=self.leftPerm, axis_name=axis_name) - + def streaming_m(self, f): """ - This function performs the streaming step in the Lattice Boltzmann Method, which is + This function performs the streaming step in the Lattice Boltzmann Method, which is the propagation of the distribution functions in the lattice. To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the distribution functions that need to be communicated to the neighboring processes. - The function then sends the left boundary slice to the right neighboring process and the right - boundary slice to the left neighboring process. The received data is then set to the + The function then sends the left boundary slice to the right neighboring process and the right + boundary slice to the left neighboring process. The received data is then set to the corresponding indices in the receiving domain. Parameters @@ -638,9 +751,14 @@ def streaming_m(self, f): The distribution functions after the streaming operation. """ f = self.streaming_p(f) - left_comm, right_comm = f[:1, ..., self.lattice.right_indices], f[-1:, ..., self.lattice.left_indices] - - left_comm, right_comm = self.send_right(left_comm, 'x'), self.send_left(right_comm, 'x') + left_comm, right_comm = ( + f[:1, ..., self.lattice.right_indices], + f[-1:, ..., self.lattice.left_indices], + ) + + left_comm, right_comm = self.send_right(left_comm, "x"), self.send_left( + right_comm, "x" + ) f = f.at[:1, ..., self.lattice.right_indices].set(left_comm) f = f.at[-1:, ..., self.lattice.left_indices].set(right_comm) return f @@ -649,8 +767,8 @@ def streaming_m(self, f): def streaming_p(self, f): """ Perform streaming operation on a partitioned (in the x-direction) distribution function. - - The function uses the vmap operation provided by the JAX library to vectorize the computation + + The function uses the vmap operation provided by the JAX library to vectorize the computation over all lattice directions. Parameters @@ -661,6 +779,7 @@ def streaming_p(self, f): ------- The updated distribution function after streaming. """ + def streaming_i(f, c): """ Perform individual streaming operation in a direction. @@ -689,7 +808,7 @@ def equilibrium(self, rho, u, cast_output=True): The equilibrium distribution function is a function of the macroscopic density and velocity. The function first casts the density and velocity to the compute precision if the cast_output flag is True. - The function finally casts the equilibrium distribution function to the output precision if the cast_output + The function finally casts the equilibrium distribution function to the output precision if the cast_output flag is True. Parameters @@ -699,7 +818,7 @@ def equilibrium(self, rho, u, cast_output=True): u: jax.numpy.ndarray The macroscopic velocity. cast_output: bool, optional - A flag indicating whether to cast the density, velocity, and equilibrium distribution function to the + A flag indicating whether to cast the density, velocity, and equilibrium distribution function to the compute and output precisions. Default is True. Returns @@ -711,7 +830,7 @@ def equilibrium(self, rho, u, cast_output=True): if cast_output: rho, u = self.precisionPolicy.cast_to_compute((rho, u)) - # Cast c to compute precision so that XLA call FXX matmul, + # Cast c to compute precision so that XLA call FXX matmul, # which is faster (it is faster in some older versions of JAX, newer versions are smart enough to do this automatically) c = jnp.array(self.c, dtype=self.precisionPolicy.compute_dtype) cu = 3.0 * jnp.dot(u, c) @@ -726,10 +845,10 @@ def equilibrium(self, rho, u, cast_output=True): @partial(jit, static_argnums=(0,)) def momentum_flux(self, fneq): """ - This function computes the momentum flux, which is the product of the non-equilibrium + This function computes the momentum flux, which is the product of the non-equilibrium distribution functions (fneq) and the lattice moments (cc). - The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann + The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann Method (LBM). Parameters @@ -747,11 +866,11 @@ def momentum_flux(self, fneq): @partial(jit, static_argnums=(0,), inline=True) def update_macroscopic(self, f): """ - This function computes the macroscopic variables (density and velocity) based on the + This function computes the macroscopic variables (density and velocity) based on the distribution functions (f). - The density is computed as the sum of the distribution functions over all lattice directions. - The velocity is computed as the dot product of the distribution functions and the lattice + The density is computed as the sum of the distribution functions over all lattice directions. + The velocity is computed as the dot product of the distribution functions and the lattice velocities, divided by the density. Parameters @@ -766,19 +885,19 @@ def update_macroscopic(self, f): u: jax.numpy.ndarray Computed velocity. """ - rho =jnp.sum(f, axis=-1, keepdims=True) + rho = jnp.sum(f, axis=-1, keepdims=True) c = jnp.array(self.c, dtype=self.precisionPolicy.compute_dtype).T u = jnp.dot(f, c) / rho return rho, u - + @partial(jit, static_argnums=(0, 4), inline=True) def apply_bc(self, fout, fin, timestep, implementation_step): """ This function applies the boundary conditions to the distribution functions. - It iterates over all boundary conditions (BCs) and checks if the implementation step of the - boundary condition matches the provided implementation step. If it does, it applies the + It iterates over all boundary conditions (BCs) and checks if the implementation step of the + boundary condition matches the provided implementation step. If it does, it applies the boundary condition to the post-streaming distribution functions (fout). Parameters @@ -802,7 +921,7 @@ def apply_bc(self, fout, fin, timestep, implementation_step): fout = bc.apply(fout, fin, timestep) else: fout = fout.at[bc.indices].set(bc.apply(fout, fin)) - + return fout @partial(jit, static_argnums=(0, 3), donate_argnums=(1,)) @@ -810,12 +929,12 @@ def step(self, f_poststreaming, timestep, return_fpost=False): """ This function performs a single step of the LBM simulation. - It first performs the collision step, which is the relaxation of the distribution functions - towards their equilibrium values. It then applies the respective boundary conditions to the + It first performs the collision step, which is the relaxation of the distribution functions + towards their equilibrium values. It then applies the respective boundary conditions to the post-collision distribution functions. - The function then performs the streaming step, which is the propagation of the distribution - functions in the lattice. It then applies the respective boundary conditions to the post-streaming + The function then performs the streaming step, which is the propagation of the distribution + functions in the lattice. It then applies the respective boundary conditions to the post-streaming distribution functions. Parameters @@ -832,13 +951,17 @@ def step(self, f_poststreaming, timestep, return_fpost=False): f_poststreaming: jax.numpy.ndarray The post-streaming distribution functions after the simulation step. f_postcollision: jax.numpy.ndarray or None - The post-collision distribution functions after the simulation step, or None if + The post-collision distribution functions after the simulation step, or None if return_fpost is False. """ f_postcollision = self.collision(f_poststreaming) - f_postcollision = self.apply_bc(f_postcollision, f_poststreaming, timestep, "PostCollision") + f_postcollision = self.apply_bc( + f_postcollision, f_poststreaming, timestep, "PostCollision" + ) f_poststreaming = self.streaming(f_postcollision) - f_poststreaming = self.apply_bc(f_poststreaming, f_postcollision, timestep, "PostStreaming") + f_poststreaming = self.apply_bc( + f_poststreaming, f_postcollision, timestep, "PostStreaming" + ) if return_fpost: return f_poststreaming, f_postcollision @@ -849,10 +972,10 @@ def run(self, t_max): """ This function runs the LBM simulation for a specified number of time steps. - It first initializes the distribution functions and then enters a loop where it performs the + It first initializes the distribution functions and then enters a loop where it performs the simulation steps (collision, streaming, and boundary conditions) for each time step. - The function can also print the progress of the simulation, save the simulation data, and + The function can also print the progress of the simulation, save the simulation data, and compute the performance of the simulation in million lattice updates per second (MLUPS). Parameters @@ -871,25 +994,39 @@ def run(self, t_max): if latest_step is not None: # existing checkpoint present # Assert that the checkpoint manager is not None assert self.mngr is not None, "Checkpoint manager does not exist." - state = {'f': f} + state = {"f": f} shardings = jax.tree_map(lambda x: x.sharding, state) - restore_args = orb.checkpoint_utils.construct_restore_args(state, shardings) + restore_args = orb.checkpoint_utils.construct_restore_args( + state, shardings + ) try: - f = self.mngr.restore(latest_step, restore_kwargs={'restore_args': restore_args})['f'] + f = self.mngr.restore( + latest_step, restore_kwargs={"restore_args": restore_args} + )["f"] print(f"Restored checkpoint at step {latest_step}.") except ValueError: - raise ValueError(f"Failed to restore checkpoint at step {latest_step}.") - + raise ValueError( + f"Failed to restore checkpoint at step {latest_step}." + ) + start_step = latest_step + 1 if not (t_max > start_step): - raise ValueError(f"Simulation already exceeded maximum allowable steps (t_max = {t_max}). Consider increasing t_max.") + raise ValueError( + f"Simulation already exceeded maximum allowable steps (t_max = {t_max}). Consider increasing t_max." + ) if self.computeMLUPS: start = time.time() # Loop over all time steps for timestep in range(start_step, t_max + 1): - io_flag = self.ioRate > 0 and (timestep % self.ioRate == 0 or timestep == t_max) - print_iter_flag = self.printInfoRate> 0 and timestep % self.printInfoRate== 0 - checkpoint_flag = self.checkpointRate > 0 and timestep % self.checkpointRate == 0 + io_flag = self.ioRate > 0 and ( + timestep % self.ioRate == 0 or timestep == t_max + ) + print_iter_flag = ( + self.printInfoRate > 0 and timestep % self.printInfoRate == 0 + ) + checkpoint_flag = ( + self.checkpointRate > 0 and timestep % self.checkpointRate == 0 + ) if io_flag: # Update the macroscopic variables and save the previous values (for error computation) @@ -900,12 +1037,17 @@ def run(self, t_max): rho_prev = process_allgather(rho_prev) u_prev = process_allgather(u_prev) - # Perform one time-step (collision, streaming, and boundary conditions) f, fstar = self.step(f, timestep, return_fpost=self.returnFpost) # Print the progress of the simulation if print_iter_flag: - print(colored("Timestep ", 'blue') + colored(f"{timestep}", 'green') + colored(" of ", 'blue') + colored(f"{t_max}", 'green') + colored(" completed", 'blue')) + print( + colored("Timestep ", "blue") + + colored(f"{timestep}", "green") + + colored(" of ", "blue") + + colored(f"{t_max}", "green") + + colored(" completed", "blue") + ) if io_flag: # Save the simulation data @@ -913,20 +1055,20 @@ def run(self, t_max): rho, u = self.update_macroscopic(f) rho = downsample_field(rho, self.downsamplingFactor) u = downsample_field(u, self.downsamplingFactor) - + # Gather the data from all processes and convert it to numpy arrays (move to host memory) rho = process_allgather(rho) u = process_allgather(u) # Save the data self.handle_io_timestep(timestep, f, fstar, rho, u, rho_prev, u_prev) - + if checkpoint_flag: # Save the checkpoint print(f"Saving checkpoint at timestep {timestep}/{t_max}") - state = {'f': f} + state = {"f": f} self.mngr.save(timestep, state) - + # Start the timer for the MLUPS computation after the first timestep (to remove compilation overhead) if self.computeMLUPS and timestep == 1: jax.block_until_ready(f) @@ -937,14 +1079,41 @@ def run(self, t_max): jax.block_until_ready(f) end = time.time() if self.dim == 2: - print(colored("Domain: ", 'blue') + colored(f"{self.nx} x {self.ny}", 'green') if self.dim == 2 else colored(f"{self.nx} x {self.ny} x {self.nz}", 'green')) - print(colored("Number of voxels: ", 'blue') + colored(f"{self.nx * self.ny}", 'green') if self.dim == 2 else colored(f"{self.nx * self.ny * self.nz}", 'green')) - print(colored("MLUPS: ", 'blue') + colored(f"{self.nx * self.ny * t_max / (end - start) / 1e6}", 'red')) + print( + colored("Domain: ", "blue") + + colored(f"{self.nx} x {self.ny}", "green") + if self.dim == 2 + else colored(f"{self.nx} x {self.ny} x {self.nz}", "green") + ) + print( + colored("Number of voxels: ", "blue") + + colored(f"{self.nx * self.ny}", "green") + if self.dim == 2 + else colored(f"{self.nx * self.ny * self.nz}", "green") + ) + print( + colored("MLUPS: ", "blue") + + colored( + f"{self.nx * self.ny * t_max / (end - start) / 1e6}", "red" + ) + ) elif self.dim == 3: - print(colored("Domain: ", 'blue') + colored(f"{self.nx} x {self.ny} x {self.nz}", 'green')) - print(colored("Number of voxels: ", 'blue') + colored(f"{self.nx * self.ny * self.nz}", 'green')) - print(colored("MLUPS: ", 'blue') + colored(f"{self.nx * self.ny * self.nz * t_max / (end - start) / 1e6}", 'red')) + print( + colored("Domain: ", "blue") + + colored(f"{self.nx} x {self.ny} x {self.nz}", "green") + ) + print( + colored("Number of voxels: ", "blue") + + colored(f"{self.nx * self.ny * self.nz}", "green") + ) + print( + colored("MLUPS: ", "blue") + + colored( + f"{self.nx * self.ny * self.nz * t_max / (end - start) / 1e6}", + "red", + ) + ) return f @@ -952,7 +1121,7 @@ def handle_io_timestep(self, timestep, f, fstar, rho, u, rho_prev, u_prev): """ This function handles the input/output (I/O) operations at each time step of the simulation. - It prepares the data to be saved and calls the output_data function, which can be overwritten + It prepares the data to be saved and calls the output_data function, which can be overwritten by the user to customize the I/O operations. Parameters @@ -975,22 +1144,22 @@ def handle_io_timestep(self, timestep, f, fstar, rho, u, rho_prev, u_prev): "u": u, "u_prev": u_prev, "f_poststreaming": f, - "f_postcollision": fstar + "f_postcollision": fstar, } self.output_data(**kwargs) def output_data(self, **kwargs): """ - This function is intended to be overwritten by the user to customize the input/output (I/O) + This function is intended to be overwritten by the user to customize the input/output (I/O) operations of the simulation. - By default, it does nothing. When overwritten, it could save the simulation data to files, + By default, it does nothing. When overwritten, it could save the simulation data to files, display the simulation results in real time, send the data to another process for analysis, etc. Parameters ---------- **kwargs: dict - A dictionary containing the simulation data to be outputted. The keys are the names of the + A dictionary containing the simulation data to be outputted. The keys are the names of the data fields, and the values are the data fields themselves. """ pass @@ -999,10 +1168,10 @@ def set_boundary_conditions(self): """ This function sets the boundary conditions for the simulation. - It is intended to be overwritten by the user to specify the boundary conditions according to + It is intended to be overwritten by the user to specify the boundary conditions according to the specific problem being solved. - By default, it does nothing. When overwritten, it could set periodic boundaries, no-slip + By default, it does nothing. When overwritten, it could set periodic boundaries, no-slip boundaries, inflow/outflow boundaries, etc. """ pass @@ -1012,7 +1181,7 @@ def collision(self, fin): """ This function performs the collision step in the Lattice Boltzmann Method. - It is intended to be overwritten by the user to specify the collision operator according to + It is intended to be overwritten by the user to specify the collision operator according to the specific LBM model being used. By default, it does nothing. When overwritten, it could implement the BGK collision operator, @@ -1034,10 +1203,10 @@ def get_force(self): """ This function computes the force to be applied to the fluid in the Lattice Boltzmann Method. - It is intended to be overwritten by the user to specify the force according to the specific + It is intended to be overwritten by the user to specify the force according to the specific problem being solved. - By default, it does nothing and returns None. When overwritten, it could implement a constant + By default, it does nothing and returns None. When overwritten, it could implement a constant force term. Returns @@ -1063,12 +1232,12 @@ def apply_force(self, f_postcollision, feq, rho, u): u: jax.numpy.ndarray The velocity field. - + Returns ------- f_postcollision: jax.numpy.ndarray The post-collision distribution functions with the force applied. - + References ---------- Kupershtokh, A. (2004). New method of incorporating a body force term into the lattice Boltzmann equation. In @@ -1081,6 +1250,3 @@ def apply_force(self, f_postcollision, feq, rho, u): feq_force = self.equilibrium(rho, u + delta_u, cast_output=False) f_postcollision = f_postcollision + feq_force - feq return f_postcollision - - - diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py index 15a2ea4..bcefed1 100644 --- a/xlb/compute_backend.py +++ b/xlb/compute_backend.py @@ -2,6 +2,8 @@ from enum import Enum, auto + class ComputeBackend(Enum): JAX = auto() + PALLAS = auto() WARP = auto() diff --git a/xlb/compute_backends.py b/xlb/compute_backends.py deleted file mode 100644 index aff65cc..0000000 --- a/xlb/compute_backends.py +++ /dev/null @@ -1,8 +0,0 @@ -# Enum used to keep track of the compute backends - -from enum import Enum - -class ComputeBackends(Enum): - JAX = 1 - PALLAS = 2 - WARP = 3 diff --git a/xlb/experimental/ooc/ooc_array.py b/xlb/experimental/ooc/ooc_array.py index f5e2c34..6effde6 100644 --- a/xlb/experimental/ooc/ooc_array.py +++ b/xlb/experimental/ooc/ooc_array.py @@ -1,11 +1,16 @@ import numpy as np import cupy as cp -#from mpi4py import MPI + +# from mpi4py import MPI import itertools from dataclasses import dataclass from xlb.experimental.ooc.tiles.dense_tile import DenseTile, DenseGPUTile, DenseCPUTile -from xlb.experimental.ooc.tiles.compressed_tile import CompressedTile, CompressedGPUTile, CompressedCPUTile +from xlb.experimental.ooc.tiles.compressed_tile import ( + CompressedTile, + CompressedGPUTile, + CompressedCPUTile, +) class OOCArray: diff --git a/xlb/experimental/ooc/out_of_core.py b/xlb/experimental/ooc/out_of_core.py index 5faa422..01851e8 100644 --- a/xlb/experimental/ooc/out_of_core.py +++ b/xlb/experimental/ooc/out_of_core.py @@ -8,7 +8,11 @@ import numpy as np from xlb.experimental.ooc.ooc_array import OOCArray -from xlb.experimental.ooc.utils import _cupy_to_backend, _backend_to_cupy, _stream_to_backend +from xlb.experimental.ooc.utils import ( + _cupy_to_backend, + _backend_to_cupy, + _stream_to_backend, +) def OOCmap(comm, ref_args, add_index=False, backend="jax"): diff --git a/xlb/global_config.py b/xlb/global_config.py index 563a75d..17a1839 100644 --- a/xlb/global_config.py +++ b/xlb/global_config.py @@ -9,5 +9,6 @@ def init(velocity_set, compute_backend, precision_policy): GlobalConfig.compute_backend = compute_backend GlobalConfig.precision_policy = precision_policy + def current_backend(): return GlobalConfig.compute_backend diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index 583b72e..d44ce65 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1 +1 @@ -from xlb.grid.grid import Grid \ No newline at end of file +from xlb.grid.grid import Grid diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 239786b..51f6cbf 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -5,34 +5,29 @@ class Grid(ABC): - def __init__(self, grid_shape, velocity_set, compute_backend): - self.velocity_set: VelocitySet = velocity_set - self.compute_backend = compute_backend - self.grid_shape = grid_shape - self.pop_shape = (self.velocity_set.q, *grid_shape) - self.u_shape = (self.velocity_set.d, *grid_shape) - self.rho_shape = (1, *grid_shape) - self.dim = self.velocity_set.d - - @abstractmethod - def create_field(self, cardinality, callback=None): - pass - @staticmethod - def create(grid_shape, velocity_set=None, compute_backend=None): - compute_backend = compute_backend or GlobalConfig.compute_backend - velocity_set = velocity_set or GlobalConfig.velocity_set + def __init__(self, shape, velocity_set, precision_policy, grid_backend): + # Set parameters + self.shape = shape + self.velocity_set = velocity_set + self.precision_policy = precision_policy + self.grid_backend = grid_backend + self.dim = self.velocity_set.d -<<<<<<< HEAD - if compute_backend == ComputeBackend.JAX: -======= - if compute_backend == ComputeBackends.JAX or compute_backend == ComputeBackends.PALLAS: ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 - from xlb.grid.jax_grid import JaxGrid # Avoids circular import + # Create field dict + self.fields = {} - return JaxGrid(grid_shape, velocity_set, compute_backend) - raise ValueError(f"Compute backend {compute_backend} is not supported") + def parallelize_operator(self, operator): + raise NotImplementedError("Parallelization not implemented, child class must implement") @abstractmethod - def field_global_to_local_shape(self, shape): + def create_field( + self, name: str, cardinality: int, precision: Precision, callback=None + ): pass + + def get_field(self, name: str): + return self.fields[name] + + def swap_fields(self, field1, field2): + self.fields[field1], self.fields[field2] = self.fields[field2], self.fields[field1] diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 6b95c25..42687c8 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -8,11 +8,11 @@ class JaxGrid(Grid): - def __init__(self, grid_shape, velocity_set, compute_backend): - super().__init__(grid_shape, velocity_set, compute_backend) - self.initialize_jax_backend() + def __init__(self, grid_shape, velocity_set, precision_policy, grid_backend): + super().__init__(grid_shape, velocity_set, precision_policy, grid_backend) + self._initialize_jax_backend() - def initialize_jax_backend(self): + def _initialize_jax_backend(self): self.nDevices = jax.device_count() self.backend = jax.default_backend() device_mesh = ( @@ -34,17 +34,17 @@ def initialize_jax_backend(self): self.grid_shape[0] // self.nDevices, ) + self.grid_shape[1:] - def field_global_to_local_shape(self, shape): - if len(shape) < 2: - raise ValueError("Shape must have at least two dimensions") + def create_field(self, name: str, cardinality: int, callback=None): + # Get shape of the field + shape = (cardinality,) + (self.shape) - new_second_index = shape[1] // self.nDevices - - return shape[:1] + (new_second_index,) + shape[2:] - - def create_field(self, cardinality, callback=None): + # Create field if callback is None: - f = ConstInitializer(self, cardinality=cardinality)(0.0) - return f - shape = (cardinality,) + (self.grid_shape) - return jax.make_array_from_callback(shape, self.sharding, callback) + f = jax.numpy.full(shape, 0.0, dtype=self.precision_policy) + if self.sharding is not None: + f = jax.make_sharded_array(self.sharding, f) + else: + f = jax.make_array_from_callback(shape, self.sharding, callback) + + # Add field to the field dictionary + self.fields[name] = f diff --git a/xlb/grid_backend.py b/xlb/grid_backend.py new file mode 100644 index 0000000..42cd022 --- /dev/null +++ b/xlb/grid_backend.py @@ -0,0 +1,9 @@ +# Enum used to keep track of the compute backends + +from enum import Enum, auto + + +class GridBackend(Enum): + JAX = auto() + WARP = auto() + OOC = auto() diff --git a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py new file mode 100644 index 0000000..ee511e2 --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py @@ -0,0 +1,37 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class BoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ + + @classmethod + def from_jax_func( + cls, jax_func, precision_policy: PrecisionPolicy, velocity_set: VelocitySet + ): + """ + Create a boundary masker from a jax function + """ + raise NotImplementedError + + @classmethod + def from_warp_func( + cls, warp_func, precision_policy: PrecisionPolicy, velocity_set: VelocitySet + ): + """ + Create a boundary masker from a warp function + """ + raise NotImplementedError diff --git a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py new file mode 100644 index 0000000..3b57895 --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py @@ -0,0 +1,100 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class IndicesBoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ + + def __init__( + self, + indices: np.ndarray, + stream_indices: bool, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # Set indices + # TODO: handle multi-gpu case (this will usually implicitly work) + self.indices = indices + self.stream_indices = stream_indices + + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + @staticmethod + def _indices_to_tuple(indices): + """ + Converts a tensor of indices to a tuple for indexing + TODO: Might be better to index + """ + return tuple([indices[:, i] for i in range(indices.shape[1])]) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, start_index, boundary_id, mask, id_number): + # Get local indices from the meshgrid and the indices + local_indices = self.indices - start_index + + # Remove any indices that are out of bounds + local_indices = local_indices[ + (local_indices[:, 0] >= 0) + & (local_indices[:, 0] < mask.shape[0]) + & (local_indices[:, 1] >= 0) + & (local_indices[:, 1] < mask.shape[1]) + & (local_indices[:, 2] >= 0) + & (local_indices[:, 2] < mask.shape[2]) + ] + + # Set the boundary id + boundary_id = boundary_id.at[self._indices_to_tuple(local_indices)].set( + id_number + ) + + # Stream mask if necessary + if self.stream_indices: + # Make mask then stream to get the edge points + pre_stream_mask = jnp.zeros_like(mask) + pre_stream_mask = pre_stream_mask.at[ + self._indices_to_tuple(local_indices) + ].set(True) + post_stream_mask = self.stream(pre_stream_mask) + + # Set false for points inside the boundary + post_stream_mask = post_stream_mask.at[ + post_stream_mask[..., 0] == True + ].set(False) + + # Get indices on edges + edge_indices = jnp.argwhere(post_stream_mask) + + # Set the mask + mask = mask.at[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ].set( + post_stream_mask[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ] + ) + + else: + # Set the mask + mask = mask.at[self._indices_to_tuple(local_indices)].set(True) + + return boundary_id, mask diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index d8cb4ad..0937f1a 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -31,83 +31,6 @@ def __init__( compute_backend=compute_backend, ) - @classmethod - def from_indices( - cls, - indices, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): - """ - Creates a boundary condition from a list of indices. - """ - - # Make stream operator to get edge points - stream = Stream(velocity_set=velocity_set) - - # Create a mask function - def set_boundary(ijk, boundary_id, mask, id_number): - """ - Sets the mask id for the boundary condition. - Halfway bounce-back is implemented by setting the mask to True for points in the boundary, - then streaming the mask to get the points on the surface. - - Parameters - ---------- - ijk : jnp.ndarray - Array of shape (N, N, N, 3) containing the meshgrid of lattice points. - boundary_id : jnp.ndarray - Array of shape (N, N, N) containing the boundary id. This will be modified in place and returned. - mask : jnp.ndarray - Array of shape (N, N, N, Q) containing the mask. This will be modified in place and returned. - """ - - # Get local indices from the meshgrid and the indices - local_indices = ijk[ - tuple(s[:, 0] for s in jnp.split(indices, velocity_set.d, axis=1)) - ] - - # Make mask then stream to get the edge points - pre_stream_mask = jnp.zeros_like(mask) - pre_stream_mask = pre_stream_mask.at[ - tuple( - [s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)] - ) - ].set(True) - post_stream_mask = stream(pre_stream_mask) - - # Set false for points inside the boundary - post_stream_mask = post_stream_mask.at[ - post_stream_mask[..., 0] == True - ].set(False) - - # Get indices on edges - edge_indices = jnp.argwhere(post_stream_mask) - - # Set the boundary id - boundary_id = boundary_id.at[ - tuple( - [s[:, 0] for s in jnp.split(local_indices, velocity_set.d, axis=1)] - ) - ].set(id_number) - - # Set the mask - mask = mask.at[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ].set( - post_stream_mask[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ] - ) - - return boundary_id, mask - - return cls( - set_boundary=set_boundary, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): flip_mask = boundary[..., jnp.newaxis] & mask diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 08d6ce3..10c4a5f 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -14,13 +14,17 @@ class BGK(Collision): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): + def jax_implementation( + self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray + ): fneq = f - feq fout = f - self.omega * fneq return fout - + @Operator.register_backend(ComputeBackends.PALLAS) - def pallas_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): + def pallas_implementation( + self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray + ): fneq = f - feq fout = f - self.omega * fneq return fout @@ -35,7 +39,10 @@ def _construct_warp(self): # Construct the functional @wp.func def functional( - f: self._warp_lattice_vec, feq: self._warp_lattice_vec + f: self._warp_lattice_vec, + feq: self._warp_lattice_vec, + rho: self.compute_dtype, + u: self._warp_u_vec, ) -> self._warp_lattice_vec: fneq = f - feq fout = f - self.omega * fneq @@ -46,6 +53,8 @@ def functional( def kernel( f: self._warp_array_type, feq: self._warp_array_type, + rho: self._warp_array_type, + u: self._warp_array_type, fout: self._warp_array_type, ): # Get the global index @@ -66,13 +75,15 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout): + def warp_implementation(self, f, feq, rho, u, fout): # Launch the warp kernel wp.launch( self._kernel, inputs=[ f, feq, + rho, + u, fout, ], dim=f.shape[1:], diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index b83a41f..8302978 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -20,19 +20,16 @@ class KBC(Collision): def __init__( self, + omega: float, velocity_set: VelocitySet = None, precision_policy=None, compute_backend=None, ): super().__init__( -<<<<<<< HEAD omega=omega, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, -======= - velocity_set=velocity_set, compute_backend=compute_backend ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 ) self.epsilon = 1e-32 self.beta = self.omega * 0.5 @@ -45,6 +42,7 @@ def jax_implementation( f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, + u: jnp.ndarray, ): """ KBC collision step for lattice. @@ -57,6 +55,8 @@ def jax_implementation( Equilibrium distribution function. rho : jax.numpy.array Density. + u : jax.numpy.array + Velocity. """ fneq = f - feq if isinstance(self.velocity_set, D2Q9): diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py index b12ec41..e13d2db 100644 --- a/xlb/operator/initializer/const_init.py +++ b/xlb/operator/initializer/const_init.py @@ -11,8 +11,6 @@ class ConstInitializer(Operator): def __init__( self, - grid: Grid, - cardinality, type=np.float32, velocity_set: VelocitySet = None, compute_backend: ComputeBackends = None, @@ -21,7 +19,6 @@ def __init__( self.grid = grid velocity_set = velocity_set or GlobalConfig.velocity_set compute_backend = compute_backend or GlobalConfig.compute_backend - self.shape = (cardinality,) + (grid.grid_shape) super().__init__(velocity_set, compute_backend) @@ -30,17 +27,10 @@ def __init__( def jax_implementation(self, const_value, sharding=None): if sharding is None: sharding = self.grid.sharding - x = jax.numpy.full( - shape=self.shape, fill_value=const_value, dtype=self.type - ) + x = jax.numpy.full(shape=self.shape, fill_value=const_value, dtype=self.type) return jax.lax.with_sharding_constraint(x, sharding) @Operator.register_backend(ComputeBackends.PALLAS) @partial(jax.jit, static_argnums=(0, 2)) def jax_implementation(self, const_value, sharding=None): - if sharding is None: - sharding = self.grid.sharding - x = jax.numpy.full( - shape=self.shape, fill_value=const_value, dtype=self.type - ) - return jax.lax.with_sharding_constraint(x, sharding) + return self.jax_implementation(const_value, sharding) diff --git a/xlb/operator/initializer/initializer.py b/xlb/operator/initializer/initializer.py new file mode 100644 index 0000000..e813934 --- /dev/null +++ b/xlb/operator/initializer/initializer.py @@ -0,0 +1,13 @@ +from xlb.velocity_set import VelocitySet +from xlb.global_config import GlobalConfig +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator +from xlb.grid.grid import Grid +import numpy as np +import jax + + +class Initializer(Operator): + """ + Base class for all initializers. + """ diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index eef2cfd..eb8a8d2 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -46,7 +46,33 @@ def jax_implementation(self, f): return rho, u -<<<<<<< HEAD + @Operator.register_backend(ComputeBackends.PALLAS) + def pallas_implementation(self, f): + # TODO: Maybe this can be done with jnp.sum + rho = jnp.sum(f, axis=0, keepdims=True) + + u = jnp.zeros((3, *rho.shape[1:])) + u.at[0].set( + -f[9] + - f[10] + - f[11] + - f[12] + - f[13] + + f[14] + + f[15] + + f[16] + + f[17] + + f[18] + ) / rho + u.at[1].set( + -f[3] - f[4] - f[5] + f[6] + f[7] + f[8] - f[12] + f[13] - f[17] + f[18] + ) / rho + u.at[2].set( + -f[1] + f[2] - f[4] + f[5] - f[7] + f[8] - f[10] + f[11] - f[15] + f[16] + ) / rho + + return rho, jnp.array(u) + def _construct_warp(self): # Make constants for warp _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) @@ -107,31 +133,3 @@ def warp_implementation(self, f, rho, u): dim=rho.shape[1:], ) return rho, u -======= - @Operator.register_backend(ComputeBackends.PALLAS) - def pallas_implementation(self, f): - # TODO: Maybe this can be done with jnp.sum - rho = jnp.sum(f, axis=0, keepdims=True) - - u = jnp.zeros((3, *rho.shape[1:])) - u.at[0].set( - -f[9] - - f[10] - - f[11] - - f[12] - - f[13] - + f[14] - + f[15] - + f[16] - + f[17] - + f[18] - ) / rho - u.at[1].set( - -f[3] - f[4] - f[5] + f[6] + f[7] + f[8] - f[12] + f[13] - f[17] + f[18] - ) / rho - u.at[2].set( - -f[1] + f[2] - f[4] + f[5] - f[7] + f[8] - f[10] + f[11] - f[15] + f[16] - ) / rho - - return rho, jnp.array(u) ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 diff --git a/xlb/operator/precision_caster/precision_caster.py b/xlb/operator/precision_caster/precision_caster.py new file mode 100644 index 0000000..be676f4 --- /dev/null +++ b/xlb/operator/precision_caster/precision_caster.py @@ -0,0 +1,98 @@ +""" +Base class for casting precision of the input data to the desired precision +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.operator.operator import Operator +from xlb.velocity_set import VelocitySet +from xlb.precision_policy import Precision, PrecisionPolicy +from xlb.compute_backend import ComputeBackend + + +class PrecisionCaster(Operator): + """ + Class that handles the construction of lattice boltzmann precision casting operator + """ + + def __init__( + self, + input_precision: Precision, + output_precision: Precision, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + super().__init__( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Set the input and output precision based on the backend + self.input_precision = self._precision_to_dtype(input_precision) + self.output_precision = self._precision_to_dtype(output_precision) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0,)) + def jax_implementation(self, f: jnp.ndarray) -> jnp.ndarray: + return self.output_precision(f) + + def _construct_warp(self): + # Construct needed types and constants + from_lattice_vec = wp.vec(self.velocity_set.q, dtype=self.input_precision) + to_lattice_vec = wp.vec(self.velocity_set.q, dtype=self.output_precision) + from_array_type = wp.array4d(dtype=self.input_precision) + to_array_type = wp.array4d(dtype=self.output_precision) + _q = wp.constant(self.velocity_set.q) + + # Construct the functional + @wp.func + def functional( + from_f: from_lattice_vec, + ) -> to_lattice_vec: + to_f = to_lattice_vec() + for i in range(self.velocity_set.q): + to_f[i] = self.output_precision(from_f[i]) + return to_f + + # Construct the warp kernel + @wp.kernel + def kernel( + from_f: from_array_type, + to_f: to_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Get f + _from_f = from_lattice_vec() + for l in range(_q): + _from_f[l] = from_f[l, i, j, k] + + # Cast the precision + _to_f = functional(_from_f) + + # Set f + for l in range(_q): + to_f[l, i, j, k] = _to_f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, feq, fout): + # Launch the warp kernel + wp.launch( + self._kernel, + inputs=[ + f, + feq, + fout, + ], + dim=f.shape[1:], + ) + return fout diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py new file mode 100644 index 0000000..ebb9e13 --- /dev/null +++ b/xlb/operator/stepper/nse.py @@ -0,0 +1,218 @@ +# Base class for all stepper operators + +from functools import partial +from jax import jit +from logging import warning + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.boundary_condition import ImplementationStep +from xlb.operator.collision.bgk import BGK + + +class IncompressibleNavierStokesStepper(Stepper): + """ + Class that handles the construction of lattice boltzmann stepping operator + """ + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0, 5)) + def apply_jax(self, f, boundary_id, mask, timestep): + """ + Perform a single step of the lattice boltzmann method + """ + + # Cast to compute precision + f_pre_collision = self.precision_policy.cast_to_compute_jax(f) + + # Compute the macroscopic variables + rho, u = self.macroscopic(f_pre_collision) + + # Compute equilibrium + feq = self.equilibrium(rho, u) + + # Apply collision + f_post_collision = self.collision( + f, + feq, + rho, + u, + ) + + # Apply collision type boundary conditions + for id_number, bc in self.collision_boundary_conditions.items(): + f_post_collision = bc( + f_pre_collision, + f_post_collision, + boundary_id == id_number, + mask, + ) + f_pre_streaming = f_post_collision + + ## Apply forcing + # if self.forcing_op is not None: + # f = self.forcing_op.apply_jax(f, timestep) + + # Apply streaming + f_post_streaming = self.stream(f_pre_streaming) + + # Apply boundary conditions + for id_number, bc in self.stream_boundary_conditions.items(): + f_post_streaming = bc( + f_pre_streaming, + f_post_streaming, + boundary_id == id_number, + mask, + ) + + # Copy back to store precision + f = self.precision_policy.cast_to_store_jax(f_post_streaming) + + return f + + @Operator.register_backend(ComputeBackends.PALLAS) + @partial(jit, static_argnums=(0,)) + def apply_pallas(self, fin, boundary_id, mask, timestep): + # Raise warning that the boundary conditions are not implemented + ################################################################ + warning("Boundary conditions are not implemented for PALLAS backend currently") + ################################################################ + + from xlb.operator.parallel_operator import ParallelOperator + + def _pallas_collide(fin, fout): + idx = pl.program_id(0) + + f = pl.load(fin, (slice(None), idx, slice(None), slice(None))) + + print("f shape", f.shape) + + rho, u = self.macroscopic(f) + + print("rho shape", rho.shape) + print("u shape", u.shape) + + feq = self.equilibrium(rho, u) + + print("feq shape", feq.shape) + + for i in range(self.velocity_set.q): + print("f shape", f[i].shape) + f_post_collision = self.collision(f[i], feq[i]) + print("f_post_collision shape", f_post_collision.shape) + pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) + # f_post_collision = self.collision(f, feq) + # pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) + + @jit + def _pallas_collide_kernel(fin): + return pl.pallas_call( + partial(_pallas_collide), + out_shape=jax.ShapeDtypeStruct( + ((self.velocity_set.q,) + (self.grid.grid_shape_per_gpu)), fin.dtype + ), + # grid=1, + grid=(self.grid.grid_shape_per_gpu[0], 1, 1), + )(fin) + + def _pallas_collide_and_stream(f): + f = _pallas_collide_kernel(f) + # f = self.stream._streaming_jax_p(f) + + return f + + fout = ParallelOperator( + self.grid, _pallas_collide_and_stream, self.velocity_set + )(fin) + + return fout + + def _construct_warp(self): + # Make constants for warp + _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + _nr_boundary_conditions = wp.constant(len(self.boundary_conditions)) + + # Construct the kernel + @wp.kernel + def kernel( + f_0: self._warp_array_type, + f_1: self._warp_array_type, + boundary_id: self._warp_uint8_array_type, + mask: self._warp_array_bool_array_type, + timestep: wp.int32, + ): + # Get the global index + i, j, k = wp.tid() + + # Get the f, boundary id and mask + _f = self._warp_lattice_vec() + _boundary_id = boundary_id[0, i, j, k] + _mask = self._bool_lattice_vec() + for l in range(_q): + _f[l] = self.f_0[l, i, j, k] + _mask[l] = mask[l, i, j, k] + + # Compute rho and u + rho, u = self.macroscopic.functional(_f) + + # Compute equilibrium + feq = self.equilibrium.functional(rho, u) + + # Apply collision + f_post_collision = self.collision.functional( + _f, + feq, + rho, + u, + ) + + # Apply collision type boundary conditions + if _boundary_id == id_number: + f_post_collision = self.collision_boundary_conditions[ + id_number + ].functional( + _f, + f_post_collision, + _mask, + ) + f_pre_streaming = f_post_collision # store pre streaming vector + + # Apply forcing + # if self.forcing_op is not None: + # f = self.forcing.functional(f, timestep) + + # Apply streaming + for l in range(_q): + # Get the streamed indices + streamed_i, streamed_j, streamed_k = self.stream.functional( + l, i, j, k, self._warp_max_i, self._warp_max_j, self._warp_max_k + ) + streamed_l = l + + ## Modify the streamed indices based on streaming boundary condition + # if _boundary_id != 0: + # streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].functional( + # streamed_l, streamed_i, streamed_j, streamed_k, self._warp_max_i, self._warp_max_j, self._warp_max_k + # ) + + # Set the output + f_1[streamed_l, streamed_i, streamed_j, streamed_k] = f_pre_streaming[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, boundary_id, mask, timestep): + # Launch the warp kernel + wp.launch( + self._kernel, + inputs=[ + f, + rho, + u, + ], + dim=rho.shape[1:], + ) + return rho, u diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py new file mode 100644 index 0000000..48c62a8 --- /dev/null +++ b/xlb/operator/stepper/stepper.py @@ -0,0 +1,141 @@ +# Base class for all stepper operators + +import jax.numpy as jnp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition import ImplementationStep +from xlb.operator.precision_caster import PrecisionCaster + + +class Stepper(Operator): + """ + Class that handles the construction of lattice boltzmann stepping operator + """ + + def __init__( + self, + collision, + stream, + equilibrium, + macroscopic, + boundary_conditions=[], + forcing=None, + ): + # Set parameters + self.collision = collision + self.stream = stream + self.equilibrium = equilibrium + self.macroscopic = macroscopic + self.boundary_conditions = boundary_conditions + self.forcing = forcing + + # Get velocity set, precision policy, and compute backend + velocity_sets = set([op.velocity_set for op in self.operators]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + velocity_set = velocity_sets.pop() + precision_policies = set([op.precision_policy for op in self.operators]) + assert len(precision_policies) == 1, "All precision policies must be the same" + precision_policy = precision_policies.pop() + compute_backends = set([op.compute_backend for op in self.operators]) + assert len(compute_backends) == 1, "All compute backends must be the same" + compute_backend = compute_backends.pop() + + # Get collision and stream boundary conditions + self.collision_boundary_conditions = {} + self.stream_boundary_conditions = {} + for id_number, bc in enumerate(self.boundary_conditions): + bc_id = id_number + 1 + if bc.implementation_step == ImplementationStep.COLLISION: + self.collision_boundary_conditions[bc_id] = bc + elif bc.implementation_step == ImplementationStep.STREAMING: + self.stream_boundary_conditions[bc_id] = bc + else: + raise ValueError("Boundary condition step not recognized") + + # Make operators for converting the precisions + self.cast_to_compute = PrecisionCaster( + + + # Make operator for setting boundary condition arrays + self.set_boundary = SetBoundary( + self.collision_boundary_conditions, + self.stream_boundary_conditions, + velocity_set, + precision_policy, + compute_backend, + ) + + # Get all operators for checking + self.operators = [ + collision, + stream, + equilibrium, + macroscopic, + *boundary_conditions, + self.set_boundary, + ] + + # Initialize operator + super().__init__(velocity_set, precision_policy, compute_backend) + + +class SetBoundary(Operator): + """ + Class that handles the construction of lattice boltzmann boundary condition operator + This will probably never be used directly and it might be better to refactor it + """ + + def __init__( + self, + collision_boundary_conditions, + stream_boundary_conditions, + velocity_set, + precision_policy, + compute_backend, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # Set parameters + self.collision_boundary_conditions = collision_boundary_conditions + self.stream_boundary_conditions = stream_boundary_conditions + + def _apply_all_bc(self, ijk, boundary_id, mask, bc): + """ + Apply all boundary conditions + """ + for id_number, bc in self.collision_boundary_conditions.items(): + boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, id_number) + for id_number, bc in self.stream_boundary_conditions.items(): + boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, id_number) + return boundary_id, mask + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, ijk): + """ + Set boundary condition arrays + These store the boundary condition information for each boundary + """ + boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) + mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) + return self._apply_all_bc(ijk, boundary_id, mask, bc) + + @Operator.register_backend(ComputeBackend.PALLAS) + def pallas_implementation(self, ijk): + """ + Set boundary condition arrays + These store the boundary condition information for each boundary + """ + raise NotImplementedError("Pallas implementation not available") + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, ijk): + """ + Set boundary condition arrays + These store the boundary condition information for each boundary + """ + boundary_id = wp.zeros(ijk.shape[:-1], dtype=wp.uint8) + mask = wp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=wp.bool) + return self._apply_all_bc(ijk, boundary_id, mask, bc) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 5e7fc99..8f52607 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -3,34 +3,25 @@ from functools import partial import jax.numpy as jnp from jax import jit, vmap + from xlb.velocity_set.velocity_set import VelocitySet -<<<<<<< HEAD from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from jax.experimental.shard_map import shard_map -from jax.sharding import PartitionSpec as P -======= -from xlb.compute_backends import ComputeBackends -from xlb.operator import Operator -from xlb.operator import ParallelOperator ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 class Stream(Operator): """ Base class for all streaming operators. """ -<<<<<<< HEAD @Operator.register_backend(ComputeBackend.JAX) -======= def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None): self.grid = grid - self.parallel_operator = ParallelOperator(grid, self._streaming_jax_p, velocity_set) + self.parallel_operator = ParallelOperator( + grid, self._streaming_jax_p, velocity_set + ) super().__init__(velocity_set, compute_backend) - @Operator.register_backend(ComputeBackends.JAX) ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 @partial(jit, static_argnums=(0)) def jax_implementation(self, f): """ @@ -41,10 +32,6 @@ def jax_implementation(self, f): f: jax.numpy.ndarray The distribution function. """ -<<<<<<< HEAD -======= - return self.parallel_operator(f) ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 def _streaming_jax_i(f, c): """ @@ -68,7 +55,6 @@ def _streaming_jax_i(f, c): return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T ) -<<<<<<< HEAD def _construct_warp(self): # Make constants for warp @@ -140,5 +126,3 @@ def warp_implementation(self, f_0, f_1): dim=f_0.shape[1:], ) return f_1 -======= ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 diff --git a/xlb/physics_type.py b/xlb/physics_type.py index 586eefe..39841fc 100644 --- a/xlb/physics_type.py +++ b/xlb/physics_type.py @@ -2,6 +2,7 @@ from enum import Enum, auto + class PhysicsType(Enum): NSE = auto() # Navier-Stokes Equations ADE = auto() # Advection-Diffusion Equations diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 7f3b075..0ba6c1c 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -2,11 +2,13 @@ from enum import Enum, auto + class Precision(Enum): FP64 = auto() FP32 = auto() FP16 = auto() + class PrecisionPolicy(Enum): FP64FP64 = auto() FP64FP32 = auto() diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index ae96039..626969b 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -16,156 +16,133 @@ from jax.experimental import pallas as pl -class IncompressibleNavierStokes(Solver): +class IncompressibleNavierStokesSolver(Solver): + + _equilibrium_registry = { + "Quadratic": QuadraticEquilibrium, + "Linear": LinearEquilibrium, + } + _collision_registry = { + "BGK": BGK, + "KBC": KBC, + } + def __init__( self, - grid, - omega, + omega: float, + shape: tuple[int, int, int], + collision="BGK", + equilibrium="Quadratic", + boundary_conditions=[], + initializer=None, + forcing=None, velocity_set: VelocitySet = None, - compute_backend=None, precision_policy=None, - boundary_conditions=[], - collision_kernel="BGK", + compute_backend=None, + grid_backend=None, + grid_configs={}, ): - self.grid = grid - self.omega = omega - self.collision_kernel = collision_kernel super().__init__( + shape=shape, + boundary_conditions=boundary_conditions, velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, - boundary_conditions=boundary_conditions, + grid_backend=grid_backend, + grid_configs=grid_configs, ) - self.create_operators() - # Operators - def create_operators(self): - self.macroscopic = Macroscopic( - velocity_set=self.velocity_set, compute_backend=self.compute_backend - ) - self.equilibrium = QuadraticEquilibrium( - velocity_set=self.velocity_set, compute_backend=self.compute_backend - ) - self.collision = ( - KBC( - omega=self.omega, - velocity_set=self.velocity_set, - compute_backend=self.compute_backend, - ) - if self.collision_kernel == "KBC" - else BGK( - omega=self.omega, - velocity_set=self.velocity_set, - compute_backend=self.compute_backend, - ) + # Set omega + self.omega = omega + + # Add fields to grid + self.grid.create_field("rho", 1, self.precision_policy.store_precision) + self.grid.create_field("u", 3, self.precision_policy.store_precision) + self.grid.create_field("f0", self.velocity_set.q, self.precision_policy.store_precision) + self.grid.create_field("f1", self.velocity_set.q, self.precision_policy.store_precision) + + # Create operators + self.collision = self._get_collision(collision)( + omega=self.omega, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, ) self.stream = Stream( - self.grid, velocity_set=self.velocity_set, + precision_policy=self.precision_policy, compute_backend=self.compute_backend, ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0,)) - def step(self, f, timestep): - """ - Perform a single step of the lattice boltzmann method - """ - - # Cast to compute precision - f = self.precision_policy.cast_to_compute(f) - - # Compute the macroscopic variables - rho, u = self.macroscopic(f) - - # Compute equilibrium - feq = self.equilibrium(rho, u) - - # Apply collision - f_post_collision = self.collision( - f, - feq, + self.equilibrium = self._get_equilibrium(equilibrium)( + velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend + ) + self.macroscopic = Macroscopic( + velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend + ) + if initializer is None: + self.initializer = EquilibriumInitializer( + rho=1.0, u=(0.0, 0.0, 0.0), + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + if forcing is not None: + raise NotImplementedError("Forcing not yet implemented") + + # Create stepper operator + self.stepper = IncompressibleNavierStokesStepper( + collision=self.collision, + stream=self.stream, + equilibrium=self.equilibrium, + macroscopic=self.macroscopic, + boundary_conditions=self.boundary_conditions, + forcing=None, ) - # # Apply collision type boundary conditions - # for id_number, bc in self.collision_boundary_conditions.items(): - # f_post_collision = bc( - # f_pre_collision, - # f_post_collision, - # boundary_id == id_number, - # mask, - # ) - f_pre_streaming = f_post_collision - - ## Apply forcing - # if self.forcing_op is not None: - # f = self.forcing_op.apply_jax(f, timestep) - - # Apply streaming - f_post_streaming = self.stream(f_pre_streaming) - - # Apply boundary conditions - # for id_number, bc in self.stream_boundary_conditions.items(): - # f_post_streaming = bc( - # f_pre_streaming, - # f_post_streaming, - # boundary_id == id_number, - # mask, - # ) - - # Copy back to store precision - f = self.precision_policy.cast_to_store(f_post_streaming) - - return f - - @Operator.register_backend(ComputeBackends.PALLAS) - @partial(jit, static_argnums=(0,)) - def step(self, fin, timestep): - from xlb.operator.parallel_operator import ParallelOperator - - def _pallas_collide(fin, fout): - idx = pl.program_id(0) - - f = (pl.load(fin, (slice(None), idx, slice(None), slice(None)))) - - print("f shape", f.shape) - - rho, u = self.macroscopic(f) - - print("rho shape", rho.shape) - print("u shape", u.shape) - - feq = self.equilibrium(rho, u) - - print("feq shape", feq.shape) - - for i in range(self.velocity_set.q): - print("f shape", f[i].shape) - f_post_collision = self.collision(f[i], feq[i]) - print("f_post_collision shape", f_post_collision.shape) - pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) - # f_post_collision = self.collision(f, feq) - # pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) + # Add parrallelization + self.stepper = self.grid.parallelize_operator(self.stepper) - @jit - def _pallas_collide_kernel(fin): - return pl.pallas_call( - partial(_pallas_collide), - out_shape=jax.ShapeDtypeStruct( - ((self.velocity_set.q,) + (self.grid.grid_shape_per_gpu)), fin.dtype - ), - # grid=1, - grid=(self.grid.grid_shape_per_gpu[0], 1, 1), - )(fin) + # Initialize + self.initialize() - def _pallas_collide_and_stream(f): - f = _pallas_collide_kernel(f) - # f = self.stream._streaming_jax_p(f) + def initialize(self): + self.initializer(f=self.grid.get_field("f0")) - return f + def monitor(self): + pass - fout = ParallelOperator( - self.grid, _pallas_collide_and_stream, self.velocity_set - )(fin) + def run(self, steps: int, monitor_frequency: int = 1, compute_mlups: bool = False): - return fout + # Run steps + for _ in range(steps): + # Run step + self.stepper( + f0=self.grid.get_field("f0"), + f1=self.grid.get_field("f1") + ) + self.grid.swap_fields("f0", "f1") + + def checkpoint(self): + raise NotImplementedError("Checkpointing not yet implemented") + + def _get_collision(self, collision: str): + if isinstance(collision, str): + try: + return self._collision_registry[collision] + except KeyError: + raise ValueError(f"Collision {collision} not recognized for incompressible Navier-Stokes solver") + elif issubclass(collision, Operator): + return collision + else: + raise ValueError(f"Collision {collision} not recognized for incompressible Navier-Stokes solver") + + def _get_equilibrium(self, equilibrium: str): + if isinstance(equilibrium, str): + try: + return self._equilibrium_registry[equilibrium] + except KeyError: + raise ValueError(f"Equilibrium {equilibrium} not recognized for incompressible Navier-Stokes solver") + elif issubclass(equilibrium, Operator): + return equilibrium + else: + raise ValueError(f"Equilibrium {equilibrium} not recognized for incompressible Navier-Stokes solver") diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 7c15344..06a06aa 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -13,25 +13,29 @@ class Solver(Operator): def __init__( self, + shape: tuple[int, int, int], + boundary_conditions=[], velocity_set=None, - compute_backend=None, precision_policy=None, - boundary_conditions=[], + compute_backend=None, + grid_backend=None, + grid_configs={}, ): + # Set parameters + self.shape = shape self.velocity_set = velocity_set or GlobalConfig.velocity_set - self.compute_backend = compute_backend or GlobalConfig.compute_backend self.precision_policy = precision_policy or GlobalConfig.precision_policy + self.compute_backend = compute_backend or GlobalConfig.compute_backend + self.grid_backend = grid_backend or GlobalConfig.grid_backend self.boundary_conditions = boundary_conditions - # Get collision and stream boundary conditions - self.collision_boundary_conditions = {} - self.stream_boundary_conditions = {} - for id_number, bc in enumerate(self.boundary_conditions): - bc_id = id_number + 1 - if bc.implementation_step == ImplementationStep.COLLISION: - self.collision_boundary_conditions[bc_id] = bc - elif bc.implementation_step == ImplementationStep.STREAMING: - self.stream_boundary_conditions[bc_id] = bc - else: - raise ValueError("Boundary condition step not recognized") + # Make grid + if self.grid_backend is GridBackend.JAX: + self.grid = JaxGrid(**grid_configs) + elif self.grid_backend is GridBackend.WARP: + self.grid = WarpGrid(**grid_configs) + elif self.grid_backend is GridBackend.OOC + self.grid = OOCGrid(**grid_configs) + else: + raise ValueError(f"Grid backend {self.grid_backend} not recognized") diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py index 2107fc8..3c8032e 100644 --- a/xlb/utils/__init__.py +++ b/xlb/utils/__init__.py @@ -1 +1,9 @@ -from .utils import downsample_field, save_image, save_fields_vtk, save_BCs_vtk, rotate_geometry, voxelize_stl, axangle2mat +from .utils import ( + downsample_field, + save_image, + save_fields_vtk, + save_BCs_vtk, + rotate_geometry, + voxelize_stl, + axangle2mat, +) diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index c20dacc..2752314 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -140,6 +140,7 @@ def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): grid.save(output_filename, binary=True) print(f"Saved {output_filename} in {time() - start:.6f} seconds.") + def save_BCs_vtk(timestep, BCs, gridInfo, output_dir="."): """ Save boundary conditions as VTK format to the specified directory. From 3d1245e5e9e31b9d98fb0614964b7922a464b243 Mon Sep 17 00:00:00 2001 From: Oliver Date: Tue, 20 Feb 2024 16:11:35 -0800 Subject: [PATCH 012/144] almost done --- examples/CFD_refactor/windtunnel3d.py | 58 ++++++++++++++----- xlb/grid/grid.py | 15 ++++- .../boundary_condition/boundary_condition.py | 8 ++- .../boundary_masker/boundary_masker.py | 5 +- .../boundary_condition/full_bounce_back.py | 10 ++-- xlb/operator/collision/bgk.py | 2 +- .../equilibrium/quadratic_equilibrium.py | 58 +++++++++---------- xlb/operator/initializer/__init__.py | 1 + xlb/operator/initializer/const_init.py | 2 +- xlb/operator/macroscopic/macroscopic.py | 2 +- xlb/operator/operator.py | 5 +- xlb/operator/stream/stream.py | 7 --- xlb/solver/__init__.py | 3 +- xlb/solver/nse.py | 1 - xlb/solver/solver.py | 2 +- xlb/velocity_set/velocity_set.py | 4 -- 16 files changed, 105 insertions(+), 78 deletions(-) diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py index 1de0e2d..28d9208 100644 --- a/examples/CFD_refactor/windtunnel3d.py +++ b/examples/CFD_refactor/windtunnel3d.py @@ -7,6 +7,10 @@ from jax import config from xlb.solver import IncompressibleNavierStokesSolver +from xlb.velocity_set import D3Q27, D3Q19 +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.grid_backend import GridBackend from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC @@ -26,25 +30,49 @@ def __init__( dx: float = 0.01, # m viscosity: float = 1.42e-5, # air at 20 degrees Celsius density: float = 1.2754, # kg/m^3 - ): - - - - omega: float, - shape: tuple[int, int, int], collision="BGK", equilibrium="Quadratic", - boundary_conditions=[], - initializer=None, - forcing=None, - velocity_set: VelocitySet = None, - precision_policy=None, - compute_backend=None, - grid_backend=None, + velocity_set=D3Q27(), + precision_policy=PrecisionPolicy.FP32FP32, + compute_backend=ComputeBackend.JAX, + grid_backend=GridBackend.JAX, grid_configs={}, ): - - super().__init__(**kwargs) + + # Set parameters + self.inlet_velocity = inlet_velocity + self.lower_bounds = lower_bounds + self.upper_bounds = upper_bounds + self.dx = dx + self.viscosity = viscosity + self.density = density + + # Get fluid properties needed for the simulation + self.velocity_conversion = 0.05 / inlet_velocity + self.dt = self.dx * self.velocity_conversion + self.lbm_viscosity = self.viscosity * self.dt / (self.dx ** 2) + self.tau = 0.5 + self.lbm_viscosity + self.lbm_density = 1.0 + self.mass_conversion = self.dx ** 3 * (self.density / self.lbm_density) + + # Make boundary conditions + + + # Initialize the IncompressibleNavierStokesSolver + super().__init__( + omega=self.tau, + shape=shape, + collision=collision, + equilibrium=equilibrium, + boundary_conditions=boundary_conditions, + initializer=initializer, + forcing=forcing, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + grid_backend=grid_backend, + grid_configs=grid_configs, + ) def voxelize_stl(self, stl_filename, length_lbm_unit): mesh = trimesh.load_mesh(stl_filename, process=False) diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 51f6cbf..abdf2f0 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,12 +1,21 @@ from abc import ABC, abstractmethod -from xlb.compute_backend import ComputeBackend + from xlb.global_config import GlobalConfig +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy, Precision from xlb.velocity_set import VelocitySet +from xlb.operator import Operator class Grid(ABC): - def __init__(self, shape, velocity_set, precision_policy, grid_backend): + def __init__( + self, + shape : tuple, + velocity_set : VelocitySet, + precision_policy : PrecisionPolicy, + grid_backend : ComputeBackend + ): # Set parameters self.shape = shape self.velocity_set = velocity_set @@ -17,7 +26,7 @@ def __init__(self, shape, velocity_set, precision_policy, grid_backend): # Create field dict self.fields = {} - def parallelize_operator(self, operator): + def parallelize_operator(self, operator: Operator): raise NotImplementedError("Parallelization not implemented, child class must implement") @abstractmethod diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index f5ad4dd..42fa5d4 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -8,10 +8,14 @@ import numpy as np from enum import Enum -from xlb.operator.operator import Operator from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) # Enum for implementation step @@ -55,7 +59,7 @@ def from_indices( Create a boundary condition from indices and boundary id. """ # Create boundary mask - boundary_mask = IndicesBoundaryMask( + boundary_mask = IndicesBoundaryMasker( indices, stream_indices, velocity_set, precision_policy, compute_backend ) diff --git a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py index ee511e2..20bf580 100644 --- a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py +++ b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py @@ -1,14 +1,11 @@ # Base class for all equilibriums -from functools import partial -import numpy as np import jax.numpy as jnp from jax import jit import warp as wp -from typing import Tuple -from xlb.global_config import GlobalConfig from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index 4cd86b6..91fda7f 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -9,8 +9,10 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend -from xlb.operator.boundary_condition.boundary_condition import ( +from xlb.operator import Operator +from xlb.operator.boundary_condition import ( BoundaryCondition, ImplementationStep, ) @@ -48,13 +50,12 @@ def from_indices( Create a full bounce-back boundary condition from indices. """ # Create boundary mask - boundary_mask = IndicesBoundaryMask( + boundary_mask = IndicesBoundaryMasker( indices, False, velocity_set, precision_policy, compute_backend ) # Create boundary condition return cls( - ImplementationStep.COLLISION, boundary_mask, velocity_set, precision_policy, @@ -64,7 +65,8 @@ def from_indices( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): - flip = jnp.repeat(boundary, self.velocity_set.q, axis=-1) + flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) + print(flip.shape) flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) return flipped_f diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 10c4a5f..5052a70 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -21,7 +21,7 @@ def jax_implementation( fout = f - self.omega * fneq return fout - @Operator.register_backend(ComputeBackends.PALLAS) + @Operator.register_backend(ComputeBackend.PALLAS) def pallas_implementation( self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray ): diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 532820b..06ec7f3 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -27,7 +27,34 @@ def jax_implementation(self, rho, u): feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq -<<<<<<< HEAD + @Operator.register_backend(ComputeBackend.PALLAS) + def pallas_implementation(self, rho, u): + u0, u1, u2 = u[0], u[1], u[2] + usqr = 1.5 * (u0**2 + u1**2 + u2**2) + + eq = [ + rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u0 + 4.5 * u0 * u0 - usqr), + rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u1 + 4.5 * u1 * u1 - usqr), + rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u2 + 4.5 * u2 * u2 - usqr), + ] + + combined_velocities = [u0 + u1, u0 - u1, u0 + u2, u0 - u2, u1 + u2, u1 - u2] + + for vel in combined_velocities: + eq.append( + rho[0] * (1.0 / 36.0) * (1.0 - 3.0 * vel + 4.5 * vel * vel - usqr) + ) + + eq.append(rho[0] * (1.0 / 3.0) * (1.0 - usqr)) + + for i in range(3): + eq.append(eq[i] + rho[0] * (1.0 / 18.0) * 6.0 * u[i]) + + for i, vel in enumerate(combined_velocities, 3): + eq.append(eq[i] + rho[0] * (1.0 / 36.0) * 6.0 * vel) + + return jnp.array(eq) + def _construct_warp(self): # Make constants for warp _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) @@ -95,32 +122,3 @@ def warp_implementation(self, rho, u, f): dim=rho.shape[1:], ) return f -======= - @Operator.register_backend(ComputeBackends.PALLAS) - def pallas_implementation(self, rho, u): - u0, u1, u2 = u[0], u[1], u[2] - usqr = 1.5 * (u0**2 + u1**2 + u2**2) - - eq = [ - rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u0 + 4.5 * u0 * u0 - usqr), - rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u1 + 4.5 * u1 * u1 - usqr), - rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u2 + 4.5 * u2 * u2 - usqr), - ] - - combined_velocities = [u0 + u1, u0 - u1, u0 + u2, u0 - u2, u1 + u2, u1 - u2] - - for vel in combined_velocities: - eq.append( - rho[0] * (1.0 / 36.0) * (1.0 - 3.0 * vel + 4.5 * vel * vel - usqr) - ) - - eq.append(rho[0] * (1.0 / 3.0) * (1.0 - usqr)) - - for i in range(3): - eq.append(eq[i] + rho[0] * (1.0 / 18.0) * 6.0 * u[i]) - - for i, vel in enumerate(combined_velocities, 3): - eq.append(eq[i] + rho[0] * (1.0 / 36.0) * 6.0 * vel) - - return jnp.array(eq) ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 diff --git a/xlb/operator/initializer/__init__.py b/xlb/operator/initializer/__init__.py index b2d14b9..026bee9 100644 --- a/xlb/operator/initializer/__init__.py +++ b/xlb/operator/initializer/__init__.py @@ -1,2 +1,3 @@ +from xlb.operator.initializer.initializer import Initializer from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer from xlb.operator.initializer.const_init import ConstInitializer diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py index e13d2db..17235cd 100644 --- a/xlb/operator/initializer/const_init.py +++ b/xlb/operator/initializer/const_init.py @@ -32,5 +32,5 @@ def jax_implementation(self, const_value, sharding=None): @Operator.register_backend(ComputeBackends.PALLAS) @partial(jax.jit, static_argnums=(0, 2)) - def jax_implementation(self, const_value, sharding=None): + def pallas_implementation(self, const_value, sharding=None): return self.jax_implementation(const_value, sharding) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index eb8a8d2..ab537b6 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -46,7 +46,7 @@ def jax_implementation(self, f): return rho, u - @Operator.register_backend(ComputeBackends.PALLAS) + @Operator.register_backend(ComputeBackend.PALLAS) def pallas_implementation(self, f): # TODO: Maybe this can be done with jnp.sum rho = jnp.sum(f, axis=0, keepdims=True) diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 042faba..87c6f15 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -111,11 +111,10 @@ def store_dtype(self): """ return self._precision_to_dtype(self.precision_policy.store_precision) - @staticmethod - def _precision_to_dtype(precision): + def _precision_to_dtype(self, precision): """ Convert the precision to the corresponding dtype - TODO: Maybe move this + TODO: Maybe move this to precision policy? """ if precision == Precision.FP64: return self.backend.float64 diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 8f52607..e6a46c6 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -15,13 +15,6 @@ class Stream(Operator): """ @Operator.register_backend(ComputeBackend.JAX) - def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None): - self.grid = grid - self.parallel_operator = ParallelOperator( - grid, self._streaming_jax_p, velocity_set - ) - super().__init__(velocity_set, compute_backend) - @partial(jit, static_argnums=(0)) def jax_implementation(self, f): """ diff --git a/xlb/solver/__init__.py b/xlb/solver/__init__.py index 62dfc30..0304fda 100644 --- a/xlb/solver/__init__.py +++ b/xlb/solver/__init__.py @@ -1 +1,2 @@ -from xlb.solver.nse import IncompressibleNavierStokes +from xlb.solver.solver import Solver +from xlb.solver.nse import IncompressibleNavierStokesSolver diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 626969b..6c4c0c2 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -20,7 +20,6 @@ class IncompressibleNavierStokesSolver(Solver): _equilibrium_registry = { "Quadratic": QuadraticEquilibrium, - "Linear": LinearEquilibrium, } _collision_registry = { "BGK": BGK, diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 06a06aa..7d3db77 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -35,7 +35,7 @@ def __init__( self.grid = JaxGrid(**grid_configs) elif self.grid_backend is GridBackend.WARP: self.grid = WarpGrid(**grid_configs) - elif self.grid_backend is GridBackend.OOC + elif self.grid_backend is GridBackend.OOC: self.grid = OOCGrid(**grid_configs) else: raise ValueError(f"Grid backend {self.grid_backend} not recognized") diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 2e03beb..0564fde 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -5,12 +5,8 @@ from functools import partial import jax.numpy as jnp from jax import jit, vmap -<<<<<<< HEAD import warp as wp -======= ->>>>>>> a48510cefc7af0cb965b67c86854a609b7d8d1d4 - class VelocitySet(object): """ From 93a130d4a875c359c356fa526db5d002706edae9 Mon Sep 17 00:00:00 2001 From: Oliver Date: Tue, 20 Feb 2024 16:12:01 -0800 Subject: [PATCH 013/144] almost done --- xlb/operator/boundary_condition/boundary_masker/__init__.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 xlb/operator/boundary_condition/boundary_masker/__init__.py diff --git a/xlb/operator/boundary_condition/boundary_masker/__init__.py b/xlb/operator/boundary_condition/boundary_masker/__init__.py new file mode 100644 index 0000000..e33e509 --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_masker/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.boundary_condition.boundary_masker.boundary_masker import BoundaryMasker +from xlb.operator.boundary_condition.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker From 05b87bf72d146721f48954f87367d36e5b838860 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 22 Feb 2024 11:53:03 -0800 Subject: [PATCH 014/144] stepper almost working --- examples/warp_backend/equilibrium.py | 38 ------ examples/warp_backend/testing.py | 108 ++++++++++++++++++ xlb/operator/__init__.py | 1 + .../boundary_condition/boundary_condition.py | 26 +++++ .../indices_boundary_masker.py | 22 +++- .../equilibrium_boundary.py | 21 +++- .../boundary_condition/full_bounce_back.py | 13 ++- xlb/operator/collision/bgk.py | 13 ++- xlb/operator/collision/collision.py | 2 +- xlb/operator/initializer/__init__.py | 3 - xlb/operator/initializer/const_init.py | 36 ------ xlb/operator/initializer/equilibrium_init.py | 33 ------ xlb/operator/initializer/initializer.py | 13 --- xlb/operator/operator.py | 16 ++- xlb/operator/precision_caster/__init__.py | 1 + .../precision_caster/precision_caster.py | 13 +-- xlb/operator/stepper/__init__.py | 2 + xlb/operator/stepper/nse.py | 16 +-- xlb/operator/stepper/stepper.py | 30 ++--- xlb/operator/stream/stream.py | 3 +- xlb/operator/test/test.py | 1 + 21 files changed, 240 insertions(+), 171 deletions(-) delete mode 100644 examples/warp_backend/equilibrium.py create mode 100644 examples/warp_backend/testing.py delete mode 100644 xlb/operator/initializer/__init__.py delete mode 100644 xlb/operator/initializer/const_init.py delete mode 100644 xlb/operator/initializer/equilibrium_init.py delete mode 100644 xlb/operator/initializer/initializer.py create mode 100644 xlb/operator/precision_caster/__init__.py create mode 100644 xlb/operator/stepper/__init__.py create mode 100644 xlb/operator/test/test.py diff --git a/examples/warp_backend/equilibrium.py b/examples/warp_backend/equilibrium.py deleted file mode 100644 index a99ace4..0000000 --- a/examples/warp_backend/equilibrium.py +++ /dev/null @@ -1,38 +0,0 @@ -# from IPython import display -import numpy as np -import jax -import jax.numpy as jnp -import scipy -import time -from tqdm import tqdm -import matplotlib.pyplot as plt - -import warp as wp -wp.init() - -import xlb - -if __name__ == "__main__": - - # Make operator - precision_policy = xlb.PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D3Q27() - compute_backend = xlb.ComputeBackend.WARP - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - - # Make warp arrays - nr = 128 - f = wp.zeros((27, nr, nr, nr), dtype=wp.float32) - u = wp.zeros((3, nr, nr, nr), dtype=wp.float32) - rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32) - - # Run simulation - equilibrium(rho, u, f) - macroscopic(f, rho, u) diff --git a/examples/warp_backend/testing.py b/examples/warp_backend/testing.py new file mode 100644 index 0000000..3940378 --- /dev/null +++ b/examples/warp_backend/testing.py @@ -0,0 +1,108 @@ +# from IPython import display +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt + +import warp as wp +wp.init() + +import xlb + + +def test_backends(compute_backend): + + # Set parameters + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q27() + + # Make operators + collision = xlb.operator.collision.BGK( + omega=1.0, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + bounceback = xlb.operator.boundary_condition.FullBounceBack.from_indices( + indices=np.array([[0, 0, 0], [0, 0, 1]]), + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + boundary_conditions=[bounceback]) + + # Test operators + if compute_backend == xlb.ComputeBackend.WARP: + # Make warp arrays + nr = 128 + f_0 = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + f_1 = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + f_out = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + u = wp.zeros((3, nr, nr, nr), dtype=wp.float32) + rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32) + boundary_id = wp.zeros((1, nr, nr, nr), dtype=wp.uint8) + boundary = wp.zeros((1, nr, nr, nr), dtype=wp.bool) + mask = wp.zeros((27, nr, nr, nr), dtype=wp.bool) + + # Test operators + collision(f_0, f_1, rho, u, f_out) + equilibrium(rho, u, f_0) + macroscopic(f_0, rho, u) + stream(f_0, f_1) + bounceback(f_0, f_1, f_out, boundary, mask) + #bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1) + + + + elif compute_backend == xlb.ComputeBackend.JAX: + # Make jax arrays + nr = 128 + f_0 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) + f_1 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) + f_out = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) + u = jnp.zeros((3, nr, nr, nr), dtype=jnp.float32) + rho = jnp.zeros((1, nr, nr, nr), dtype=jnp.float32) + boundary_id = jnp.zeros((1, nr, nr, nr), dtype=jnp.uint8) + boundary = jnp.zeros((1, nr, nr, nr), dtype=jnp.bool_) + mask = jnp.zeros((27, nr, nr, nr), dtype=jnp.bool_) + + # Test operators + collision(f_0, f_1, rho, u) + equilibrium(rho, u) + macroscopic(f_0) + stream(f_0) + bounceback(f_0, f_1, boundary, mask) + bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1) + stepper(f_0, boundary_id, mask, 0) + + + +if __name__ == "__main__": + + # Test backends + compute_backends = [ + xlb.ComputeBackend.WARP, + xlb.ComputeBackend.JAX + ] + + for compute_backend in compute_backends: + test_backends(compute_backend) + print(f"Backend {compute_backend} passed all tests.") diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 02b8a59..501a7af 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,2 +1,3 @@ from xlb.operator.operator import Operator from xlb.operator.parallel_operator import ParallelOperator +import xlb.operator.stepper # diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 42fa5d4..95b1265 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -45,6 +45,32 @@ def __init__( # Set boundary masker self.boundary_masker = boundary_masker + @classmethod + def from_function( + cls, + implementation_step: ImplementationStep, + boundary_function, + velocity_set, + precision_policy, + compute_backend, + ): + """ + Create a boundary condition from a function. + """ + # Create boundary mask + boundary_mask = BoundaryMasker.from_function( + boundary_function, velocity_set, precision_policy, compute_backend + ) + + # Create boundary condition + return cls( + implementation_step, + boundary_mask, + velocity_set, + precision_policy, + compute_backend, + ) + @classmethod def from_indices( cls, diff --git a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py index 3b57895..fdf8ced 100644 --- a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py @@ -47,10 +47,10 @@ def _indices_to_tuple(indices): return tuple([indices[:, i] for i in range(indices.shape[1])]) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), inline=True) + #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this def jax_implementation(self, start_index, boundary_id, mask, id_number): # Get local indices from the meshgrid and the indices - local_indices = self.indices - start_index + local_indices = self.indices - np.array(start_index)[np.newaxis, :] # Remove any indices that are out of bounds local_indices = local_indices[ @@ -98,3 +98,21 @@ def jax_implementation(self, start_index, boundary_id, mask, id_number): mask = mask.at[self._indices_to_tuple(local_indices)].set(True) return boundary_id, mask + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, start_index, boundary_id, mask, id_number): + # Reuse the jax implementation, TODO: implement a warp version + # Convert to jax + boundary_id = wp.jax.to_jax(boundary_id) + mask = wp.jax.to_jax(mask) + + # Call jax implementation + boundary_id, mask = self.jax_implementation( + start_index, boundary_id, mask, id_number + ) + + # Convert back to warp + boundary_id = wp.jax.to_warp(boundary_id) + mask = wp.jax.to_warp(mask) + + return boundary_id, mask diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index fbc5418..4b47980 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -5,18 +5,24 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend -from xlb.operator.stream.stream import Stream +from xlb.operator import Operator from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) + class EquilibriumBoundary(BoundaryCondition): """ - A boundary condition that skips the streaming step. + Equilibrium boundary condition for a lattice Boltzmann method simulation. """ def __init__( @@ -25,11 +31,13 @@ def __init__( rho: float, u: tuple[float, float], equilibrium: Equilibrium, + boundary_masker: BoundaryMasker, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, ): super().__init__( - set_boundary=set_boundary, + ImplementationStep.COLLISION, implementation_step=ImplementationStep.STREAMING, velocity_set=velocity_set, compute_backend=compute_backend, @@ -39,12 +47,13 @@ def __init__( @classmethod def from_indices( cls, - indices, + indices: np.ndarray, rho: float, u: tuple[float, float], equilibrium: Equilibrium, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index 91fda7f..ed0ec5a 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -7,6 +7,7 @@ import jax.lax as lax from functools import partial import numpy as np +import warp as wp from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -32,7 +33,7 @@ def __init__( boundary_masker: BoundaryMasker, velocity_set: VelocitySet, precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, + compute_backend: ComputeBackend, ): super().__init__( ImplementationStep.COLLISION, @@ -66,13 +67,12 @@ def from_indices( @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) - print(flip.shape) flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) return flipped_f def _construct_warp(self): # Make constants for warp - _opp_indices = wp.constant(self.velocity_set.opp_indices) + _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) @@ -107,7 +107,12 @@ def kernel( for l in range(_q): _f_pre[l] = f_pre[l, i, j, k] _f_post[l] = f_post[l, i, j, k] - _mask[l] = mask[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) # Check if the boundary is active if boundary[i, j, k]: diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 5052a70..4071345 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,5 +1,7 @@ import jax.numpy as jnp from jax import jit +import warp as wp + from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision @@ -18,7 +20,7 @@ def jax_implementation( self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray ): fneq = f - feq - fout = f - self.omega * fneq + fout = f - self.compute_dtype(self.omega) * fneq return fout @Operator.register_backend(ComputeBackend.PALLAS) @@ -35,6 +37,7 @@ def _construct_warp(self): _q = wp.constant(self.velocity_set.q) _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) _d = wp.constant(self.velocity_set.d) + _omega = wp.constant(self.compute_dtype(self.omega)) # Construct the functional @wp.func @@ -45,7 +48,7 @@ def functional( u: self._warp_u_vec, ) -> self._warp_lattice_vec: fneq = f - feq - fout = f - self.omega * fneq + fout = f - _omega * fneq return fout # Construct the warp kernel @@ -66,7 +69,11 @@ def kernel( for l in range(_q): _f[l] = f[l, i, j, k] _feq[l] = feq[l, i, j, k] - _fout = functional(_f, _feq) + _u = self._warp_u_vec() + for l in range(_d): + _u[l] = u[l, i, j, k] + _rho = rho[0, i, j, k] + _fout = functional(_f, _feq, _rho, _u) # Write the result for l in range(_q): diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 1fe0a5b..acf4538 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -26,5 +26,5 @@ def __init__( precision_policy=None, compute_backend=None, ): - super().__init__(velocity_set, precision_policy, compute_backend) self.omega = omega + super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/initializer/__init__.py b/xlb/operator/initializer/__init__.py deleted file mode 100644 index 026bee9..0000000 --- a/xlb/operator/initializer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from xlb.operator.initializer.initializer import Initializer -from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer -from xlb.operator.initializer.const_init import ConstInitializer diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py deleted file mode 100644 index 17235cd..0000000 --- a/xlb/operator/initializer/const_init.py +++ /dev/null @@ -1,36 +0,0 @@ -from xlb.velocity_set import VelocitySet -from xlb.global_config import GlobalConfig -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.grid.grid import Grid -from functools import partial -import numpy as np -import jax - - -class ConstInitializer(Operator): - def __init__( - self, - type=np.float32, - velocity_set: VelocitySet = None, - compute_backend: ComputeBackends = None, - ): - self.type = type - self.grid = grid - velocity_set = velocity_set or GlobalConfig.velocity_set - compute_backend = compute_backend or GlobalConfig.compute_backend - - super().__init__(velocity_set, compute_backend) - - @Operator.register_backend(ComputeBackends.JAX) - @partial(jax.jit, static_argnums=(0, 2)) - def jax_implementation(self, const_value, sharding=None): - if sharding is None: - sharding = self.grid.sharding - x = jax.numpy.full(shape=self.shape, fill_value=const_value, dtype=self.type) - return jax.lax.with_sharding_constraint(x, sharding) - - @Operator.register_backend(ComputeBackends.PALLAS) - @partial(jax.jit, static_argnums=(0, 2)) - def pallas_implementation(self, const_value, sharding=None): - return self.jax_implementation(const_value, sharding) diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py deleted file mode 100644 index 5d96fbb..0000000 --- a/xlb/operator/initializer/equilibrium_init.py +++ /dev/null @@ -1,33 +0,0 @@ -from xlb.velocity_set import VelocitySet -from xlb.global_config import GlobalConfig -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.grid.grid import Grid -import numpy as np -import jax - - -class EquilibriumInitializer(Operator): - def __init__( - self, - grid: Grid, - velocity_set: VelocitySet = None, - compute_backend: ComputeBackends = None, - ): - velocity_set = velocity_set or GlobalConfig.velocity_set - compute_backend = compute_backend or GlobalConfig.compute_backend - local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) - - self.init_values = np.zeros( - grid.field_global_to_local_shape(grid.pop_shape) - ) + velocity_set.w.reshape(local_shape) - - super().__init__(velocity_set, compute_backend) - - @Operator.register_backend(ComputeBackends.JAX) - def jax_implementation(self, index): - return self.init_values - - @Operator.register_backend(ComputeBackends.PALLAS) - def jax_implementation(self, index): - return self.init_values diff --git a/xlb/operator/initializer/initializer.py b/xlb/operator/initializer/initializer.py deleted file mode 100644 index e813934..0000000 --- a/xlb/operator/initializer/initializer.py +++ /dev/null @@ -1,13 +0,0 @@ -from xlb.velocity_set import VelocitySet -from xlb.global_config import GlobalConfig -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.grid.grid import Grid -import numpy as np -import jax - - -class Initializer(Operator): - """ - Base class for all initializers. - """ diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 87c6f15..81b3035 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -92,7 +92,7 @@ def backend(self): This should be used with caution as all backends may not have the same API. """ if self.compute_backend == ComputeBackend.JAX: - import jax as backend + import jax.numpy as backend elif self.compute_backend == ComputeBackend.WARP: import warp as backend return backend @@ -152,7 +152,8 @@ def _warp_bool_lattice_vec(self): """ Returns the warp type for the streaming matrix (c) """ - return wp.vec(self.velocity_set.q, dtype=wp.bool) + #return wp.vec(self.velocity_set.q, dtype=wp.bool) + return wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO bool breaks @property def _warp_stream_mat(self): @@ -163,6 +164,15 @@ def _warp_stream_mat(self): (self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype ) + @property + def _warp_int_stream_mat(self): + """ + Returns the warp type for the streaming matrix (c) + """ + return wp.mat( + (self.velocity_set.d, self.velocity_set.q), dtype=wp.int32 + ) + @property def _warp_array_type(self): """ @@ -199,4 +209,4 @@ def _construct_warp(self): TODO: Maybe a better way to do this? Maybe add this to the backend decorator? """ - raise NotImplementedError("Children must implement this method") + return None, None diff --git a/xlb/operator/precision_caster/__init__.py b/xlb/operator/precision_caster/__init__.py new file mode 100644 index 0000000..a027c52 --- /dev/null +++ b/xlb/operator/precision_caster/__init__.py @@ -0,0 +1 @@ +from xlb.operator.precision_caster.precision_caster import PrecisionCaster diff --git a/xlb/operator/precision_caster/precision_caster.py b/xlb/operator/precision_caster/precision_caster.py index be676f4..cb441c5 100644 --- a/xlb/operator/precision_caster/precision_caster.py +++ b/xlb/operator/precision_caster/precision_caster.py @@ -16,7 +16,7 @@ class PrecisionCaster(Operator): """ - Class that handles the construction of lattice boltzmann precision casting operator + Class that handles the construction of lattice boltzmann precision casting operator. """ def __init__( @@ -84,15 +84,14 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout): + def warp_implementation(self, from_f, to_f): # Launch the warp kernel wp.launch( self._kernel, inputs=[ - f, - feq, - fout, + from_f, + to_f, ], - dim=f.shape[1:], + dim=from_f.shape[1:], ) - return fout + return to_f diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py new file mode 100644 index 0000000..44ff137 --- /dev/null +++ b/xlb/operator/stepper/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.stepper.nse import IncompressibleNavierStokesStepper diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index ebb9e13..55a7703 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -1,14 +1,16 @@ # Base class for all stepper operators +from logging import warning from functools import partial from jax import jit -from logging import warning +import warp as wp -from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend -from xlb.operator.stepper.stepper import Stepper +from xlb.operator import Operator +from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.collision.bgk import BGK +from xlb.operator.collision import BGK class IncompressibleNavierStokesStepper(Stepper): @@ -71,7 +73,7 @@ def apply_jax(self, f, boundary_id, mask, timestep): return f - @Operator.register_backend(ComputeBackends.PALLAS) + @Operator.register_backend(ComputeBackend.PALLAS) @partial(jit, static_argnums=(0,)) def apply_pallas(self, fin, boundary_id, mask, timestep): # Raise warning that the boundary conditions are not implemented @@ -141,7 +143,7 @@ def kernel( f_0: self._warp_array_type, f_1: self._warp_array_type, boundary_id: self._warp_uint8_array_type, - mask: self._warp_array_bool_array_type, + mask: self._warp_bool_array_type, timestep: wp.int32, ): # Get the global index @@ -201,7 +203,7 @@ def kernel( # Set the output f_1[streamed_l, streamed_i, streamed_j, streamed_k] = f_pre_streaming[l] - return functional, kernel + return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f, boundary_id, mask, timestep): diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 48c62a8..b65924b 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,10 +1,13 @@ # Base class for all stepper operators +from functools import partial import jax.numpy as jnp +from jax import jit +import warp as wp -from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator +from xlb.operator import Operator from xlb.operator.boundary_condition import ImplementationStep from xlb.operator.precision_caster import PrecisionCaster @@ -31,6 +34,15 @@ def __init__( self.boundary_conditions = boundary_conditions self.forcing = forcing + # Get all operators for checking + self.operators = [ + collision, + stream, + equilibrium, + macroscopic, + *boundary_conditions, + ] + # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators]) assert len(velocity_sets) == 1, "All velocity sets must be the same" @@ -55,8 +67,7 @@ def __init__( raise ValueError("Boundary condition step not recognized") # Make operators for converting the precisions - self.cast_to_compute = PrecisionCaster( - + #self.cast_to_compute = PrecisionCaster( # Make operator for setting boundary condition arrays self.set_boundary = SetBoundary( @@ -66,16 +77,7 @@ def __init__( precision_policy, compute_backend, ) - - # Get all operators for checking - self.operators = [ - collision, - stream, - equilibrium, - macroscopic, - *boundary_conditions, - self.set_boundary, - ] + self.operators.append(self.set_boundary) # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index e6a46c6..ebb932b 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -3,6 +3,7 @@ from functools import partial import jax.numpy as jnp from jax import jit, vmap +import warp as wp from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -51,7 +52,7 @@ def _streaming_jax_i(f, c): def _construct_warp(self): # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _c = wp.constant(self._warp_int_stream_mat(self.velocity_set.c)) _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) diff --git a/xlb/operator/test/test.py b/xlb/operator/test/test.py new file mode 100644 index 0000000..7d4290a --- /dev/null +++ b/xlb/operator/test/test.py @@ -0,0 +1 @@ +x = 1 From 9ebf2441c6e77b9635f6691e70d41a6692429272 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 22 Feb 2024 11:56:22 -0800 Subject: [PATCH 015/144] parralization --- xlb/grid/jax_grid.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 42687c8..092ea1d 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -34,6 +34,47 @@ def _initialize_jax_backend(self): self.grid_shape[0] // self.nDevices, ) + self.grid_shape[1:] + + def parallelize_operator(self, operator: Operator): + # TODO: fix this + + # Make parallel function + def _parallel_operator(f): + rightPerm = [ + (i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices) + ] + leftPerm = [ + ((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices) + ] + f = self.func(f) + left_comm, right_comm = ( + f[self.velocity_set.right_indices, :1, ...], + f[self.velocity_set.left_indices, -1:, ...], + ) + left_comm, right_comm = ( + lax.ppermute(left_comm, perm=rightPerm, axis_name="x"), + lax.ppermute(right_comm, perm=leftPerm, axis_name="x"), + ) + f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) + f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) + + return f + + in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) + out_specs = in_specs + + f = shard_map( + self._parallel_func, + mesh=self.grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + )(f) + return f + + + + def create_field(self, name: str, cardinality: int, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape) From 299c3ccec73a11b6127800fe757a176f1166dbc5 Mon Sep 17 00:00:00 2001 From: Oliver Date: Tue, 27 Feb 2024 16:21:43 -0800 Subject: [PATCH 016/144] stepper functional --- examples/interfaces/functional_interface.py | 136 ++++++++++++++++++ xlb/grid/__init__.py | 2 + xlb/grid/grid.py | 19 +-- xlb/grid/jax_grid.py | 14 +- xlb/grid/warp_grid.py | 26 ++++ .../equilibrium/quadratic_equilibrium.py | 2 +- xlb/operator/operator.py | 6 +- xlb/operator/stepper/nse.py | 66 +++++---- xlb/operator/stepper/stepper.py | 70 +++++++++ 9 files changed, 286 insertions(+), 55 deletions(-) create mode 100644 examples/interfaces/functional_interface.py create mode 100644 xlb/grid/warp_grid.py diff --git a/examples/interfaces/functional_interface.py b/examples/interfaces/functional_interface.py new file mode 100644 index 0000000..f9dc016 --- /dev/null +++ b/examples/interfaces/functional_interface.py @@ -0,0 +1,136 @@ +# Simple Taylor green example using the functional interface to xlb + +import time +from tqdm import tqdm +import matplotlib.pyplot as plt + +import warp as wp +wp.init() + +import xlb +from xlb.operator import Operator + +class TaylorGreenInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: self._warp_array_type, + u: self._warp_array_type, + vel: float, + nr: int, + ): + # Get the global index + i, j, k = wp.tid() + + # Get real pos + x = 2.0 * wp.pi * wp.float(i) / wp.float(nr) + y = 2.0 * wp.pi * wp.float(j) / wp.float(nr) + z = 2.0 * wp.pi * wp.float(k) / wp.float(nr) + + # Compute u + u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) + u[1, i, j, k] = - vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + u[2, i, j, k] = 0.0 + + # Compute rho + rho[0, i, j, k] = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * ( + wp.cos(2.0 * x) + + (wp.cos(2.0 * y) + * (wp.cos(2.0 * z) + 2.0)) + ) + + 1.0 + ) + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u, vel, nr): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + vel, + nr, + ], + dim=rho.shape[1:], + ) + return rho, u + +if __name__ == "__main__": + + # Set parameters + compute_backend = xlb.ComputeBackend.WARP + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q19() + + # Make feilds + nr = 256 + shape = (nr, nr, nr) + grid = xlb.grid.WarpGrid(shape=shape) + rho = grid.create_field(cardinality=1, dtype=wp.float32) + u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) + f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) + mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + + # Make operators + initializer = TaylorGreenInitializer( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + collision = xlb.operator.collision.BGK( + omega=1.0, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + boundary_conditions=[]) + + # Parrallelize the stepper + #stepper = grid.parallelize_operator(stepper) + + # Set initial conditions + rho, u = initializer(rho, u, 0.1, nr) + f0 = equilibrium(rho, u, f0) + + # Plot initial conditions + #plt.imshow(f0[0, nr//2, :, :].numpy()) + #plt.show() + + # Time stepping + num_steps = 1024 + start = time.time() + for _ in tqdm(range(num_steps)): + f1 = stepper(f0, f1, boundary_id, mask, _) + f1, f0 = f0, f1 + wp.synchronize() + end = time.time() + + # Print MLUPS + print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index d44ce65..a777f8f 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1 +1,3 @@ from xlb.grid.grid import Grid +from xlb.grid.warp_grid import WarpGrid +from xlb.grid.jax_grid import JaxGrid diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index abdf2f0..2276929 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -12,19 +12,10 @@ class Grid(ABC): def __init__( self, shape : tuple, - velocity_set : VelocitySet, - precision_policy : PrecisionPolicy, - grid_backend : ComputeBackend ): # Set parameters self.shape = shape - self.velocity_set = velocity_set - self.precision_policy = precision_policy - self.grid_backend = grid_backend - self.dim = self.velocity_set.d - - # Create field dict - self.fields = {} + self.dim = len(shape) def parallelize_operator(self, operator: Operator): raise NotImplementedError("Parallelization not implemented, child class must implement") @@ -33,10 +24,4 @@ def parallelize_operator(self, operator: Operator): def create_field( self, name: str, cardinality: int, precision: Precision, callback=None ): - pass - - def get_field(self, name: str): - return self.fields[name] - - def swap_fields(self, field1, field2): - self.fields[field1], self.fields[field2] = self.fields[field2], self.fields[field1] + raise NotImplementedError("create_field not implemented, child class must implement") diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 092ea1d..1698c34 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -1,15 +1,15 @@ -from xlb.grid.grid import Grid -from xlb.compute_backends import ComputeBackends from jax.sharding import PartitionSpec as P from jax.sharding import NamedSharding, Mesh from jax.experimental import mesh_utils -from xlb.operator.initializer import ConstInitializer import jax +from xlb.grid import Grid +from xlb.compute_backend import ComputeBackend +from xlb.operator import Operator class JaxGrid(Grid): - def __init__(self, grid_shape, velocity_set, precision_policy, grid_backend): - super().__init__(grid_shape, velocity_set, precision_policy, grid_backend) + def __init__(self, shape): + super().__init__(shape) self._initialize_jax_backend() def _initialize_jax_backend(self): @@ -73,8 +73,6 @@ def _parallel_operator(f): return f - - def create_field(self, name: str, cardinality: int, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape) @@ -88,4 +86,4 @@ def create_field(self, name: str, cardinality: int, callback=None): f = jax.make_array_from_callback(shape, self.sharding, callback) # Add field to the field dictionary - self.fields[name] = f + return f diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py new file mode 100644 index 0000000..e4d160e --- /dev/null +++ b/xlb/grid/warp_grid.py @@ -0,0 +1,26 @@ +import warp as wp + +from xlb.grid import Grid +from xlb.operator import Operator + +class WarpGrid(Grid): + def __init__(self, shape): + super().__init__(shape) + + def parallelize_operator(self, operator: Operator): + # TODO: Implement parallelization of the operator + raise NotImplementedError("Parallelization of the operator is not implemented yet for the WarpGrid") + + def create_field(self, cardinality: int, dtype, callback=None): + # Get shape of the field + shape = (cardinality,) + (self.shape) + + # Create the field + f = wp.zeros(shape, dtype=dtype) + + # Raise error on callback + if callback is not None: + raise ValueError("Callback is not supported in the WarpGrid") + + # Add field to the field dictionary + return f diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 06ec7f3..b5e5dc0 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -113,7 +113,7 @@ def kernel( def warp_implementation(self, rho, u, f): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ rho, u, diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 81b3035..f3ea901 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -27,7 +27,7 @@ def __init__(self, velocity_set, precision_policy, compute_backend): # Construct the kernel based backend functions TODO: Maybe move this to the register or something if self.compute_backend == ComputeBackend.WARP: - self._functional, self._kernel = self._construct_warp() + self.warp_functional, self.warp_kernel = self._construct_warp() @classmethod def register_backend(cls, backend_name): @@ -189,9 +189,9 @@ def _warp_uint8_array_type(self): Returns the warp type for arrays """ if self.velocity_set.d == 2: - return wp.array3d(dtype=wp.bool) + return wp.array3d(dtype=wp.uint8) elif self.velocity_set.d == 3: - return wp.array4d(dtype=wp.bool) + return wp.array4d(dtype=wp.uint8) @property def _warp_bool_array_type(self): diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index 55a7703..d5151c9 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -136,6 +136,7 @@ def _construct_warp(self): _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) _nr_boundary_conditions = wp.constant(len(self.boundary_conditions)) + print(_q) # Construct the kernel @wp.kernel @@ -145,6 +146,9 @@ def kernel( boundary_id: self._warp_uint8_array_type, mask: self._warp_bool_array_type, timestep: wp.int32, + max_i: wp.int32, + max_j: wp.int32, + max_k: wp.int32, ): # Get the global index i, j, k = wp.tid() @@ -152,51 +156,56 @@ def kernel( # Get the f, boundary id and mask _f = self._warp_lattice_vec() _boundary_id = boundary_id[0, i, j, k] - _mask = self._bool_lattice_vec() + _mask = self._warp_bool_lattice_vec() for l in range(_q): - _f[l] = self.f_0[l, i, j, k] - _mask[l] = mask[l, i, j, k] + _f[l] = f_0[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) # Compute rho and u - rho, u = self.macroscopic.functional(_f) + rho, u = self.macroscopic.warp_functional(_f) # Compute equilibrium - feq = self.equilibrium.functional(rho, u) + feq = self.equilibrium.warp_functional(rho, u) # Apply collision - f_post_collision = self.collision.functional( + f_post_collision = self.collision.warp_functional( _f, feq, rho, u, ) - # Apply collision type boundary conditions - if _boundary_id == id_number: - f_post_collision = self.collision_boundary_conditions[ - id_number - ].functional( - _f, - f_post_collision, - _mask, - ) + ## Apply collision type boundary conditions + #if _boundary_id != wp.uint8(0): + # f_post_collision = self.collision_boundary_conditions[ + # _boundary_id + # ].warp_functional( + # _f, + # f_post_collision, + # _mask, + # ) f_pre_streaming = f_post_collision # store pre streaming vector # Apply forcing # if self.forcing_op is not None: - # f = self.forcing.functional(f, timestep) + # f = self.forcing.warp_functional(f, timestep) # Apply streaming for l in range(_q): # Get the streamed indices - streamed_i, streamed_j, streamed_k = self.stream.functional( - l, i, j, k, self._warp_max_i, self._warp_max_j, self._warp_max_k + streamed_i, streamed_j, streamed_k = self.stream.warp_functional( + l, i, j, k, max_i, max_j, max_k ) streamed_l = l ## Modify the streamed indices based on streaming boundary condition # if _boundary_id != 0: - # streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].functional( + # streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].warp_functional( # streamed_l, streamed_i, streamed_j, streamed_k, self._warp_max_i, self._warp_max_j, self._warp_max_k # ) @@ -206,15 +215,20 @@ def kernel( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, boundary_id, mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_id, mask, timestep): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ - f, - rho, - u, + f_0, + f_1, + boundary_id, + mask, + timestep, + f_0.shape[1], + f_0.shape[2], + f_0.shape[3], ], - dim=rho.shape[1:], + dim=f_0.shape[1:], ) - return rho, u + return f_1 diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index b65924b..005e54d 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -82,6 +82,76 @@ def __init__( # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) + ###################################################### + # TODO: This is a hacky way to do this. Need to refactor + ###################################################### + """ + def _construct_warp_bc_functional(self): + # identity collision boundary condition + @wp.func + def identity( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + mask: self._warp_bool_lattice_vec, + ): + return f_post + def get_bc_functional(id_number, self.collision_boundary_conditions): + if id_number in self.collision_boundary_conditions.keys(): + return self.collision_boundary_conditions[id_number].warp_functional + else: + return identity + + # Manually set the boundary conditions TODO: Extremely hacky + collision_bc_functional_0 = get_bc_functional(0, self.collision_boundary_conditions) + collision_bc_functional_1 = get_bc_functional(1, self.collision_boundary_conditions) + collision_bc_functional_2 = get_bc_functional(2, self.collision_boundary_conditions) + collision_bc_functional_3 = get_bc_functional(3, self.collision_boundary_conditions) + collision_bc_functional_4 = get_bc_functional(4, self.collision_boundary_conditions) + collision_bc_functional_5 = get_bc_functional(5, self.collision_boundary_conditions) + collision_bc_functional_6 = get_bc_functional(6, self.collision_boundary_conditions) + collision_bc_functional_7 = get_bc_functional(7, self.collision_boundary_conditions) + collision_bc_functional_8 = get_bc_functional(8, self.collision_boundary_conditions) + + # Make the warp boundary condition functional + @wp.func + def warp_bc( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + mask: self._warp_bool_lattice_vec, + boundary_id: wp.uint8, + ): + if boundary_id == 0: + f_post = collision_bc_functional_0(f_pre, f_post, mask) + elif boundary_id == 1: + f_post = collision_bc_functional_1(f_pre, f_post, mask) + elif boundary_id == 2: + f_post = collision_bc_functional_2(f_pre, f_post, mask) + elif boundary_id == 3: + f_post = collision_bc_functional_3(f_pre, f_post, mask) + elif boundary_id == 4: + f_post = collision_bc_functional_4(f_pre, f_post, mask) + elif boundary_id == 5: + f_post = collision_bc_functional_5(f_pre, f_post, mask) + elif boundary_id == 6: + f_post = collision_bc_functional_6(f_pre, f_post, mask) + elif boundary_id == 7: + f_post = collision_bc_functional_7(f_pre, f_post, mask) + elif boundary_id == 8: + f_post = collision_bc_functional_8(f_pre, f_post, mask) + + return f_post + + + + + ###################################################### + """ + + + + + + class SetBoundary(Operator): """ From cb8459c0144e7d42a36f9bcde25cb50ccec41d10 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 28 Feb 2024 22:46:40 -0800 Subject: [PATCH 017/144] fixed warp bug --- examples/interfaces/functional_interface.py | 24 +++++++++++++------ .../equilibrium/quadratic_equilibrium.py | 8 +++---- xlb/operator/macroscopic/macroscopic.py | 6 ++--- xlb/operator/stepper/nse.py | 9 ++++--- xlb/operator/stream/stream.py | 6 ++--- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/examples/interfaces/functional_interface.py b/examples/interfaces/functional_interface.py index f9dc016..e1419a1 100644 --- a/examples/interfaces/functional_interface.py +++ b/examples/interfaces/functional_interface.py @@ -2,6 +2,7 @@ import time from tqdm import tqdm +import os import matplotlib.pyplot as plt import warp as wp @@ -16,6 +17,7 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel( + f0: self._warp_array_type, rho: self._warp_array_type, u: self._warp_array_type, vel: float, @@ -51,11 +53,12 @@ def kernel( return None, kernel @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel, nr): + def warp_implementation(self, f0, rho, u, vel, nr): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ + f0, rho, u, vel, @@ -89,7 +92,7 @@ def warp_implementation(self, rho, u, vel, nr): precision_policy=precision_policy, compute_backend=compute_backend) collision = xlb.operator.collision.BGK( - omega=1.0, + omega=1.9, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) @@ -116,19 +119,26 @@ def warp_implementation(self, rho, u, vel, nr): #stepper = grid.parallelize_operator(stepper) # Set initial conditions - rho, u = initializer(rho, u, 0.1, nr) + rho, u = initializer(f0, rho, u, 0.1, nr) f0 = equilibrium(rho, u, f0) - # Plot initial conditions - #plt.imshow(f0[0, nr//2, :, :].numpy()) - #plt.show() - # Time stepping + plot_freq = 32 + save_dir = "taylor_green" + os.makedirs(save_dir, exist_ok=True) + #compute_mlup = False # Plotting results + compute_mlup = True num_steps = 1024 start = time.time() for _ in tqdm(range(num_steps)): f1 = stepper(f0, f1, boundary_id, mask, _) f1, f0 = f0, f1 + if (_ % plot_freq == 0) and (not compute_mlup): + rho, u = macroscopic(f0, rho, u) + plt.imshow(u[0, :, nr//2, :].numpy()) + plt.colorbar() + plt.savefig(f"{save_dir}/{str(_).zfill(4)}.png") + plt.close() wp.synchronize() end = time.time() diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index b5e5dc0..a1245f1 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -69,12 +69,12 @@ def functional( ) -> self._warp_lattice_vec: feq = self._warp_lattice_vec() # empty lattice vector for l in range(_q): - # Compute cu + ## Compute cu cu = self.compute_dtype(0.0) for d in range(_d): - if _c[l, d] == 1: + if _c[d, l] == 1: cu += u[d] - elif _c[l, d] == -1: + elif _c[d, l] == -1: cu -= u[d] cu *= self.compute_dtype(3.0) @@ -99,7 +99,7 @@ def kernel( # Get the equilibrium _u = self._warp_u_vec() for d in range(_d): - _u[i] = u[d, i, j, k] + _u[d] = u[d, i, j, k] _rho = rho[0, i, j, k] feq = functional(_rho, _u) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index ab537b6..97bd10a 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -88,9 +88,9 @@ def functional(f: self._warp_lattice_vec): for l in range(_q): rho += f[l] for d in range(_d): - if _c[l, d] == 1: + if _c[d, l] == 1: u[d] += f[l] - elif _c[l, d] == -1: + elif _c[d, l] == -1: u[d] -= f[l] u /= rho @@ -124,7 +124,7 @@ def kernel( def warp_implementation(self, f, rho, u): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ f, rho, diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index d5151c9..7b0acfe 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -136,7 +136,6 @@ def _construct_warp(self): _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) _nr_boundary_conditions = wp.constant(len(self.boundary_conditions)) - print(_q) # Construct the kernel @wp.kernel @@ -145,10 +144,10 @@ def kernel( f_1: self._warp_array_type, boundary_id: self._warp_uint8_array_type, mask: self._warp_bool_array_type, - timestep: wp.int32, - max_i: wp.int32, - max_j: wp.int32, - max_k: wp.int32, + timestep: int, + max_i: int, + max_j: int, + max_k: int, ): # Get the global index i, j, k = wp.tid() diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index ebb932b..5033cac 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -67,9 +67,9 @@ def functional( max_j: int, max_k: int, ): - streamed_i = i + _c[l, 0] - streamed_j = j + _c[l, 1] - streamed_k = k + _c[l, 2] + streamed_i = i + _c[0, l] + streamed_j = j + _c[1, l] + streamed_k = k + _c[2, l] if streamed_i < 0: streamed_i = max_i - 1 elif streamed_i >= max_i: From b62bab108b74c2bcadb0148164b7a93c4df53a63 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 6 Mar 2024 16:25:12 -0800 Subject: [PATCH 018/144] first commit --- xlb/operator/boundary_condition/__init__.py | 2 + .../boundary_condition/boundary_condition.py | 50 ++++++++++++++ .../boundary_condition_registry.py | 25 +++++++ .../collision_boundary_condition.py | 67 +++++++++++++++++++ xlb/operator/boundary_condition/do_nothing.py | 2 + .../equilibrium_boundary.py | 2 + .../boundary_condition/full_bounce_back.py | 5 +- .../boundary_condition/halfway_bounce_back.py | 2 + xlb/operator/stepper/stepper.py | 62 +++++++++++------ 9 files changed, 194 insertions(+), 23 deletions(-) create mode 100644 xlb/operator/boundary_condition/boundary_condition_registry.py create mode 100644 xlb/operator/boundary_condition/collision_boundary_condition.py diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 3d10b59..6a0b10a 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -6,3 +6,5 @@ from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBack from xlb.operator.boundary_condition.do_nothing import DoNothing from xlb.operator.boundary_condition.equilibrium_boundary import EquilibriumBoundary + +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 95b1265..8d1daea 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -112,3 +112,53 @@ def from_stl( Create a boundary condition from an STL file. """ raise NotImplementedError + + +class CollisionBoundaryCondition(Operator): + """ + Class for combining collision and boundary conditions together + into a single operator. + """ + + def __init__( + self, + boundary_conditions: list[BoundaryCondition], + ): + # Set boundary conditions + self.boundary_conditions = boundary_conditions + + # Check that all boundary conditions have the same implementation step other properties + for bc in self.boundary_conditions: + assert bc.implementation_step == ImplementationStep.COLLISION, ( + "All boundary conditions must be applied during the collision step." + ) + + # Get velocity set, precision policy, and compute backend + velocity_sets = set([bc.velocity_set for bc in self.boundary_conditions]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + velocity_set = velocity_sets.pop() + precision_policies = set([bc.precision_policy for bc in self.boundary_conditions]) + assert len(precision_policies) == 1, "All precision policies must be the same" + precision_policy = precision_policies.pop() + compute_backends = set([bc.compute_backend for bc in self.boundary_conditions]) + assert len(compute_backends) == 1, "All compute backends must be the same" + compute_backend = compute_backends.pop() + + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, mask, boundary_id): + """ + Apply collision boundary conditions + """ + for bc in self.boundary_conditions: + f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) + return f_post, mask + + + def _construct_warp(self): diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py new file mode 100644 index 0000000..23a0d17 --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -0,0 +1,25 @@ +""" +Registry for boundary conditions in a LBM simulation. +""" + +class BoundaryConditionRegistry: + """ + Registry for boundary conditions in a LBM simulation. + """ + + def __init__( + self, + ): + self.ids = {} + self.next_id = 0 + + def register_boundary_condition(self, boundary_condition): + """ + Register a boundary condition. + """ + id = self.next_id + self.next_id += 1 + self.ids[boundary_condition] = id + return id + +boundary_condition_registry = BoundaryConditionRegistry() diff --git a/xlb/operator/boundary_condition/collision_boundary_condition.py b/xlb/operator/boundary_condition/collision_boundary_condition.py new file mode 100644 index 0000000..c75f602 --- /dev/null +++ b/xlb/operator/boundary_condition/collision_boundary_condition.py @@ -0,0 +1,67 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +# Import all collision boundary conditions +from xlb.boundary_condition.full_bounce_back import FullBounceBack + + +class CollisionBoundaryCondition(Operator): + """ + Class for combining collision and boundary conditions together + into a single operator. + """ + + def __init__( + self, + boundary_conditions: list[BoundaryCondition], + ): + # Set boundary conditions + self.boundary_conditions = boundary_conditions + + # Check that all boundary conditions have the same implementation step other properties + for bc in self.boundary_conditions: + assert bc.implementation_step == ImplementationStep.COLLISION, ( + "All boundary conditions must be applied during the collision step." + ) + + # Get velocity set, precision policy, and compute backend + velocity_sets = set([bc.velocity_set for bc in self.boundary_conditions]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + velocity_set = velocity_sets.pop() + precision_policies = set([bc.precision_policy for bc in self.boundary_conditions]) + assert len(precision_policies) == 1, "All precision policies must be the same" + precision_policy = precision_policies.pop() + compute_backends = set([bc.compute_backend for bc in self.boundary_conditions]) + assert len(compute_backends) == 1, "All compute backends must be the same" + compute_backend = compute_backends.pop() + + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, mask, boundary_id): + """ + Apply collision boundary conditions + """ + for bc in self.boundary_conditions: + f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) + return f_post, mask + + def _construct_warp(self): + diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py index 6251660..9f85da2 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -10,12 +10,14 @@ BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry class DoNothing(BoundaryCondition): """ A boundary condition that skips the streaming step. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( self, diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index 4b47980..fc0cc11 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -17,6 +17,7 @@ BoundaryMasker, IndicesBoundaryMasker, ) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry @@ -24,6 +25,7 @@ class EquilibriumBoundary(BoundaryCondition): """ Equilibrium boundary condition for a lattice Boltzmann method simulation. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( self, diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index ed0ec5a..f572e83 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -21,12 +21,14 @@ BoundaryMasker, IndicesBoundaryMasker, ) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry class FullBounceBack(BoundaryCondition): """ Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( self, @@ -65,7 +67,8 @@ def from_indices( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): + def apply_jax(self, f_pre, f_post, boundary_id, mask): + boundary = boundary_id == self.id flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) return flipped_f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index 0937f1a..708a2fa 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -11,12 +11,14 @@ BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry class HalfwayBounceBack(BoundaryCondition): """ Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( self, diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 005e54d..d4ba8c7 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -55,24 +55,15 @@ def __init__( compute_backend = compute_backends.pop() # Get collision and stream boundary conditions - self.collision_boundary_conditions = {} - self.stream_boundary_conditions = {} - for id_number, bc in enumerate(self.boundary_conditions): - bc_id = id_number + 1 - if bc.implementation_step == ImplementationStep.COLLISION: - self.collision_boundary_conditions[bc_id] = bc - elif bc.implementation_step == ImplementationStep.STREAMING: - self.stream_boundary_conditions[bc_id] = bc - else: - raise ValueError("Boundary condition step not recognized") + self.collision_boundary_conditions = [bc for bc in boundary_conditions if bc.implementation_step == ImplementationStep.COLLISION] + self.stream_boundary_conditions = [bc for bc in boundary_conditions if bc.implementation_step == ImplementationStep.STREAMING] # Make operators for converting the precisions #self.cast_to_compute = PrecisionCaster( # Make operator for setting boundary condition arrays self.set_boundary = SetBoundary( - self.collision_boundary_conditions, - self.stream_boundary_conditions, + self.boundary_conditions, velocity_set, precision_policy, compute_backend, @@ -148,11 +139,41 @@ def warp_bc( """ +class ApplyCollisionBoundaryConditions(Operator): + """ + Class that handles the construction of lattice boltzmann collision boundary condition operator + """ - + def __init__( + self, + boundary_conditions, + velocity_set, + precision_policy, + compute_backend, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + # Set boundary conditions + self.boundary_conditions = boundary_conditions + # Check that all boundary conditions are collision boundary conditions + for bc in boundary_conditions: + assert bc.implementation_step == ImplementationStep.COLLISION, "All boundary conditions must be collision boundary conditions" + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, mask, boundary_id): + """ + Apply collision boundary conditions + """ + for bc in self.boundary_conditions: + f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) + return f_post, mask + + def _construct_warp(self): + + + class SetBoundary(Operator): """ Class that handles the construction of lattice boltzmann boundary condition operator @@ -161,26 +182,23 @@ class SetBoundary(Operator): def __init__( self, - collision_boundary_conditions, - stream_boundary_conditions, + boundary_conditions, velocity_set, precision_policy, compute_backend, ): super().__init__(velocity_set, precision_policy, compute_backend) - # Set parameters - self.collision_boundary_conditions = collision_boundary_conditions - self.stream_boundary_conditions = stream_boundary_conditions + # Set boundary conditions + self.boundary_conditions = boundary_conditions + def _apply_all_bc(self, ijk, boundary_id, mask, bc): """ Apply all boundary conditions """ - for id_number, bc in self.collision_boundary_conditions.items(): - boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, id_number) - for id_number, bc in self.stream_boundary_conditions.items(): - boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, id_number) + for bc in self.boundary_conditions: + boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, bc.id) return boundary_id, mask @Operator.register_backend(ComputeBackend.JAX) From 414047dcd564b65d94b2948e22e279302ffabda2 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 7 Mar 2024 14:54:39 -0800 Subject: [PATCH 019/144] boundary condition restructuring --- .../collision_boundary_condition.py | 76 ++++++++ .../equilibrium_boundary.py | 1 - .../boundary_condition/full_bounce_back.py | 7 +- .../boundary_applier/boundary_applier.py | 42 +++++ .../collision_boundary_applier.py | 143 +++++++++++++++ .../full_bounce_back_applier.py | 117 ++++++++++++ .../stream_boundary_applier.py | 143 +++++++++++++++ .../boundary_condition.py | 53 ++++++ .../boundary_masker/__init__.py | 2 + .../boundary_masker/boundary_masker.py | 8 + .../indices_boundary_masker.py | 117 ++++++++++++ .../full_bounce_back.py | 70 ++++++++ xlb/operator/stepper/stepper.py | 170 +----------------- 13 files changed, 778 insertions(+), 171 deletions(-) create mode 100644 xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py create mode 100644 xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py create mode 100644 xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py create mode 100644 xlb/operator/boundary_condition_new/boundary_applier/stream_boundary_applier.py create mode 100644 xlb/operator/boundary_condition_new/boundary_condition.py create mode 100644 xlb/operator/boundary_condition_new/boundary_masker/__init__.py create mode 100644 xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py create mode 100644 xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py create mode 100644 xlb/operator/boundary_condition_new/full_bounce_back.py diff --git a/xlb/operator/boundary_condition/collision_boundary_condition.py b/xlb/operator/boundary_condition/collision_boundary_condition.py index c75f602..5d0f764 100644 --- a/xlb/operator/boundary_condition/collision_boundary_condition.py +++ b/xlb/operator/boundary_condition/collision_boundary_condition.py @@ -47,6 +47,11 @@ def __init__( assert len(compute_backends) == 1, "All compute backends must be the same" compute_backend = compute_backends.pop() + # Make all possible collision boundary conditions to obtain the warp functions + self.full_bounce_back = FullBounceBack( + None, velocity_set, precision_policy, compute_backend + ) + super().__init__( velocity_set, precision_policy, @@ -64,4 +69,75 @@ def jax_implementation(self, f_pre, f_post, mask, boundary_id): return f_post, mask def _construct_warp(self): + """ + Construct the warp kernel for the collision boundary condition. + """ + + # Make constants for warp + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + + # Get boolean constants for all boundary conditions + if any([isinstance(bc, FullBounceBack) for bc in self.boundary_conditions]): + _use_full_bounce_back = wp.constant(True) + + # Construct the funcional for all boundary conditions + @wp.func + def functional( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + boundary_id: wp.uint8, + mask: self._warp_bool_lattice_vec, + ): + # Apply all boundary conditions + # Full bounce-back + if _use_full_bounce_back: + if boundary_id == self.full_bounce_back.id: + f_post = self.full_bounce_back.warp_functional(f_pre, f_post, mask) + + return f_post + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: self._warp_array_type, + f_post: self._warp_array_type, + f: self._warp_array_type, + boundary_id: self._warp_uint8_array_type, + mask: self._warp_bool_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Make vectors for the lattice + _f_pre = self._warp_lattice_vec() + _f_post = self._warp_lattice_vec() + _mask = self._warp_bool_lattice_vec() + _boundary_id = wp.uint8(boundary_id[0, i, j, k]) + for l in range(_q): + _f_pre[l] = f_pre[l, i, j, k] + _f_post[l] = f_post[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) + + # Apply all boundary conditions + if _boundary_id != wp.uint8(0): + _f_post = functional(_f_pre, _f_post, _boundary_id, _mask) + + # Write the result to the output + for l in range(_q): + f[l, i, j, k] = _f_post[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary_id, mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary_id, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index fc0cc11..9122fe3 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -39,7 +39,6 @@ def __init__( compute_backend: ComputeBackend, ): super().__init__( - ImplementationStep.COLLISION, implementation_step=ImplementationStep.STREAMING, velocity_set=velocity_set, compute_backend=compute_backend, diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index f572e83..4393fe4 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -78,6 +78,7 @@ def _construct_warp(self): _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) + _id = wp.constant(self.id) # Construct the funcional to get streamed indices @wp.func @@ -97,7 +98,7 @@ def kernel( f_pre: self._warp_array_type, f_post: self._warp_array_type, f: self._warp_array_type, - boundary: self._warp_bool_array_type, + boundary_id: self._warp_uint8_array_type, mask: self._warp_bool_array_type, ): # Get the global index @@ -118,7 +119,7 @@ def kernel( _mask[l] = wp.uint8(0) # Check if the boundary is active - if boundary[i, j, k]: + if boundary_id[i, j, k] == wp.uint8(_id: _f = functional(_f_pre, _f_post, _mask) else: _f = _f_post @@ -133,6 +134,6 @@ def kernel( def warp_implementation(self, f_pre, f_post, f, boundary, mask): # Launch the warp kernel wp.launch( - self._kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] ) return f diff --git a/xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py b/xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py new file mode 100644 index 0000000..2cd6cc6 --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py @@ -0,0 +1,42 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) + + +# Enum for implementation step +class ImplementationStep(Enum): + COLLISION = 1 + STREAMING = 2 + + +class BoundaryApplier(Operator): + """ + Base class for boundary conditions in a LBM simulation. + """ + + def __init__( + self, + implementation_step: ImplementationStep, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # Set implementation step + self.implementation_step = implementation_step diff --git a/xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py b/xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py new file mode 100644 index 0000000..2e68c0b --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py @@ -0,0 +1,143 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +# Import all collision boundary conditions +from xlb.boundary_condition.full_bounce_back import FullBounceBack + + +class CollisionBoundaryCondition(Operator): + """ + Class for combining collision and boundary conditions together + into a single operator. + """ + + def __init__( + self, + boundary_appliers: list[BoundaryApplier], + ): + # Set boundary conditions + self.boundary_appliers = boundary_appliers + + # Check that all boundary conditions have the same implementation step other properties + for bc in self.boundary_appliers: + assert bc.implementation_step == ImplementationStep.COLLISION, ( + "All boundary conditions must be applied during the collision step." + ) + + # Get velocity set, precision policy, and compute backend + velocity_sets = set([bc.velocity_set for bc in self.boundary_appliers]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + velocity_set = velocity_sets.pop() + precision_policies = set([bc.precision_policy for bc in self.boundary_appliers]) + assert len(precision_policies) == 1, "All precision policies must be the same" + precision_policy = precision_policies.pop() + compute_backends = set([bc.compute_backend for bc in self.boundary_appliers]) + assert len(compute_backends) == 1, "All compute backends must be the same" + compute_backend = compute_backends.pop() + + # Make all possible collision boundary conditions to obtain the warp functions + self.full_bounce_back = FullBounceBack( + None, velocity_set, precision_policy, compute_backend + ) + + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, mask, boundary_id): + """ + Apply collision boundary conditions + """ + for bc in self.boundary_conditions: + f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) + return f_post, mask + + def _construct_warp(self): + """ + Construct the warp kernel for the collision boundary condition. + """ + + # Make constants for warp + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + + # Get boolean constants for all boundary conditions + if any([isinstance(bc, FullBounceBack) for bc in self.boundary_conditions]): + _use_full_bounce_back = wp.constant(True) + + # Construct the funcional for all boundary conditions + @wp.func + def functional( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + boundary_id: wp.uint8, + mask: self._warp_bool_lattice_vec, + ): + # Apply all boundary conditions + # Full bounce-back + if _use_full_bounce_back: + if boundary_id == self.full_bounce_back.id: + f_post = self.full_bounce_back.warp_functional(f_pre, f_post, mask) + + return f_post + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: self._warp_array_type, + f_post: self._warp_array_type, + f: self._warp_array_type, + boundary_id: self._warp_uint8_array_type, + mask: self._warp_bool_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Make vectors for the lattice + _f_pre = self._warp_lattice_vec() + _f_post = self._warp_lattice_vec() + _mask = self._warp_bool_lattice_vec() + _boundary_id = wp.uint8(boundary_id[0, i, j, k]) + for l in range(_q): + _f_pre[l] = f_pre[l, i, j, k] + _f_post[l] = f_post[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) + + # Apply all boundary conditions + if _boundary_id != wp.uint8(0): + _f_post = functional(_f_pre, _f_post, _boundary_id, _mask) + + # Write the result to the output + for l in range(_q): + f[l, i, j, k] = _f_post[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary_id, mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary_id, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py b/xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py new file mode 100644 index 0000000..502a3ca --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py @@ -0,0 +1,117 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator import Operator +from xlb.operator.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + + +class FullBounceBackApplier(BoundaryApplier): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + super().__init__( + ImplementationStep.COLLISION, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary_id, mask): + boundary = boundary_id == self.id + flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) + flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) + return flipped_f + + def _construct_warp(self): + # Make constants for warp + _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + _id = wp.constant(self.id) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + mask: self._warp_bool_lattice_vec, + ): + fliped_f = self._warp_lattice_vec() + for l in range(_q): + fliped_f[l] = f_pre[_opp_indices[l]] + return fliped_f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: self._warp_array_type, + f_post: self._warp_array_type, + f: self._warp_array_type, + boundary_id: self._warp_uint8_array_type, + mask: self._warp_bool_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Make vectors for the lattice + _f_pre = self._warp_lattice_vec() + _f_post = self._warp_lattice_vec() + _mask = self._warp_bool_lattice_vec() + for l in range(_q): + _f_pre[l] = f_pre[l, i, j, k] + _f_post[l] = f_post[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) + + # Check if the boundary is active + if boundary_id[i, j, k] == wp.uint8(_id: + _f = functional(_f_pre, _f_post, _mask) + else: + _f = _f_post + + # Write the result to the output + for l in range(_q): + f[l, i, j, k] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary, mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition_new/boundary_applier/stream_boundary_applier.py b/xlb/operator/boundary_condition_new/boundary_applier/stream_boundary_applier.py new file mode 100644 index 0000000..2e68c0b --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_applier/stream_boundary_applier.py @@ -0,0 +1,143 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +# Import all collision boundary conditions +from xlb.boundary_condition.full_bounce_back import FullBounceBack + + +class CollisionBoundaryCondition(Operator): + """ + Class for combining collision and boundary conditions together + into a single operator. + """ + + def __init__( + self, + boundary_appliers: list[BoundaryApplier], + ): + # Set boundary conditions + self.boundary_appliers = boundary_appliers + + # Check that all boundary conditions have the same implementation step other properties + for bc in self.boundary_appliers: + assert bc.implementation_step == ImplementationStep.COLLISION, ( + "All boundary conditions must be applied during the collision step." + ) + + # Get velocity set, precision policy, and compute backend + velocity_sets = set([bc.velocity_set for bc in self.boundary_appliers]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + velocity_set = velocity_sets.pop() + precision_policies = set([bc.precision_policy for bc in self.boundary_appliers]) + assert len(precision_policies) == 1, "All precision policies must be the same" + precision_policy = precision_policies.pop() + compute_backends = set([bc.compute_backend for bc in self.boundary_appliers]) + assert len(compute_backends) == 1, "All compute backends must be the same" + compute_backend = compute_backends.pop() + + # Make all possible collision boundary conditions to obtain the warp functions + self.full_bounce_back = FullBounceBack( + None, velocity_set, precision_policy, compute_backend + ) + + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, mask, boundary_id): + """ + Apply collision boundary conditions + """ + for bc in self.boundary_conditions: + f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) + return f_post, mask + + def _construct_warp(self): + """ + Construct the warp kernel for the collision boundary condition. + """ + + # Make constants for warp + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + + # Get boolean constants for all boundary conditions + if any([isinstance(bc, FullBounceBack) for bc in self.boundary_conditions]): + _use_full_bounce_back = wp.constant(True) + + # Construct the funcional for all boundary conditions + @wp.func + def functional( + f_pre: self._warp_lattice_vec, + f_post: self._warp_lattice_vec, + boundary_id: wp.uint8, + mask: self._warp_bool_lattice_vec, + ): + # Apply all boundary conditions + # Full bounce-back + if _use_full_bounce_back: + if boundary_id == self.full_bounce_back.id: + f_post = self.full_bounce_back.warp_functional(f_pre, f_post, mask) + + return f_post + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: self._warp_array_type, + f_post: self._warp_array_type, + f: self._warp_array_type, + boundary_id: self._warp_uint8_array_type, + mask: self._warp_bool_array_type, + ): + # Get the global index + i, j, k = wp.tid() + + # Make vectors for the lattice + _f_pre = self._warp_lattice_vec() + _f_post = self._warp_lattice_vec() + _mask = self._warp_bool_lattice_vec() + _boundary_id = wp.uint8(boundary_id[0, i, j, k]) + for l in range(_q): + _f_pre[l] = f_pre[l, i, j, k] + _f_post[l] = f_post[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) + + # Apply all boundary conditions + if _boundary_id != wp.uint8(0): + _f_post = functional(_f_pre, _f_post, _boundary_id, _mask) + + # Write the result to the output + for l in range(_q): + f[l, i, j, k] = _f_post[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary_id, mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary_id, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition_new/boundary_condition.py b/xlb/operator/boundary_condition_new/boundary_condition.py new file mode 100644 index 0000000..ebd814c --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_condition.py @@ -0,0 +1,53 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +from functools import partial +import numpy as np +from enum import Enum + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_masker import BoundaryMasker +from xlb.operator.boundary_condition.boundary_applier import BoundaryApplier + + +# Enum for implementation step +class ImplementationStep(Enum): + COLLISION = 1 + STREAMING = 2 + + +class BoundaryCondition(): + """ + Base class for boundary conditions in a LBM simulation. + + Boundary conditions are unique in that they are not operators themselves, + but rather hold operators for applying and making masks for boundary conditions. + """ + + def __init__( + self, + boundary_applier: BoundaryApplier, + boundary_masker: BoundaryMasker, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # Set operators + self.boundary_applier = boundary_applier + self.boundary_masker = boundary_masker + + # Get velocity set, precision policy, and compute backend + velocity_sets = set([boundary_applier.velocity_set, boundary_masker.velocity_set]) + assert len(velocity_sets) == 1, "All velocity sets must be the same" + velocity_set = velocity_sets.pop() + precision_policies = set([boundary_applier.precision_policy, boundary_masker.precision_policy]) + assert len(precision_policies) == 1, "All precision policies must be the same" + precision_policy = precision_policies.pop() + compute_backends = set([boundary_applier.compute_backend, boundary_masker.compute_backend]) + assert len(compute_backends) == 1, "All compute backends must be the same" + compute_backend = compute_backends.pop() diff --git a/xlb/operator/boundary_condition_new/boundary_masker/__init__.py b/xlb/operator/boundary_condition_new/boundary_masker/__init__.py new file mode 100644 index 0000000..e33e509 --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_masker/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.boundary_condition.boundary_masker.boundary_masker import BoundaryMasker +from xlb.operator.boundary_condition.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker diff --git a/xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py b/xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py new file mode 100644 index 0000000..f6b73f0 --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py @@ -0,0 +1,8 @@ +# Base class for all boundary masker operators + +from xlb.operator.operator import Operator + +class BoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ diff --git a/xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py new file mode 100644 index 0000000..ccc52a3 --- /dev/null +++ b/xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py @@ -0,0 +1,117 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class IndicesBoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ + + def __init__( + self, + stream_indices: bool, + id_number: int, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # Set indices + self.id_number = id_number + self.stream_indices = stream_indices + + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + @staticmethod + def _indices_to_tuple(indices): + """ + Converts a tensor of indices to a tuple for indexing + TODO: Might be better to index + """ + return tuple([indices[:, i] for i in range(indices.shape[1])]) + + @Operator.register_backend(ComputeBackend.JAX) + #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this + def jax_implementation(self, indicies, start_index, boundary_id, mask, id_number): + # Get local indices from the meshgrid and the indices + local_indices = self.indices - np.array(start_index)[np.newaxis, :] + + # Remove any indices that are out of bounds + local_indices = local_indices[ + (local_indices[:, 0] >= 0) + & (local_indices[:, 0] < mask.shape[0]) + & (local_indices[:, 1] >= 0) + & (local_indices[:, 1] < mask.shape[1]) + & (local_indices[:, 2] >= 0) + & (local_indices[:, 2] < mask.shape[2]) + ] + + # Set the boundary id + boundary_id = boundary_id.at[self._indices_to_tuple(local_indices)].set( + id_number + ) + + # Stream mask if necessary + if self.stream_indices: + # Make mask then stream to get the edge points + pre_stream_mask = jnp.zeros_like(mask) + pre_stream_mask = pre_stream_mask.at[ + self._indices_to_tuple(local_indices) + ].set(True) + post_stream_mask = self.stream(pre_stream_mask) + + # Set false for points inside the boundary + post_stream_mask = post_stream_mask.at[ + post_stream_mask[..., 0] == True + ].set(False) + + # Get indices on edges + edge_indices = jnp.argwhere(post_stream_mask) + + # Set the mask + mask = mask.at[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ].set( + post_stream_mask[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ] + ) + + else: + # Set the mask + mask = mask.at[self._indices_to_tuple(local_indices)].set(True) + + return boundary_id, mask + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, start_index, boundary_id, mask, id_number): + # Reuse the jax implementation, TODO: implement a warp version + # Convert to jax + boundary_id = wp.jax.to_jax(boundary_id) + mask = wp.jax.to_jax(mask) + + # Call jax implementation + boundary_id, mask = self.jax_implementation( + start_index, boundary_id, mask, id_number + ) + + # Convert back to warp + boundary_id = wp.jax.to_warp(boundary_id) + mask = wp.jax.to_warp(mask) + + return boundary_id, mask diff --git a/xlb/operator/boundary_condition_new/full_bounce_back.py b/xlb/operator/boundary_condition_new/full_bounce_back.py new file mode 100644 index 0000000..1559609 --- /dev/null +++ b/xlb/operator/boundary_condition_new/full_bounce_back.py @@ -0,0 +1,70 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator import Operator +from xlb.operator.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + + +class FullBounceBack(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + def __init__( + self, + boundary_masker: BoundaryMasker, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + + boundary_applier = FullBounceBackApplier( + velocity_set, precision_policy, compute_backend + ) + + super().__init__( + boundary_applier, + boundary_masker, + velocity_set, + precision_policy, + compute_backend, + ) + + @classmethod + def from_indices( + cls, velocity_set, precision_policy, compute_backend + ): + """ + Create a full bounce-back boundary condition from indices. + """ + # Create boundary mask + boundary_mask = IndicesBoundaryMasker( + False, velocity_set, precision_policy, compute_backend + ) + + # Create boundary condition + return cls( + boundary_mask, + velocity_set, + precision_policy, + compute_backend, + ) diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index d4ba8c7..0e0675a 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -40,7 +40,9 @@ def __init__( stream, equilibrium, macroscopic, - *boundary_conditions, + *[bc.boundary_applier for bc in boundary_conditions], + *[bc.boundary_masker for bc in boundary_conditions], + forcing, ] # Get velocity set, precision policy, and compute backend @@ -61,171 +63,5 @@ def __init__( # Make operators for converting the precisions #self.cast_to_compute = PrecisionCaster( - # Make operator for setting boundary condition arrays - self.set_boundary = SetBoundary( - self.boundary_conditions, - velocity_set, - precision_policy, - compute_backend, - ) - self.operators.append(self.set_boundary) - # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) - - ###################################################### - # TODO: This is a hacky way to do this. Need to refactor - ###################################################### - """ - def _construct_warp_bc_functional(self): - # identity collision boundary condition - @wp.func - def identity( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, - ): - return f_post - def get_bc_functional(id_number, self.collision_boundary_conditions): - if id_number in self.collision_boundary_conditions.keys(): - return self.collision_boundary_conditions[id_number].warp_functional - else: - return identity - - # Manually set the boundary conditions TODO: Extremely hacky - collision_bc_functional_0 = get_bc_functional(0, self.collision_boundary_conditions) - collision_bc_functional_1 = get_bc_functional(1, self.collision_boundary_conditions) - collision_bc_functional_2 = get_bc_functional(2, self.collision_boundary_conditions) - collision_bc_functional_3 = get_bc_functional(3, self.collision_boundary_conditions) - collision_bc_functional_4 = get_bc_functional(4, self.collision_boundary_conditions) - collision_bc_functional_5 = get_bc_functional(5, self.collision_boundary_conditions) - collision_bc_functional_6 = get_bc_functional(6, self.collision_boundary_conditions) - collision_bc_functional_7 = get_bc_functional(7, self.collision_boundary_conditions) - collision_bc_functional_8 = get_bc_functional(8, self.collision_boundary_conditions) - - # Make the warp boundary condition functional - @wp.func - def warp_bc( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, - boundary_id: wp.uint8, - ): - if boundary_id == 0: - f_post = collision_bc_functional_0(f_pre, f_post, mask) - elif boundary_id == 1: - f_post = collision_bc_functional_1(f_pre, f_post, mask) - elif boundary_id == 2: - f_post = collision_bc_functional_2(f_pre, f_post, mask) - elif boundary_id == 3: - f_post = collision_bc_functional_3(f_pre, f_post, mask) - elif boundary_id == 4: - f_post = collision_bc_functional_4(f_pre, f_post, mask) - elif boundary_id == 5: - f_post = collision_bc_functional_5(f_pre, f_post, mask) - elif boundary_id == 6: - f_post = collision_bc_functional_6(f_pre, f_post, mask) - elif boundary_id == 7: - f_post = collision_bc_functional_7(f_pre, f_post, mask) - elif boundary_id == 8: - f_post = collision_bc_functional_8(f_pre, f_post, mask) - - return f_post - - - - - ###################################################### - """ - - -class ApplyCollisionBoundaryConditions(Operator): - """ - Class that handles the construction of lattice boltzmann collision boundary condition operator - """ - - def __init__( - self, - boundary_conditions, - velocity_set, - precision_policy, - compute_backend, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set boundary conditions - self.boundary_conditions = boundary_conditions - - # Check that all boundary conditions are collision boundary conditions - for bc in boundary_conditions: - assert bc.implementation_step == ImplementationStep.COLLISION, "All boundary conditions must be collision boundary conditions" - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, mask, boundary_id): - """ - Apply collision boundary conditions - """ - for bc in self.boundary_conditions: - f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) - return f_post, mask - - def _construct_warp(self): - - - -class SetBoundary(Operator): - """ - Class that handles the construction of lattice boltzmann boundary condition operator - This will probably never be used directly and it might be better to refactor it - """ - - def __init__( - self, - boundary_conditions, - velocity_set, - precision_policy, - compute_backend, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set boundary conditions - self.boundary_conditions = boundary_conditions - - - def _apply_all_bc(self, ijk, boundary_id, mask, bc): - """ - Apply all boundary conditions - """ - for bc in self.boundary_conditions: - boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, bc.id) - return boundary_id, mask - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) - mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) - return self._apply_all_bc(ijk, boundary_id, mask, bc) - - @Operator.register_backend(ComputeBackend.PALLAS) - def pallas_implementation(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - raise NotImplementedError("Pallas implementation not available") - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - boundary_id = wp.zeros(ijk.shape[:-1], dtype=wp.uint8) - mask = wp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=wp.bool) - return self._apply_all_bc(ijk, boundary_id, mask, bc) From c715b69d72feea638b09b69bfe65cdf4c9117961 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 11 Mar 2024 09:59:13 -0700 Subject: [PATCH 020/144] base boundary condition structure --- .../{functional_interface.py => ldc.py} | 0 examples/interfaces/taylor_green.py | 146 ++++++++++++++++++ xlb/operator/boundary_condition/__init__.py | 10 -- .../boundary_applier/boundary_applier.py | 7 - .../collision_boundary_applier.py | 2 +- .../full_bounce_back_applier.py | 8 +- .../stream_boundary_applier.py | 0 .../boundary_condition/boundary_condition.py | 143 ++--------------- .../boundary_masker/boundary_masker.py | 28 +--- .../indices_boundary_masker.py | 7 +- .../collision_boundary_condition.py | 143 ----------------- xlb/operator/boundary_condition/do_nothing.py | 56 ------- .../equilibrium_boundary.py | 79 ---------- .../boundary_condition/full_bounce_back.py | 85 +--------- .../boundary_condition/halfway_bounce_back.py | 42 ----- .../boundary_condition.py | 53 ------- .../boundary_masker/__init__.py | 2 - .../boundary_masker/boundary_masker.py | 8 - .../indices_boundary_masker.py | 117 -------------- .../full_bounce_back.py | 70 --------- xlb/operator/stepper/nse.py | 24 ++- xlb/operator/stepper/stepper.py | 17 +- 22 files changed, 195 insertions(+), 852 deletions(-) rename examples/interfaces/{functional_interface.py => ldc.py} (100%) create mode 100644 examples/interfaces/taylor_green.py delete mode 100644 xlb/operator/boundary_condition/__init__.py rename xlb/operator/{boundary_condition_new => boundary_condition}/boundary_applier/boundary_applier.py (90%) rename xlb/operator/{boundary_condition_new => boundary_condition}/boundary_applier/collision_boundary_applier.py (99%) rename xlb/operator/{boundary_condition_new => boundary_condition}/boundary_applier/full_bounce_back_applier.py (95%) rename xlb/operator/{boundary_condition_new => boundary_condition}/boundary_applier/stream_boundary_applier.py (100%) delete mode 100644 xlb/operator/boundary_condition/collision_boundary_condition.py delete mode 100644 xlb/operator/boundary_condition/do_nothing.py delete mode 100644 xlb/operator/boundary_condition/equilibrium_boundary.py delete mode 100644 xlb/operator/boundary_condition/halfway_bounce_back.py delete mode 100644 xlb/operator/boundary_condition_new/boundary_condition.py delete mode 100644 xlb/operator/boundary_condition_new/boundary_masker/__init__.py delete mode 100644 xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py delete mode 100644 xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py delete mode 100644 xlb/operator/boundary_condition_new/full_bounce_back.py diff --git a/examples/interfaces/functional_interface.py b/examples/interfaces/ldc.py similarity index 100% rename from examples/interfaces/functional_interface.py rename to examples/interfaces/ldc.py diff --git a/examples/interfaces/taylor_green.py b/examples/interfaces/taylor_green.py new file mode 100644 index 0000000..e1419a1 --- /dev/null +++ b/examples/interfaces/taylor_green.py @@ -0,0 +1,146 @@ +# Simple Taylor green example using the functional interface to xlb + +import time +from tqdm import tqdm +import os +import matplotlib.pyplot as plt + +import warp as wp +wp.init() + +import xlb +from xlb.operator import Operator + +class TaylorGreenInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + f0: self._warp_array_type, + rho: self._warp_array_type, + u: self._warp_array_type, + vel: float, + nr: int, + ): + # Get the global index + i, j, k = wp.tid() + + # Get real pos + x = 2.0 * wp.pi * wp.float(i) / wp.float(nr) + y = 2.0 * wp.pi * wp.float(j) / wp.float(nr) + z = 2.0 * wp.pi * wp.float(k) / wp.float(nr) + + # Compute u + u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) + u[1, i, j, k] = - vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + u[2, i, j, k] = 0.0 + + # Compute rho + rho[0, i, j, k] = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * ( + wp.cos(2.0 * x) + + (wp.cos(2.0 * y) + * (wp.cos(2.0 * z) + 2.0)) + ) + + 1.0 + ) + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, f0, rho, u, vel, nr): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + f0, + rho, + u, + vel, + nr, + ], + dim=rho.shape[1:], + ) + return rho, u + +if __name__ == "__main__": + + # Set parameters + compute_backend = xlb.ComputeBackend.WARP + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q19() + + # Make feilds + nr = 256 + shape = (nr, nr, nr) + grid = xlb.grid.WarpGrid(shape=shape) + rho = grid.create_field(cardinality=1, dtype=wp.float32) + u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) + f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) + mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + + # Make operators + initializer = TaylorGreenInitializer( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + collision = xlb.operator.collision.BGK( + omega=1.9, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + boundary_conditions=[]) + + # Parrallelize the stepper + #stepper = grid.parallelize_operator(stepper) + + # Set initial conditions + rho, u = initializer(f0, rho, u, 0.1, nr) + f0 = equilibrium(rho, u, f0) + + # Time stepping + plot_freq = 32 + save_dir = "taylor_green" + os.makedirs(save_dir, exist_ok=True) + #compute_mlup = False # Plotting results + compute_mlup = True + num_steps = 1024 + start = time.time() + for _ in tqdm(range(num_steps)): + f1 = stepper(f0, f1, boundary_id, mask, _) + f1, f0 = f0, f1 + if (_ % plot_freq == 0) and (not compute_mlup): + rho, u = macroscopic(f0, rho, u) + plt.imshow(u[0, :, nr//2, :].numpy()) + plt.colorbar() + plt.savefig(f"{save_dir}/{str(_).zfill(4)}.png") + plt.close() + wp.synchronize() + end = time.time() + + # Print MLUPS + print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py deleted file mode 100644 index 6a0b10a..0000000 --- a/xlb/operator/boundary_condition/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.full_bounce_back import FullBounceBack -from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBack -from xlb.operator.boundary_condition.do_nothing import DoNothing -from xlb.operator.boundary_condition.equilibrium_boundary import EquilibriumBoundary - -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry diff --git a/xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py b/xlb/operator/boundary_condition/boundary_applier/boundary_applier.py similarity index 90% rename from xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py rename to xlb/operator/boundary_condition/boundary_applier/boundary_applier.py index 2cd6cc6..5eb3de0 100644 --- a/xlb/operator/boundary_condition_new/boundary_applier/boundary_applier.py +++ b/xlb/operator/boundary_condition/boundary_applier/boundary_applier.py @@ -17,13 +17,6 @@ IndicesBoundaryMasker, ) - -# Enum for implementation step -class ImplementationStep(Enum): - COLLISION = 1 - STREAMING = 2 - - class BoundaryApplier(Operator): """ Base class for boundary conditions in a LBM simulation. diff --git a/xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py b/xlb/operator/boundary_condition/boundary_applier/collision_boundary_applier.py similarity index 99% rename from xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py rename to xlb/operator/boundary_condition/boundary_applier/collision_boundary_applier.py index 2e68c0b..633ef47 100644 --- a/xlb/operator/boundary_condition_new/boundary_applier/collision_boundary_applier.py +++ b/xlb/operator/boundary_condition/boundary_applier/collision_boundary_applier.py @@ -17,7 +17,7 @@ from xlb.boundary_condition.full_bounce_back import FullBounceBack -class CollisionBoundaryCondition(Operator): +class CollisionBoundaryApplier(Operator): """ Class for combining collision and boundary conditions together into a single operator. diff --git a/xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py b/xlb/operator/boundary_condition/boundary_applier/full_bounce_back_applier.py similarity index 95% rename from xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py rename to xlb/operator/boundary_condition/boundary_applier/full_bounce_back_applier.py index 502a3ca..48860fd 100644 --- a/xlb/operator/boundary_condition_new/boundary_applier/full_bounce_back_applier.py +++ b/xlb/operator/boundary_condition/boundary_applier/full_bounce_back_applier.py @@ -13,14 +13,10 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.boundary_condition import ( - BoundaryCondition, +from xlb.operator.boundary_condition.boundary_applier import ( + BoundaryApplier, ImplementationStep, ) -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry diff --git a/xlb/operator/boundary_condition_new/boundary_applier/stream_boundary_applier.py b/xlb/operator/boundary_condition/boundary_applier/stream_boundary_applier.py similarity index 100% rename from xlb/operator/boundary_condition_new/boundary_applier/stream_boundary_applier.py rename to xlb/operator/boundary_condition/boundary_applier/stream_boundary_applier.py diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 8d1daea..cf64c66 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -12,153 +12,38 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) +from xlb.operator.boundary_condition.boundary_masker import BoundaryMasker +from xlb.operator.boundary_condition.boundary_applier import BoundaryApplier - -# Enum for implementation step -class ImplementationStep(Enum): - COLLISION = 1 - STREAMING = 2 - - -class BoundaryCondition(Operator): +class BoundaryCondition(): """ Base class for boundary conditions in a LBM simulation. + + Boundary conditions are unique in that they are not operators themselves, + but rather hold operators for applying and making masks for boundary conditions. """ def __init__( self, - implementation_step: ImplementationStep, + boundary_applier: BoundaryApplier, boundary_masker: BoundaryMasker, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, ): super().__init__(velocity_set, precision_policy, compute_backend) - # Set implementation step - self.implementation_step = implementation_step - - # Set boundary masker + # Set operators + self.boundary_applier = boundary_applier self.boundary_masker = boundary_masker - @classmethod - def from_function( - cls, - implementation_step: ImplementationStep, - boundary_function, - velocity_set, - precision_policy, - compute_backend, - ): - """ - Create a boundary condition from a function. - """ - # Create boundary mask - boundary_mask = BoundaryMasker.from_function( - boundary_function, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - implementation_step, - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_indices( - cls, - implementation_step: ImplementationStep, - indices: np.ndarray, - stream_indices: bool, - velocity_set, - precision_policy, - compute_backend, - ): - """ - Create a boundary condition from indices and boundary id. - """ - # Create boundary mask - boundary_mask = IndicesBoundaryMasker( - indices, stream_indices, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - implementation_step, - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_stl( - cls, - implementation_step: ImplementationStep, - stl_file: str, - stream_indices: bool, - velocity_set, - precision_policy, - compute_backend, - ): - """ - Create a boundary condition from an STL file. - """ - raise NotImplementedError - - -class CollisionBoundaryCondition(Operator): - """ - Class for combining collision and boundary conditions together - into a single operator. - """ - - def __init__( - self, - boundary_conditions: list[BoundaryCondition], - ): - # Set boundary conditions - self.boundary_conditions = boundary_conditions - - # Check that all boundary conditions have the same implementation step other properties - for bc in self.boundary_conditions: - assert bc.implementation_step == ImplementationStep.COLLISION, ( - "All boundary conditions must be applied during the collision step." - ) - # Get velocity set, precision policy, and compute backend - velocity_sets = set([bc.velocity_set for bc in self.boundary_conditions]) + velocity_sets = set([boundary_applier.velocity_set, boundary_masker.velocity_set]) assert len(velocity_sets) == 1, "All velocity sets must be the same" velocity_set = velocity_sets.pop() - precision_policies = set([bc.precision_policy for bc in self.boundary_conditions]) + precision_policies = set([boundary_applier.precision_policy, boundary_masker.precision_policy]) assert len(precision_policies) == 1, "All precision policies must be the same" precision_policy = precision_policies.pop() - compute_backends = set([bc.compute_backend for bc in self.boundary_conditions]) + compute_backends = set([boundary_applier.compute_backend, boundary_masker.compute_backend]) assert len(compute_backends) == 1, "All compute backends must be the same" compute_backend = compute_backends.pop() - super().__init__( - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, mask, boundary_id): - """ - Apply collision boundary conditions - """ - for bc in self.boundary_conditions: - f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) - return f_post, mask - - - def _construct_warp(self): + # Get implementation step from boundary applier (TODO: Maybe not add this to the base class) + self.implementation_step = boundary_applier.implementation_step diff --git a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py index 20bf580..f6b73f0 100644 --- a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py +++ b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py @@ -1,34 +1,8 @@ -# Base class for all equilibriums +# Base class for all boundary masker operators -import jax.numpy as jnp -from jax import jit -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator - class BoundaryMasker(Operator): """ Operator for creating a boundary mask """ - - @classmethod - def from_jax_func( - cls, jax_func, precision_policy: PrecisionPolicy, velocity_set: VelocitySet - ): - """ - Create a boundary masker from a jax function - """ - raise NotImplementedError - - @classmethod - def from_warp_func( - cls, warp_func, precision_policy: PrecisionPolicy, velocity_set: VelocitySet - ): - """ - Create a boundary masker from a warp function - """ - raise NotImplementedError diff --git a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py index fdf8ced..ccc52a3 100644 --- a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py @@ -22,8 +22,8 @@ class IndicesBoundaryMasker(Operator): def __init__( self, - indices: np.ndarray, stream_indices: bool, + id_number: int, velocity_set: VelocitySet, precision_policy: PrecisionPolicy, compute_backend: ComputeBackend.JAX, @@ -31,8 +31,7 @@ def __init__( super().__init__(velocity_set, precision_policy, compute_backend) # Set indices - # TODO: handle multi-gpu case (this will usually implicitly work) - self.indices = indices + self.id_number = id_number self.stream_indices = stream_indices # Make stream operator @@ -48,7 +47,7 @@ def _indices_to_tuple(indices): @Operator.register_backend(ComputeBackend.JAX) #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this - def jax_implementation(self, start_index, boundary_id, mask, id_number): + def jax_implementation(self, indicies, start_index, boundary_id, mask, id_number): # Get local indices from the meshgrid and the indices local_indices = self.indices - np.array(start_index)[np.newaxis, :] diff --git a/xlb/operator/boundary_condition/collision_boundary_condition.py b/xlb/operator/boundary_condition/collision_boundary_condition.py deleted file mode 100644 index 5d0f764..0000000 --- a/xlb/operator/boundary_condition/collision_boundary_condition.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -from enum import Enum - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - -# Import all collision boundary conditions -from xlb.boundary_condition.full_bounce_back import FullBounceBack - - -class CollisionBoundaryCondition(Operator): - """ - Class for combining collision and boundary conditions together - into a single operator. - """ - - def __init__( - self, - boundary_conditions: list[BoundaryCondition], - ): - # Set boundary conditions - self.boundary_conditions = boundary_conditions - - # Check that all boundary conditions have the same implementation step other properties - for bc in self.boundary_conditions: - assert bc.implementation_step == ImplementationStep.COLLISION, ( - "All boundary conditions must be applied during the collision step." - ) - - # Get velocity set, precision policy, and compute backend - velocity_sets = set([bc.velocity_set for bc in self.boundary_conditions]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - velocity_set = velocity_sets.pop() - precision_policies = set([bc.precision_policy for bc in self.boundary_conditions]) - assert len(precision_policies) == 1, "All precision policies must be the same" - precision_policy = precision_policies.pop() - compute_backends = set([bc.compute_backend for bc in self.boundary_conditions]) - assert len(compute_backends) == 1, "All compute backends must be the same" - compute_backend = compute_backends.pop() - - # Make all possible collision boundary conditions to obtain the warp functions - self.full_bounce_back = FullBounceBack( - None, velocity_set, precision_policy, compute_backend - ) - - super().__init__( - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, mask, boundary_id): - """ - Apply collision boundary conditions - """ - for bc in self.boundary_conditions: - f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) - return f_post, mask - - def _construct_warp(self): - """ - Construct the warp kernel for the collision boundary condition. - """ - - # Make constants for warp - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - - # Get boolean constants for all boundary conditions - if any([isinstance(bc, FullBounceBack) for bc in self.boundary_conditions]): - _use_full_bounce_back = wp.constant(True) - - # Construct the funcional for all boundary conditions - @wp.func - def functional( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - boundary_id: wp.uint8, - mask: self._warp_bool_lattice_vec, - ): - # Apply all boundary conditions - # Full bounce-back - if _use_full_bounce_back: - if boundary_id == self.full_bounce_back.id: - f_post = self.full_bounce_back.warp_functional(f_pre, f_post, mask) - - return f_post - - # Construct the warp kernel - @wp.kernel - def kernel( - f_pre: self._warp_array_type, - f_post: self._warp_array_type, - f: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, - ): - # Get the global index - i, j, k = wp.tid() - - # Make vectors for the lattice - _f_pre = self._warp_lattice_vec() - _f_post = self._warp_lattice_vec() - _mask = self._warp_bool_lattice_vec() - _boundary_id = wp.uint8(boundary_id[0, i, j, k]) - for l in range(_q): - _f_pre[l] = f_pre[l, i, j, k] - _f_post[l] = f_post[l, i, j, k] - - # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) - - # Apply all boundary conditions - if _boundary_id != wp.uint8(0): - _f_post = functional(_f_pre, _f_post, _boundary_id, _mask) - - # Write the result to the output - for l in range(_q): - f[l, i, j, k] = _f_post[l] - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary_id, mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary_id, mask], dim=f_pre.shape[1:] - ) - return f diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py deleted file mode 100644 index 9f85da2..0000000 --- a/xlb/operator/boundary_condition/do_nothing.py +++ /dev/null @@ -1,56 +0,0 @@ -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend -from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - -class DoNothing(BoundaryCondition): - """ - A boundary condition that skips the streaming step. - """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - - def __init__( - self, - set_boundary, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): - super().__init__( - set_boundary=set_boundary, - implementation_step=ImplementationStep.STREAMING, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - - @classmethod - def from_indices( - cls, - indices, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): - """ - Creates a boundary condition from a list of indices. - """ - - return cls( - set_boundary=cls._set_boundary_from_indices(indices), - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) - f = lax.select(do_nothing, f_pre, f_post) - return f diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py deleted file mode 100644 index 9122fe3..0000000 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ /dev/null @@ -1,79 +0,0 @@ -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - - -class EquilibriumBoundary(BoundaryCondition): - """ - Equilibrium boundary condition for a lattice Boltzmann method simulation. - """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - - def __init__( - self, - set_boundary, - rho: float, - u: tuple[float, float], - equilibrium: Equilibrium, - boundary_masker: BoundaryMasker, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - super().__init__( - implementation_step=ImplementationStep.STREAMING, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - self.f = equilibrium(rho, u) - - @classmethod - def from_indices( - cls, - indices: np.ndarray, - rho: float, - u: tuple[float, float], - equilibrium: Equilibrium, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - """ - Creates a boundary condition from a list of indices. - """ - - return cls( - set_boundary=cls._set_boundary_from_indices(indices), - rho=rho, - u=u, - equilibrium=equilibrium, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) - equilibrium_f = jnp.repeat(self.f[None, ...], boundary.shape[0], axis=0) - equilibrium_f = jnp.repeat(equilibrium_f[:, None], boundary.shape[1], axis=1) - equilibrium_f = jnp.repeat(equilibrium_f[:, :, None], boundary.shape[2], axis=2) - f = lax.select(equilibrium_mask, equilibrium_f, f_post) - return f diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index 4393fe4..1559609 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -28,7 +28,6 @@ class FullBounceBack(BoundaryCondition): """ Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( self, @@ -37,8 +36,13 @@ def __init__( precision_policy: PrecisionPolicy, compute_backend: ComputeBackend, ): + + boundary_applier = FullBounceBackApplier( + velocity_set, precision_policy, compute_backend + ) + super().__init__( - ImplementationStep.COLLISION, + boundary_applier, boundary_masker, velocity_set, precision_policy, @@ -47,14 +51,14 @@ def __init__( @classmethod def from_indices( - cls, indices: np.ndarray, velocity_set, precision_policy, compute_backend + cls, velocity_set, precision_policy, compute_backend ): """ Create a full bounce-back boundary condition from indices. """ # Create boundary mask boundary_mask = IndicesBoundaryMasker( - indices, False, velocity_set, precision_policy, compute_backend + False, velocity_set, precision_policy, compute_backend ) # Create boundary condition @@ -64,76 +68,3 @@ def from_indices( precision_policy, compute_backend, ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary_id, mask): - boundary = boundary_id == self.id - flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) - flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) - return flipped_f - - def _construct_warp(self): - # Make constants for warp - _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - _id = wp.constant(self.id) - - # Construct the funcional to get streamed indices - @wp.func - def functional( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, - ): - fliped_f = self._warp_lattice_vec() - for l in range(_q): - fliped_f[l] = f_pre[_opp_indices[l]] - return fliped_f - - # Construct the warp kernel - @wp.kernel - def kernel( - f_pre: self._warp_array_type, - f_post: self._warp_array_type, - f: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, - ): - # Get the global index - i, j, k = wp.tid() - - # Make vectors for the lattice - _f_pre = self._warp_lattice_vec() - _f_post = self._warp_lattice_vec() - _mask = self._warp_bool_lattice_vec() - for l in range(_q): - _f_pre[l] = f_pre[l, i, j, k] - _f_post[l] = f_post[l, i, j, k] - - # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) - - # Check if the boundary is active - if boundary_id[i, j, k] == wp.uint8(_id: - _f = functional(_f_pre, _f_post, _mask) - else: - _f = _f_post - - # Write the result to the output - for l in range(_q): - f[l, i, j, k] = _f[l] - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary, mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] - ) - return f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py deleted file mode 100644 index 708a2fa..0000000 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ /dev/null @@ -1,42 +0,0 @@ -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend -from xlb.operator.stream.stream import Stream -from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - -class HalfwayBounceBack(BoundaryCondition): - """ - Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. - """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - - def __init__( - self, - set_boundary, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): - super().__init__( - set_boundary=set_boundary, - implementation_step=ImplementationStep.STREAMING, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - flip_mask = boundary[..., jnp.newaxis] & mask - flipped_f = lax.select( - flip_mask, f_pre[..., self.velocity_set.opp_indices], f_post - ) - return flipped_f diff --git a/xlb/operator/boundary_condition_new/boundary_condition.py b/xlb/operator/boundary_condition_new/boundary_condition.py deleted file mode 100644 index ebd814c..0000000 --- a/xlb/operator/boundary_condition_new/boundary_condition.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -from enum import Enum - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_masker import BoundaryMasker -from xlb.operator.boundary_condition.boundary_applier import BoundaryApplier - - -# Enum for implementation step -class ImplementationStep(Enum): - COLLISION = 1 - STREAMING = 2 - - -class BoundaryCondition(): - """ - Base class for boundary conditions in a LBM simulation. - - Boundary conditions are unique in that they are not operators themselves, - but rather hold operators for applying and making masks for boundary conditions. - """ - - def __init__( - self, - boundary_applier: BoundaryApplier, - boundary_masker: BoundaryMasker, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set operators - self.boundary_applier = boundary_applier - self.boundary_masker = boundary_masker - - # Get velocity set, precision policy, and compute backend - velocity_sets = set([boundary_applier.velocity_set, boundary_masker.velocity_set]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - velocity_set = velocity_sets.pop() - precision_policies = set([boundary_applier.precision_policy, boundary_masker.precision_policy]) - assert len(precision_policies) == 1, "All precision policies must be the same" - precision_policy = precision_policies.pop() - compute_backends = set([boundary_applier.compute_backend, boundary_masker.compute_backend]) - assert len(compute_backends) == 1, "All compute backends must be the same" - compute_backend = compute_backends.pop() diff --git a/xlb/operator/boundary_condition_new/boundary_masker/__init__.py b/xlb/operator/boundary_condition_new/boundary_masker/__init__.py deleted file mode 100644 index e33e509..0000000 --- a/xlb/operator/boundary_condition_new/boundary_masker/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from xlb.operator.boundary_condition.boundary_masker.boundary_masker import BoundaryMasker -from xlb.operator.boundary_condition.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker diff --git a/xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py b/xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py deleted file mode 100644 index f6b73f0..0000000 --- a/xlb/operator/boundary_condition_new/boundary_masker/boundary_masker.py +++ /dev/null @@ -1,8 +0,0 @@ -# Base class for all boundary masker operators - -from xlb.operator.operator import Operator - -class BoundaryMasker(Operator): - """ - Operator for creating a boundary mask - """ diff --git a/xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py deleted file mode 100644 index ccc52a3..0000000 --- a/xlb/operator/boundary_condition_new/boundary_masker/indices_boundary_masker.py +++ /dev/null @@ -1,117 +0,0 @@ -# Base class for all equilibriums - -from functools import partial -import numpy as np -import jax.numpy as jnp -from jax import jit -import warp as wp -from typing import Tuple - -from xlb.global_config import GlobalConfig -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.stream.stream import Stream - - -class IndicesBoundaryMasker(Operator): - """ - Operator for creating a boundary mask - """ - - def __init__( - self, - stream_indices: bool, - id_number: int, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set indices - self.id_number = id_number - self.stream_indices = stream_indices - - # Make stream operator - self.stream = Stream(velocity_set, precision_policy, compute_backend) - - @staticmethod - def _indices_to_tuple(indices): - """ - Converts a tensor of indices to a tuple for indexing - TODO: Might be better to index - """ - return tuple([indices[:, i] for i in range(indices.shape[1])]) - - @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this - def jax_implementation(self, indicies, start_index, boundary_id, mask, id_number): - # Get local indices from the meshgrid and the indices - local_indices = self.indices - np.array(start_index)[np.newaxis, :] - - # Remove any indices that are out of bounds - local_indices = local_indices[ - (local_indices[:, 0] >= 0) - & (local_indices[:, 0] < mask.shape[0]) - & (local_indices[:, 1] >= 0) - & (local_indices[:, 1] < mask.shape[1]) - & (local_indices[:, 2] >= 0) - & (local_indices[:, 2] < mask.shape[2]) - ] - - # Set the boundary id - boundary_id = boundary_id.at[self._indices_to_tuple(local_indices)].set( - id_number - ) - - # Stream mask if necessary - if self.stream_indices: - # Make mask then stream to get the edge points - pre_stream_mask = jnp.zeros_like(mask) - pre_stream_mask = pre_stream_mask.at[ - self._indices_to_tuple(local_indices) - ].set(True) - post_stream_mask = self.stream(pre_stream_mask) - - # Set false for points inside the boundary - post_stream_mask = post_stream_mask.at[ - post_stream_mask[..., 0] == True - ].set(False) - - # Get indices on edges - edge_indices = jnp.argwhere(post_stream_mask) - - # Set the mask - mask = mask.at[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ].set( - post_stream_mask[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ] - ) - - else: - # Set the mask - mask = mask.at[self._indices_to_tuple(local_indices)].set(True) - - return boundary_id, mask - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, start_index, boundary_id, mask, id_number): - # Reuse the jax implementation, TODO: implement a warp version - # Convert to jax - boundary_id = wp.jax.to_jax(boundary_id) - mask = wp.jax.to_jax(mask) - - # Call jax implementation - boundary_id, mask = self.jax_implementation( - start_index, boundary_id, mask, id_number - ) - - # Convert back to warp - boundary_id = wp.jax.to_warp(boundary_id) - mask = wp.jax.to_warp(mask) - - return boundary_id, mask diff --git a/xlb/operator/boundary_condition_new/full_bounce_back.py b/xlb/operator/boundary_condition_new/full_bounce_back.py deleted file mode 100644 index 1559609..0000000 --- a/xlb/operator/boundary_condition_new/full_bounce_back.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - -class FullBounceBack(BoundaryCondition): - """ - Full Bounce-back boundary condition for a lattice Boltzmann method simulation. - """ - - def __init__( - self, - boundary_masker: BoundaryMasker, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - - boundary_applier = FullBounceBackApplier( - velocity_set, precision_policy, compute_backend - ) - - super().__init__( - boundary_applier, - boundary_masker, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_indices( - cls, velocity_set, precision_policy, compute_backend - ): - """ - Create a full bounce-back boundary condition from indices. - """ - # Create boundary mask - boundary_mask = IndicesBoundaryMasker( - False, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index 7b0acfe..f910b18 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -43,13 +43,9 @@ def apply_jax(self, f, boundary_id, mask, timestep): ) # Apply collision type boundary conditions - for id_number, bc in self.collision_boundary_conditions.items(): - f_post_collision = bc( - f_pre_collision, - f_post_collision, - boundary_id == id_number, - mask, - ) + f_post_collision = self.collision_boundary_applier.jax_implementation( + f_pre_collision, f_post_collision, mask, boundary_id + ) f_pre_streaming = f_post_collision ## Apply forcing @@ -180,14 +176,12 @@ def kernel( ) ## Apply collision type boundary conditions - #if _boundary_id != wp.uint8(0): - # f_post_collision = self.collision_boundary_conditions[ - # _boundary_id - # ].warp_functional( - # _f, - # f_post_collision, - # _mask, - # ) + f_post_collision = self.collision_boundary_applier.warp_functional( + _f, + f_post_collision, + _boundary_id, + _mask, + ) f_pre_streaming = f_post_collision # store pre streaming vector # Apply forcing diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 0e0675a..b8cbbea 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -9,6 +9,10 @@ from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.boundary_condition import ImplementationStep +from xlb.operator.boundary_condition.boundary_applier import ( + CollisionBoundaryApplier + StreamingBoundaryApplier +) from xlb.operator.precision_caster import PrecisionCaster @@ -56,12 +60,13 @@ def __init__( assert len(compute_backends) == 1, "All compute backends must be the same" compute_backend = compute_backends.pop() - # Get collision and stream boundary conditions - self.collision_boundary_conditions = [bc for bc in boundary_conditions if bc.implementation_step == ImplementationStep.COLLISION] - self.stream_boundary_conditions = [bc for bc in boundary_conditions if bc.implementation_step == ImplementationStep.STREAMING] - - # Make operators for converting the precisions - #self.cast_to_compute = PrecisionCaster( + # Make single operators for all collision and streaming boundary conditions + self.collision_boundary_applier = CollisionBoundaryApplier( + [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.COLLISION] + ) + self.streaming_boundary_applier = StreamingBoundaryApplier( + [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.STREAMING] + ) # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) From 9a296397b2930e5bd1f5196ac5f7e8deba412034 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 25 Mar 2024 10:03:00 -0700 Subject: [PATCH 021/144] added boundar condition example --- examples/CFD_refactor/windtunnel3d.py | 1 + examples/interfaces/boundary_conditions.py | 3 + examples/interfaces/flow_past_sphere.py | 220 ++++++++++++++++++ examples/interfaces/ldc.py | 62 +---- examples/interfaces/taylor_green.py | 13 +- xlb/operator/__init__.py | 3 +- xlb/operator/boundary_condition/__init__.py | 3 + .../boundary_applier/boundary_applier.py | 35 --- .../boundary_condition/boundary_condition.py | 37 +-- .../boundary_condition_registry.py | 2 +- .../boundary_masker/__init__.py | 2 - .../indices_boundary_masker.py | 117 ---------- .../collision_boundary_applier.py | 0 xlb/operator/boundary_condition/do_nothing.py | 89 +++++++ .../boundary_condition/do_nothing_applier.py | 47 ++++ .../boundary_condition/equilibrium.py | 95 ++++++++ .../boundary_condition/equilibrium_applier.py | 50 ++++ .../boundary_condition/full_bounce_back.py | 21 +- .../full_bounce_back_applier.py | 0 .../boundary_condition/halfway_bounce_back.py | 112 +++++++++ .../stream_boundary_applier.py | 2 +- xlb/operator/boundary_masker/__init__.py | 9 + .../boundary_masker/boundary_masker.py | 1 + .../indices_boundary_masker.py | 155 ++++++++++++ .../boundary_masker/planar_boundary_masker.py | 128 ++++++++++ .../boundary_masker/stl_boundary_masker.py | 109 +++++++++ xlb/operator/collision/bgk.py | 54 ++--- .../equilibrium/quadratic_equilibrium.py | 47 ++-- xlb/operator/macroscopic/macroscopic.py | 36 +-- xlb/operator/operator.py | 82 +------ xlb/operator/stepper/nse.py | 109 ++++----- xlb/operator/stepper/stepper.py | 38 +-- xlb/operator/stream/stream.py | 83 ++++--- xlb/solver/nse.py | 6 +- xlb/solver/solver.py | 3 +- xlb/velocity_set/velocity_set.py | 7 + 36 files changed, 1259 insertions(+), 522 deletions(-) create mode 100644 examples/interfaces/boundary_conditions.py create mode 100644 examples/interfaces/flow_past_sphere.py create mode 100644 xlb/operator/boundary_condition/__init__.py delete mode 100644 xlb/operator/boundary_condition/boundary_applier/boundary_applier.py delete mode 100644 xlb/operator/boundary_condition/boundary_masker/__init__.py delete mode 100644 xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py rename xlb/operator/boundary_condition/{boundary_applier => }/collision_boundary_applier.py (100%) create mode 100644 xlb/operator/boundary_condition/do_nothing.py create mode 100644 xlb/operator/boundary_condition/do_nothing_applier.py create mode 100644 xlb/operator/boundary_condition/equilibrium.py create mode 100644 xlb/operator/boundary_condition/equilibrium_applier.py rename xlb/operator/boundary_condition/{boundary_applier => }/full_bounce_back_applier.py (100%) create mode 100644 xlb/operator/boundary_condition/halfway_bounce_back.py rename xlb/operator/boundary_condition/{boundary_applier => }/stream_boundary_applier.py (99%) create mode 100644 xlb/operator/boundary_masker/__init__.py rename xlb/operator/{boundary_condition => }/boundary_masker/boundary_masker.py (99%) create mode 100644 xlb/operator/boundary_masker/indices_boundary_masker.py create mode 100644 xlb/operator/boundary_masker/planar_boundary_masker.py create mode 100644 xlb/operator/boundary_masker/stl_boundary_masker.py diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py index 28d9208..9af75f3 100644 --- a/examples/CFD_refactor/windtunnel3d.py +++ b/examples/CFD_refactor/windtunnel3d.py @@ -14,6 +14,7 @@ from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC + class WindTunnel(IncompressibleNavierStokesSolver): """ This class extends the IncompressibleNavierStokesSolver class to define the boundary conditions for the wind tunnel simulation. diff --git a/examples/interfaces/boundary_conditions.py b/examples/interfaces/boundary_conditions.py new file mode 100644 index 0000000..4f7fc9a --- /dev/null +++ b/examples/interfaces/boundary_conditions.py @@ -0,0 +1,3 @@ +from xlb.operator.boundary_condition import boundary_condition_registry + +print(boundary_condition_registry.ids) diff --git a/examples/interfaces/flow_past_sphere.py b/examples/interfaces/flow_past_sphere.py new file mode 100644 index 0000000..8bfe945 --- /dev/null +++ b/examples/interfaces/flow_past_sphere.py @@ -0,0 +1,220 @@ +# Simple flow past sphere example using the functional interface to xlb + +import time +from tqdm import tqdm +import os +import matplotlib.pyplot as plt +from typing import Any +import numpy as np + +import warp as wp + +wp.init() + +import xlb +from xlb.operator import Operator + +class UniformInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + vel: float, + ): + # Get the global index + i, j, k = wp.tid() + + # Set the velocity + u[0, i, j, k] = vel + u[1, i, j, k] = 0.0 + u[2, i, j, k] = 0.0 + + # Set the density + rho[0, i, j, k] = 1.0 + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u, vel): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + vel, + ], + dim=rho.shape[1:], + ) + return rho, u + + +if __name__ == "__main__": + # Set parameters + compute_backend = xlb.ComputeBackend.WARP + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q19() + + # Make feilds + nr = 256 + vel = 0.05 + shape = (nr, nr, nr) + grid = xlb.grid.WarpGrid(shape=shape) + rho = grid.create_field(cardinality=1, dtype=wp.float32) + u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) + f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) + missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + + # Make operators + initializer = UniformInitializer( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + collision = xlb.operator.collision.BGK( + omega=1.95, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(vel, 0.0, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + equilibrium_bc=equilibrium_bc, + do_nothing_bc=do_nothing_bc, + half_way_bc=half_way_bc, + ) + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Make indices for boundary conditions (sphere) + sphere_radius = 32 + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + indices = np.array(indices).T + indices = wp.from_numpy(indices, dtype=wp.int32) + + # Set boundary conditions on the indices + boundary_id, missing_mask = indices_boundary_masker( + indices, + half_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set inlet bc + lower_bound = (0, 0, 0) + upper_bound = (0, nr, nr) + direction = (1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + equilibrium_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set outlet bc + lower_bound = (nr-1, 0, 0) + upper_bound = (nr-1, nr, nr) + direction = (-1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + do_nothing_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set initial conditions + rho, u = initializer(rho, u, vel) + f0 = equilibrium(rho, u, f0) + + # Time stepping + plot_freq = 512 + save_dir = "flow_past_sphere" + os.makedirs(save_dir, exist_ok=True) + #compute_mlup = False # Plotting results + compute_mlup = True + num_steps = 1024 * 8 + start = time.time() + for _ in tqdm(range(num_steps)): + f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1, f0 = f0, f1 + if (_ % plot_freq == 0) and (not compute_mlup): + rho, u = macroscopic(f0, rho, u) + + # Plot the velocity field and boundary id side by side + plt.subplot(1, 2, 1) + plt.imshow(u[0, :, nr // 2, :].numpy()) + plt.colorbar() + plt.subplot(1, 2, 2) + plt.imshow(boundary_id[0, :, nr // 2, :].numpy()) + plt.colorbar() + plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") + plt.close() + + wp.synchronize() + end = time.time() + + # Print MLUPS + print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") diff --git a/examples/interfaces/ldc.py b/examples/interfaces/ldc.py index e1419a1..962357f 100644 --- a/examples/interfaces/ldc.py +++ b/examples/interfaces/ldc.py @@ -11,63 +11,6 @@ import xlb from xlb.operator import Operator -class TaylorGreenInitializer(Operator): - - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - f0: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, - vel: float, - nr: int, - ): - # Get the global index - i, j, k = wp.tid() - - # Get real pos - x = 2.0 * wp.pi * wp.float(i) / wp.float(nr) - y = 2.0 * wp.pi * wp.float(j) / wp.float(nr) - z = 2.0 * wp.pi * wp.float(k) / wp.float(nr) - - # Compute u - u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1, i, j, k] = - vel * wp.cos(x) * wp.sin(y) * wp.cos(z) - u[2, i, j, k] = 0.0 - - # Compute rho - rho[0, i, j, k] = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * ( - wp.cos(2.0 * x) - + (wp.cos(2.0 * y) - * (wp.cos(2.0 * z) + 2.0)) - ) - + 1.0 - ) - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, f0, rho, u, vel, nr): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - f0, - rho, - u, - vel, - nr, - ], - dim=rho.shape[1:], - ) - return rho, u - if __name__ == "__main__": # Set parameters @@ -87,10 +30,6 @@ def warp_implementation(self, f0, rho, u, vel, nr): mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators - initializer = TaylorGreenInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) collision = xlb.operator.collision.BGK( omega=1.9, velocity_set=velocity_set, @@ -108,6 +47,7 @@ def warp_implementation(self, f0, rho, u, vel, nr): velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( collision=collision, equilibrium=equilibrium, diff --git a/examples/interfaces/taylor_green.py b/examples/interfaces/taylor_green.py index e1419a1..0529b11 100644 --- a/examples/interfaces/taylor_green.py +++ b/examples/interfaces/taylor_green.py @@ -4,6 +4,7 @@ from tqdm import tqdm import os import matplotlib.pyplot as plt +from typing import Any import warp as wp wp.init() @@ -17,9 +18,9 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel( - f0: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, + f0: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), vel: float, nr: int, ): @@ -84,7 +85,7 @@ def warp_implementation(self, f0, rho, u, vel, nr): f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) - mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators initializer = TaylorGreenInitializer( @@ -128,10 +129,10 @@ def warp_implementation(self, f0, rho, u, vel, nr): os.makedirs(save_dir, exist_ok=True) #compute_mlup = False # Plotting results compute_mlup = True - num_steps = 1024 + num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_id, mask, _) + f1 = stepper(f0, f1, boundary_id, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 501a7af..c1232a3 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,3 +1,4 @@ from xlb.operator.operator import Operator from xlb.operator.parallel_operator import ParallelOperator -import xlb.operator.stepper # +import xlb.operator.stepper +import xlb.operator.boundary_masker diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py new file mode 100644 index 0000000..c4d44e8 --- /dev/null +++ b/xlb/operator/boundary_condition/__init__.py @@ -0,0 +1,3 @@ +from xlb.operator.boundary_condition.equilibrium import EquilibriumBC +from xlb.operator.boundary_condition.do_nothing import DoNothingBC +from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC diff --git a/xlb/operator/boundary_condition/boundary_applier/boundary_applier.py b/xlb/operator/boundary_condition/boundary_applier/boundary_applier.py deleted file mode 100644 index 5eb3de0..0000000 --- a/xlb/operator/boundary_condition/boundary_applier/boundary_applier.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -from enum import Enum - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) - -class BoundaryApplier(Operator): - """ - Base class for boundary conditions in a LBM simulation. - """ - - def __init__( - self, - implementation_step: ImplementationStep, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set implementation step - self.implementation_step = implementation_step diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index cf64c66..ef32bfb 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -12,38 +12,25 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_masker import BoundaryMasker -from xlb.operator.boundary_condition.boundary_applier import BoundaryApplier -class BoundaryCondition(): +# Enum for implementation step +class ImplementationStep(Enum): + COLLISION = 1 + STREAMING = 2 + +class BoundaryCondition(Operator): """ Base class for boundary conditions in a LBM simulation. - - Boundary conditions are unique in that they are not operators themselves, - but rather hold operators for applying and making masks for boundary conditions. """ def __init__( self, - boundary_applier: BoundaryApplier, - boundary_masker: BoundaryMasker, + implementation_step: ImplementationStep, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend ): super().__init__(velocity_set, precision_policy, compute_backend) - # Set operators - self.boundary_applier = boundary_applier - self.boundary_masker = boundary_masker - - # Get velocity set, precision policy, and compute backend - velocity_sets = set([boundary_applier.velocity_set, boundary_masker.velocity_set]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - velocity_set = velocity_sets.pop() - precision_policies = set([boundary_applier.precision_policy, boundary_masker.precision_policy]) - assert len(precision_policies) == 1, "All precision policies must be the same" - precision_policy = precision_policies.pop() - compute_backends = set([boundary_applier.compute_backend, boundary_masker.compute_backend]) - assert len(compute_backends) == 1, "All compute backends must be the same" - compute_backend = compute_backends.pop() - - # Get implementation step from boundary applier (TODO: Maybe not add this to the base class) - self.implementation_step = boundary_applier.implementation_step + # Set the implementation step + self.implementation_step = implementation_step diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 23a0d17..5990d53 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -11,7 +11,7 @@ def __init__( self, ): self.ids = {} - self.next_id = 0 + self.next_id = 1 # 0 is reserved for regular streaming def register_boundary_condition(self, boundary_condition): """ diff --git a/xlb/operator/boundary_condition/boundary_masker/__init__.py b/xlb/operator/boundary_condition/boundary_masker/__init__.py deleted file mode 100644 index e33e509..0000000 --- a/xlb/operator/boundary_condition/boundary_masker/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from xlb.operator.boundary_condition.boundary_masker.boundary_masker import BoundaryMasker -from xlb.operator.boundary_condition.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker diff --git a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py deleted file mode 100644 index ccc52a3..0000000 --- a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py +++ /dev/null @@ -1,117 +0,0 @@ -# Base class for all equilibriums - -from functools import partial -import numpy as np -import jax.numpy as jnp -from jax import jit -import warp as wp -from typing import Tuple - -from xlb.global_config import GlobalConfig -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.stream.stream import Stream - - -class IndicesBoundaryMasker(Operator): - """ - Operator for creating a boundary mask - """ - - def __init__( - self, - stream_indices: bool, - id_number: int, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set indices - self.id_number = id_number - self.stream_indices = stream_indices - - # Make stream operator - self.stream = Stream(velocity_set, precision_policy, compute_backend) - - @staticmethod - def _indices_to_tuple(indices): - """ - Converts a tensor of indices to a tuple for indexing - TODO: Might be better to index - """ - return tuple([indices[:, i] for i in range(indices.shape[1])]) - - @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this - def jax_implementation(self, indicies, start_index, boundary_id, mask, id_number): - # Get local indices from the meshgrid and the indices - local_indices = self.indices - np.array(start_index)[np.newaxis, :] - - # Remove any indices that are out of bounds - local_indices = local_indices[ - (local_indices[:, 0] >= 0) - & (local_indices[:, 0] < mask.shape[0]) - & (local_indices[:, 1] >= 0) - & (local_indices[:, 1] < mask.shape[1]) - & (local_indices[:, 2] >= 0) - & (local_indices[:, 2] < mask.shape[2]) - ] - - # Set the boundary id - boundary_id = boundary_id.at[self._indices_to_tuple(local_indices)].set( - id_number - ) - - # Stream mask if necessary - if self.stream_indices: - # Make mask then stream to get the edge points - pre_stream_mask = jnp.zeros_like(mask) - pre_stream_mask = pre_stream_mask.at[ - self._indices_to_tuple(local_indices) - ].set(True) - post_stream_mask = self.stream(pre_stream_mask) - - # Set false for points inside the boundary - post_stream_mask = post_stream_mask.at[ - post_stream_mask[..., 0] == True - ].set(False) - - # Get indices on edges - edge_indices = jnp.argwhere(post_stream_mask) - - # Set the mask - mask = mask.at[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ].set( - post_stream_mask[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ] - ) - - else: - # Set the mask - mask = mask.at[self._indices_to_tuple(local_indices)].set(True) - - return boundary_id, mask - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, start_index, boundary_id, mask, id_number): - # Reuse the jax implementation, TODO: implement a warp version - # Convert to jax - boundary_id = wp.jax.to_jax(boundary_id) - mask = wp.jax.to_jax(mask) - - # Call jax implementation - boundary_id, mask = self.jax_implementation( - start_index, boundary_id, mask, id_number - ) - - # Convert back to warp - boundary_id = wp.jax.to_warp(boundary_id) - mask = wp.jax.to_warp(mask) - - return boundary_id, mask diff --git a/xlb/operator/boundary_condition/boundary_applier/collision_boundary_applier.py b/xlb/operator/boundary_condition/collision_boundary_applier.py similarity index 100% rename from xlb/operator/boundary_condition/boundary_applier/collision_boundary_applier.py rename to xlb/operator/boundary_condition/collision_boundary_applier.py diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py new file mode 100644 index 0000000..5d060ed --- /dev/null +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -0,0 +1,89 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp +from typing import Tuple, Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep, BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + +class DoNothingBC(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + boundary = boundary_id == self.id + flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) + skipped_f = lax.select(flip, f_pre, f_post) + return skipped_f + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=wp.bool), + index: Any, + ): + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + return _f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get boundary id + if boundary_id[0, index[0], index[1], index[2]] == wp.uint8(DoNothing.id): + _f = functional(f_pre, index) + for l in range(_q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary, mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition/do_nothing_applier.py b/xlb/operator/boundary_condition/do_nothing_applier.py new file mode 100644 index 0000000..222dff6 --- /dev/null +++ b/xlb/operator/boundary_condition/do_nothing_applier.py @@ -0,0 +1,47 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator import Operator +from xlb.operator.boundary_condition.boundary_applier import ( + BoundaryApplier, + ImplementationStep, +) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + + +class DoNothingApplier(BoundaryApplier): + """ + Do nothing boundary condition. Basically skips the streaming step + """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) + f = lax.select(do_nothing, f_pre, f_post) + return f diff --git a/xlb/operator/boundary_condition/equilibrium.py b/xlb/operator/boundary_condition/equilibrium.py new file mode 100644 index 0000000..013871b --- /dev/null +++ b/xlb/operator/boundary_condition/equilibrium.py @@ -0,0 +1,95 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp +from typing import Tuple, Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep, BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + +class EquilibriumBC(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + rho: float, + u: Tuple[float, float, float], + equilibrium_operator: Operator, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + # Store the equilibrium information + self.rho = rho + self.u = u + self.equilibrium_operator = equilibrium_operator + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + raise NotImplementedError + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = wp.float32(self.rho) + _u = _u_vec(self.u[0], self.u[1], self.u[2]) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=wp.bool), + index: Any, + ): + _f = self.equilibrium_operator.warp_functional(_rho, _u) + return _f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get boundary id + if boundary_id[0, index[0], index[1], index[2]] == wp.uint8(DoNothing.id): + _f = functional(f_pre, index) + for l in range(_q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary, mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition/equilibrium_applier.py b/xlb/operator/boundary_condition/equilibrium_applier.py new file mode 100644 index 0000000..3b3ef5b --- /dev/null +++ b/xlb/operator/boundary_condition/equilibrium_applier.py @@ -0,0 +1,50 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator import Operator +from xlb.operator.boundary_condition.boundary_applier import ( + BoundaryApplier, + ImplementationStep, +) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + + +class EquilibriumApplier(BoundaryApplier): + """ + Apply Equilibrium boundary condition to the distribution function. + """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary, mask): + equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) + equilibrium_f = jnp.repeat(self.f[None, ...], boundary.shape[0], axis=0) + equilibrium_f = jnp.repeat(equilibrium_f[:, None], boundary.shape[1], axis=1) + equilibrium_f = jnp.repeat(equilibrium_f[:, :, None], boundary.shape[2], axis=2) + f = lax.select(equilibrium_mask, equilibrium_f, f_post) + return f diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index 1559609..97f680b 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -17,11 +17,12 @@ BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.boundary_condition.boundary_masker import ( BoundaryMasker, IndicesBoundaryMasker, ) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.boundary_condition.boundary_applier import FullBounceBackApplier class FullBounceBack(BoundaryCondition): @@ -68,3 +69,21 @@ def from_indices( precision_policy, compute_backend, ) + + @classmethod + def from_stl(cls, velocity_set, precision_policy, compute_backend): + """ + Create a full bounce-back boundary condition from an STL file. + """ + # Create boundary mask + boundary_mask = STLBoundaryMasker( + False, velocity_set, precision_policy, compute_backend + ) + + # Create boundary condition + return cls( + boundary_mask, + velocity_set, + precision_policy, + compute_backend, + ) diff --git a/xlb/operator/boundary_condition/boundary_applier/full_bounce_back_applier.py b/xlb/operator/boundary_condition/full_bounce_back_applier.py similarity index 100% rename from xlb/operator/boundary_condition/boundary_applier/full_bounce_back_applier.py rename to xlb/operator/boundary_condition/full_bounce_back_applier.py diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py new file mode 100644 index 0000000..af6d3e1 --- /dev/null +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -0,0 +1,112 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp +from typing import Tuple, Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep, BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + +class HalfwayBounceBackBC(BoundaryCondition): + """ + Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. + + TODO: Implement moving boundary conditions for this + """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + raise NotImplementedError + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=wp.bool), + index: Any, + ): + + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + + # Get pull index + pull_index = type(index)() + + # If the mask is missing then take the opposite index + if missing_mask[l, index[0], index[1], index[2]] == wp.bool(True): + use_l = _opp_indices[l] + for d in range(self.velocity_set.d): + pull_index[d] = index[d] + + # Pull the distribution function + else: + use_l = l + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, l] + + # Get the distribution function + _f[l] = f[use_l, pull_index[0], pull_index[1], pull_index[2]] + + return _f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get boundary id + if boundary_id[0, index[0], index[1], index[2]] == wp.uint8(DoNothing.id): + _f = functional(f_pre, index) + for l in range(_q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, f, boundary, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, inputs=[f_pre, f_post, f, boundary, missing_mask], dim=f_pre.shape[1:] + ) + return f diff --git a/xlb/operator/boundary_condition/boundary_applier/stream_boundary_applier.py b/xlb/operator/boundary_condition/stream_boundary_applier.py similarity index 99% rename from xlb/operator/boundary_condition/boundary_applier/stream_boundary_applier.py rename to xlb/operator/boundary_condition/stream_boundary_applier.py index 2e68c0b..fd51c5a 100644 --- a/xlb/operator/boundary_condition/boundary_applier/stream_boundary_applier.py +++ b/xlb/operator/boundary_condition/stream_boundary_applier.py @@ -17,7 +17,7 @@ from xlb.boundary_condition.full_bounce_back import FullBounceBack -class CollisionBoundaryCondition(Operator): +class StreamBoundaryCondition(Operator): """ Class for combining collision and boundary conditions together into a single operator. diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py new file mode 100644 index 0000000..a9069c6 --- /dev/null +++ b/xlb/operator/boundary_masker/__init__.py @@ -0,0 +1,9 @@ +from xlb.operator.boundary_masker.boundary_masker import ( + BoundaryMasker, +) +from xlb.operator.boundary_masker.indices_boundary_masker import ( + IndicesBoundaryMasker, +) +from xlb.operator.boundary_masker.planar_boundary_masker import ( + PlanarBoundaryMasker, +) diff --git a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py b/xlb/operator/boundary_masker/boundary_masker.py similarity index 99% rename from xlb/operator/boundary_condition/boundary_masker/boundary_masker.py rename to xlb/operator/boundary_masker/boundary_masker.py index f6b73f0..6fe487f 100644 --- a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py +++ b/xlb/operator/boundary_masker/boundary_masker.py @@ -2,6 +2,7 @@ from xlb.operator.operator import Operator + class BoundaryMasker(Operator): """ Operator for creating a boundary mask diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py new file mode 100644 index 0000000..02a8a87 --- /dev/null +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -0,0 +1,155 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class IndicesBoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + @staticmethod + def _indices_to_tuple(indices): + """ + Converts a tensor of indices to a tuple for indexing + TODO: Might be better to index + """ + return tuple([indices[:, i] for i in range(indices.shape[1])]) + + @Operator.register_backend(ComputeBackend.JAX) + # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this + def jax_implementation(self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0)): + # Get local indices from the meshgrid and the indices + local_indices = self.indices - np.array(start_index)[np.newaxis, :] + + # Remove any indices that are out of bounds + local_indices = local_indices[ + (local_indices[:, 0] >= 0) + & (local_indices[:, 0] < mask.shape[0]) + & (local_indices[:, 1] >= 0) + & (local_indices[:, 1] < mask.shape[1]) + & (local_indices[:, 2] >= 0) + & (local_indices[:, 2] < mask.shape[2]) + ] + + # Set the boundary id + boundary_id = boundary_id.at[0, self._indices_to_tuple(local_indices)].set( + id_number + ) + + # Make mask then stream to get the edge points + pre_stream_mask = jnp.zeros_like(mask) + pre_stream_mask = pre_stream_mask.at[self._indices_to_tuple(local_indices)].set( + True + ) + post_stream_mask = self.stream(pre_stream_mask) + + # Set false for points inside the boundary (NOTE: removing this to be more consistent with the other boundary maskers) + # post_stream_mask = post_stream_mask.at[ + # post_stream_mask[0, ...] == True + # ].set(False) + + # Get indices on edges + edge_indices = jnp.argwhere(post_stream_mask) + + # Set the mask + mask = mask.at[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ].set( + post_stream_mask[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ] + ) + + return boundary_id, mask + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.wp_c + _q = wp.constant(self.velocity_set.q) + + # Construct the warp kernel + @wp.kernel + def kernel( + indices: wp.array2d(dtype=wp.int32), + id_number: wp.int32, + boundary_id: wp.array4d(dtype=wp.uint8), + mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, + ): + # Get the index of indices + ii = wp.tid() + + # Get local indices + index = wp.vec3i() + index[0] = indices[ii, 0] - start_index[0] + index[1] = indices[ii, 1] - start_index[1] + index[2] = indices[ii, 2] - start_index[2] + + # Check if in bounds + if ( + index[0] >= 0 + and index[0] < mask.shape[1] + and index[1] >= 0 + and index[1] < mask.shape[2] + and index[2] >= 0 + and index[2] < mask.shape[3] + ): + + # Stream indices + for l in range(_q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and mask + boundary_id[0, push_index[0], push_index[1], push_index[2]] = wp.uint8( + id_number + ) + mask[l, push_index[0], push_index[1], push_index[2]] = True + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0) + ): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + indices, + id_number, + boundary_id, + mask, + start_index, + ], + dim=indices.shape[0], + ) + + return boundary_id, mask diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py new file mode 100644 index 0000000..24156c9 --- /dev/null +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -0,0 +1,128 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class PlanarBoundaryMasker(Operator): + """ + Operator for creating a boundary mask on a plane of the domain + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + @Operator.register_backend(ComputeBackend.JAX) + # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this + def jax_implementation(self, edge, start_index, boundary_id, mask, id_number): + raise NotImplementedError + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.wp_c + _q = wp.constant(self.velocity_set.q) + + # Construct the warp kernel + @wp.kernel + def kernel( + lower_bound: wp.vec3i, + upper_bound: wp.vec3i, + direction: wp.vec3i, + id_number: wp.int32, + boundary_id: wp.array4d(dtype=wp.uint8), + mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, + ): + # Get the indices of the plane to mask + plane_i, plane_j = wp.tid() + + # Get local indices + if direction[0] != 0: + i = lower_bound[0] - start_index[0] + j = plane_i - start_index[1] + k = plane_j - start_index[2] + elif direction[1] != 0: + i = plane_i - start_index[0] + j = lower_bound[1] - start_index[1] + k = plane_j - start_index[2] + elif direction[2] != 0: + i = plane_i - start_index[0] + j = plane_j - start_index[1] + k = lower_bound[2] - start_index[2] + + # Check if in bounds + if ( + i >= 0 + and i < mask.shape[1] + and j >= 0 + and j < mask.shape[2] + and k >= 0 + and k < mask.shape[3] + ): + # Set the boundary id + boundary_id[0, i, j, k] = wp.uint8(id_number) + + # Set mask for just directions coming from the boundary + for l in range(_q): + d_dot_c = ( + direction[0] * _c[0, l] + + direction[1] * _c[1, l] + + direction[2] * _c[2, l] + ) + if d_dot_c >= 0: + mask[l, i, j, k] = wp.bool(True) + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + lower_bound, + upper_bound, + direction, + id_number, + boundary_id, + mask, + start_index=(0, 0, 0), + ): + # Get plane dimensions + if direction[0] != 0: + dim = (upper_bound[1] - lower_bound[1], upper_bound[2] - lower_bound[2]) + elif direction[1] != 0: + dim = (upper_bound[0] - lower_bound[0], upper_bound[2] - lower_bound[2]) + elif direction[2] != 0: + dim = (upper_bound[0] - lower_bound[0], upper_bound[1] - lower_bound[1]) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + lower_bound, + upper_bound, + direction, + id_number, + boundary_id, + mask, + start_index, + ], + dim=dim, + ) + + return boundary_id, mask diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py new file mode 100644 index 0000000..8ae8456 --- /dev/null +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -0,0 +1,109 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class STLBoundaryMasker(Operator): + """ + Operator for creating a boundary mask from an STL file + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation(self, mesh, id_number, boundary_id, mask, start_index=(0, 0, 0)): + # TODO: Implement this + raise NotImplementedError + + def _construct_warp(self): + # Make constants for warp + _opp_indices = wp.constant( + self._warp_int_lattice_vec(self.velocity_set.opp_indices) + ) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + _id = wp.constant(self.id) + + # Construct the warp kernel + @wp.kernel + def _voxelize_mesh( + voxels: wp.array3d(dtype=wp.uint8), + mesh: wp.uint64, + spacing: wp.vec3, + origin: wp.vec3, + shape: wp.vec(3, wp.uint32), + max_length: float, + material_id: int, + ): + # get index of voxel + i, j, k = wp.tid() + + # position of voxel + ijk = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) + ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center + pos = wp.cw_mul(ijk, spacing) + origin + + # Only evaluate voxel if not set yet + if voxels[i, j, k] != wp.uint8(0): + return + + # evaluate distance of point + face_index = int(0) + face_u = float(0.0) + face_v = float(0.0) + sign = float(0.0) + if wp.mesh_query_point( + mesh, pos, max_length, sign, face_index, face_u, face_v + ): + p = wp.mesh_eval_position(mesh, face_index, face_u, face_v) + delta = pos - p + norm = wp.sqrt(wp.dot(delta, delta)) + + # set point to be solid + if norm < wp.min(spacing): + voxels[i, j, k] = wp.uint8(255) + elif sign < 0: # TODO: fix this + voxels[i, j, k] = wp.uint8(material_id) + else: + pass + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, start_index, boundary_id, mask, id_number): + # Reuse the jax implementation, TODO: implement a warp version + # Convert to jax + boundary_id = wp.jax.to_jax(boundary_id) + mask = wp.jax.to_jax(mask) + + # Call jax implementation + boundary_id, mask = self.jax_implementation( + start_index, boundary_id, mask, id_number + ) + + # Convert back to warp + boundary_id = wp.jax.to_warp(boundary_id) + mask = wp.jax.to_warp(mask) + + return boundary_id, mask diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 4071345..b70be6c 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,6 +1,7 @@ import jax.numpy as jnp from jax import jit import warp as wp +from typing import Any from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -32,21 +33,19 @@ def pallas_implementation( return fout def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) - _d = wp.constant(self.velocity_set.d) + # Set local constants TODO: This is a hack and should be fixed with warp update + _w = self.velocity_set.wp_w _omega = wp.constant(self.compute_dtype(self.omega)) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the functional @wp.func def functional( - f: self._warp_lattice_vec, - feq: self._warp_lattice_vec, - rho: self.compute_dtype, - u: self._warp_u_vec, - ) -> self._warp_lattice_vec: + f: Any, + feq: Any, + rho: Any, + u: Any, + ): fneq = f - feq fout = f - _omega * fneq return fout @@ -54,30 +53,33 @@ def functional( # Construct the warp kernel @wp.kernel def kernel( - f: self._warp_array_type, - feq: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, - fout: self._warp_array_type, + f: wp.array4d(dtype=Any), + feq: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this - # Get the equilibrium - _f = self._warp_lattice_vec() - _feq = self._warp_lattice_vec() - for l in range(_q): - _f[l] = f[l, i, j, k] - _feq[l] = feq[l, i, j, k] + # Load needed values + _f = _f_vec() + _feq = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _feq[l] = feq[l, index[0], index[1], index[2]] _u = self._warp_u_vec() for l in range(_d): - _u[l] = u[l, i, j, k] - _rho = rho[0, i, j, k] + _u[l] = u[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + + # Compute the collision _fout = functional(_f, _feq, _rho, _u) # Write the result - for l in range(_q): - fout[l, i, j, k] = _fout[l] + for l in range(self.velocity_set.q): + fout[l, index[0], index[1], index[2]] = _fout[l] return functional, kernel @@ -85,7 +87,7 @@ def kernel( def warp_implementation(self, f, feq, rho, u, fout): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ f, feq, diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index a1245f1..65a7fb6 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -2,6 +2,7 @@ import jax.numpy as jnp from jax import jit import warp as wp +from typing import Any from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -56,22 +57,27 @@ def pallas_implementation(self, rho, u): return jnp.array(eq) def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) - _d = wp.constant(self.velocity_set.d) + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _w = self.velocity_set.wp_w + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) # Construct the equilibrium functional @wp.func def functional( - rho: self.compute_dtype, u: self._warp_u_vec - ) -> self._warp_lattice_vec: - feq = self._warp_lattice_vec() # empty lattice vector - for l in range(_q): - ## Compute cu + rho: Any, + u: Any, + ): + # Allocate the equilibrium + feq = _f_vec() + + # Compute the equilibrium + for l in range(self.velocity_set.q): + + # Compute cu cu = self.compute_dtype(0.0) - for d in range(_d): + for d in range(self.velocity_set.d): if _c[d, l] == 1: cu += u[d] elif _c[d, l] == -1: @@ -89,23 +95,24 @@ def functional( # Construct the warp kernel @wp.kernel def kernel( - rho: self._warp_array_type, - u: self._warp_array_type, - f: self._warp_array_type, + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # Get the equilibrium - _u = self._warp_u_vec() - for d in range(_d): - _u[d] = u[d, i, j, k] - _rho = rho[0, i, j, k] + _u = _u_vec() + for d in range(self.velocity_set.d): + _u[d] = u[d, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] feq = functional(_rho, _u) # Set the output - for l in range(_q): - f[l, i, j, k] = feq[l] + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = feq[l] return functional, kernel diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 97bd10a..161705e 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jax import jit import warp as wp -from typing import Tuple +from typing import Tuple, Any from xlb.global_config import GlobalConfig from xlb.velocity_set.velocity_set import VelocitySet @@ -75,19 +75,19 @@ def pallas_implementation(self, f): def _construct_warp(self): # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) + _c = self.velocity_set.wp_c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) # Construct the functional @wp.func - def functional(f: self._warp_lattice_vec): + def functional(f: _f_vec): # Compute rho and u rho = self.compute_dtype(0.0) - u = self._warp_u_vec() - for l in range(_q): + u = _u_vec() + for l in range(self.velocity_set.q): rho += f[l] - for d in range(_d): + for d in range(self.velocity_set.d): if _c[d, l] == 1: u[d] += f[l] elif _c[d, l] == -1: @@ -95,28 +95,28 @@ def functional(f: self._warp_lattice_vec): u /= rho return rho, u - # return u, rho # Construct the kernel @wp.kernel def kernel( - f: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # Get the equilibrium - _f = self._warp_lattice_vec() - for l in range(_q): - _f[l] = f[l, i, j, k] + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] (_rho, _u) = functional(_f) # Set the output - rho[0, i, j, k] = _rho - for d in range(_d): - u[d, i, j, k] = _u[d] + rho[0, index[0], index[1], index[2]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1], index[2]] = _u[d] return functional, kernel diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index f3ea901..ff9ba58 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,5 +1,6 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) import warp as wp +from typing import Any from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy, Precision @@ -123,90 +124,11 @@ def _precision_to_dtype(self, precision): elif precision == Precision.FP16: return self.backend.float16 - ### WARP specific types ### - # These are used to define the types for the warp backend - # TODO: There might be a better place to put these - @property - def _warp_u_vec(self): - """ - Returns the warp type for velocity - """ - return wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - - @property - def _warp_lattice_vec(self): - """ - Returns the warp type for the lattice - """ - return wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - - @property - def _warp_int_lattice_vec(self): - """ - Returns the warp type for the streaming matrix (c) - """ - return wp.vec(self.velocity_set.q, dtype=wp.int32) - - @property - def _warp_bool_lattice_vec(self): - """ - Returns the warp type for the streaming matrix (c) - """ - #return wp.vec(self.velocity_set.q, dtype=wp.bool) - return wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO bool breaks - - @property - def _warp_stream_mat(self): - """ - Returns the warp type for the streaming matrix (c) - """ - return wp.mat( - (self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype - ) - - @property - def _warp_int_stream_mat(self): - """ - Returns the warp type for the streaming matrix (c) - """ - return wp.mat( - (self.velocity_set.d, self.velocity_set.q), dtype=wp.int32 - ) - - @property - def _warp_array_type(self): - """ - Returns the warp type for arrays - """ - if self.velocity_set.d == 2: - return wp.array3d(dtype=self.store_dtype) - elif self.velocity_set.d == 3: - return wp.array4d(dtype=self.store_dtype) - - @property - def _warp_uint8_array_type(self): - """ - Returns the warp type for arrays - """ - if self.velocity_set.d == 2: - return wp.array3d(dtype=wp.uint8) - elif self.velocity_set.d == 3: - return wp.array4d(dtype=wp.uint8) - - @property - def _warp_bool_array_type(self): - """ - Returns the warp type for arrays - """ - if self.velocity_set.d == 2: - return wp.array3d(dtype=wp.bool) - elif self.velocity_set.d == 3: - return wp.array4d(dtype=wp.bool) - def _construct_warp(self): """ Construct the warp functional and kernel of the operator TODO: Maybe a better way to do this? Maybe add this to the backend decorator? + Leave it for now, as it is not clear how the warp backend will evolve """ return None, None diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index f910b18..8e89e51 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -4,12 +4,13 @@ from functools import partial from jax import jit import warp as wp +from typing import Any from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.stepper import Stepper -from xlb.operator.boundary_condition import ImplementationStep +#from xlb.operator.boundary_condition import ImplementationStep from xlb.operator.collision import BGK @@ -20,7 +21,7 @@ class IncompressibleNavierStokesStepper(Stepper): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0, 5)) - def apply_jax(self, f, boundary_id, mask, timestep): + def apply_jax(self, f, boundary_id, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ @@ -44,7 +45,7 @@ def apply_jax(self, f, boundary_id, mask, timestep): # Apply collision type boundary conditions f_post_collision = self.collision_boundary_applier.jax_implementation( - f_pre_collision, f_post_collision, mask, boundary_id + f_pre_collision, f_post_collision, missing_mask, boundary_id ) f_pre_streaming = f_post_collision @@ -61,7 +62,7 @@ def apply_jax(self, f, boundary_id, mask, timestep): f_pre_streaming, f_post_streaming, boundary_id == id_number, - mask, + missing_mask, ) # Copy back to store precision @@ -71,7 +72,7 @@ def apply_jax(self, f, boundary_id, mask, timestep): @Operator.register_backend(ComputeBackend.PALLAS) @partial(jit, static_argnums=(0,)) - def apply_pallas(self, fin, boundary_id, mask, timestep): + def apply_pallas(self, fin, boundary_id, missing_mask, timestep): # Raise warning that the boundary conditions are not implemented ################################################################ warning("Boundary conditions are not implemented for PALLAS backend currently") @@ -127,88 +128,73 @@ def _pallas_collide_and_stream(f): return fout def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - _nr_boundary_conditions = wp.constant(len(self.boundary_conditions)) + # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) + _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) + _half_way_bc = wp.uint8(self.half_way_bc.id) # Construct the kernel @wp.kernel def kernel( - f_0: self._warp_array_type, - f_1: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=Any), timestep: int, - max_i: int, - max_j: int, - max_k: int, ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO warp should fix this - # Get the f, boundary id and mask - _f = self._warp_lattice_vec() - _boundary_id = boundary_id[0, i, j, k] - _mask = self._warp_bool_lattice_vec() - for l in range(_q): - _f[l] = f_0[l, i, j, k] - + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) else: - _mask[l] = wp.uint8(0) + _missing_mask[l] = wp.uint8(0) + + # Apply streaming boundary conditions + if _boundary_id == wp.uint8(0): + # Regular streaming + _post_stream_f = self.stream.warp_functional(f_0, index) + elif _boundary_id == _equilibrium_bc: + # Equilibrium boundary condition + _post_stream_f = self.equilibrium_bc.warp_functional(f_0, missing_mask, index) + elif _boundary_id == _do_nothing_bc: + # Do nothing boundary condition + _post_stream_f = self.do_nothing_bc.warp_functional(f_0, missing_mask, index) + elif _boundary_id == _half_way_bc: + # Half way boundary condition + _post_stream_f = self.half_way_bc.warp_functional(f_0, missing_mask, index) + #_post_stream_f = self.stream.warp_functional(f_0, index) # Compute rho and u - rho, u = self.macroscopic.warp_functional(_f) + rho, u = self.macroscopic.warp_functional(_post_stream_f) # Compute equilibrium feq = self.equilibrium.warp_functional(rho, u) # Apply collision f_post_collision = self.collision.warp_functional( - _f, + _post_stream_f, feq, rho, u, ) - ## Apply collision type boundary conditions - f_post_collision = self.collision_boundary_applier.warp_functional( - _f, - f_post_collision, - _boundary_id, - _mask, - ) - f_pre_streaming = f_post_collision # store pre streaming vector - - # Apply forcing - # if self.forcing_op is not None: - # f = self.forcing.warp_functional(f, timestep) - - # Apply streaming - for l in range(_q): - # Get the streamed indices - streamed_i, streamed_j, streamed_k = self.stream.warp_functional( - l, i, j, k, max_i, max_j, max_k - ) - streamed_l = l - - ## Modify the streamed indices based on streaming boundary condition - # if _boundary_id != 0: - # streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].warp_functional( - # streamed_l, streamed_i, streamed_j, streamed_k, self._warp_max_i, self._warp_max_j, self._warp_max_k - # ) - - # Set the output - f_1[streamed_l, streamed_i, streamed_j, streamed_k] = f_pre_streaming[l] + # Set the output + for l in range(self.velocity_set.q): + f_1[l, index[0], index[1], index[2]] = f_post_collision[l] return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_id, mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_id, missing_mask, timestep): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -216,11 +202,8 @@ def warp_implementation(self, f_0, f_1, boundary_id, mask, timestep): f_0, f_1, boundary_id, - mask, + missing_mask, timestep, - f_0.shape[1], - f_0.shape[2], - f_0.shape[3], ], dim=f_0.shape[1:], ) diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index b8cbbea..c1acefa 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -8,11 +8,9 @@ from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.boundary_condition.boundary_applier import ( - CollisionBoundaryApplier - StreamingBoundaryApplier -) +#from xlb.operator.boundary_condition.boundary_condition import ImplementationStep +#from xlb.operator.boundary_condition.boundary_applier.collision_boundary_applier import CollisionBoundaryApplier +#from xlb.operator.boundary_condition.boundary_applier.stream_boundary_applier import StreamBoundaryApplier from xlb.operator.precision_caster import PrecisionCaster @@ -27,7 +25,9 @@ def __init__( stream, equilibrium, macroscopic, - boundary_conditions=[], + equilibrium_bc, + do_nothing_bc, + half_way_bc, forcing=None, ): # Set parameters @@ -35,7 +35,14 @@ def __init__( self.stream = stream self.equilibrium = equilibrium self.macroscopic = macroscopic - self.boundary_conditions = boundary_conditions + self.equilibrium_bc = equilibrium_bc + self.do_nothing_bc = do_nothing_bc + self.half_way_bc = half_way_bc + self.boundary_conditions = [ + equilibrium_bc, + do_nothing_bc, + half_way_bc, + ] self.forcing = forcing # Get all operators for checking @@ -44,9 +51,8 @@ def __init__( stream, equilibrium, macroscopic, - *[bc.boundary_applier for bc in boundary_conditions], - *[bc.boundary_masker for bc in boundary_conditions], - forcing, + *self.boundary_conditions, + #forcing, ] # Get velocity set, precision policy, and compute backend @@ -61,12 +67,12 @@ def __init__( compute_backend = compute_backends.pop() # Make single operators for all collision and streaming boundary conditions - self.collision_boundary_applier = CollisionBoundaryApplier( - [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.COLLISION] - ) - self.streaming_boundary_applier = StreamingBoundaryApplier( - [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.STREAMING] - ) + #self.collision_boundary_applier = CollisionBoundaryApplier( + # [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.COLLISION] + #) + #self.streaming_boundary_applier = StreamBoundaryApplier( + # [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.STREAMING] + #) # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 5033cac..e4d255c 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit, vmap import warp as wp +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -12,7 +13,7 @@ class Stream(Operator): """ - Base class for all streaming operators. + Base class for all streaming operators. This is used for pulling the distribution """ @Operator.register_backend(ComputeBackend.JAX) @@ -21,6 +22,8 @@ def jax_implementation(self, f): """ JAX implementation of the streaming step. + TODO: Make sure this works with pull scheme. + Parameters ---------- f: jax.numpy.ndarray @@ -42,66 +45,61 @@ def _streaming_jax_i(f, c): The updated distribution function after streaming. """ if self.velocity_set.d == 2: - return jnp.roll(f, (c[0], c[1]), axis=(0, 1)) + return jnp.roll(f, (-c[0], -c[1]), axis=(0, 1)) # Negative sign is used to pull the distribution instead of pushing elif self.velocity_set.d == 3: - return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) + return jnp.roll(f, (-c[0], -c[1], -c[2]), axis=(0, 1, 2)) return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T ) def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_int_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the funcional to get streamed indices @wp.func def functional( - l: int, - i: int, - j: int, - k: int, - max_i: int, - max_j: int, - max_k: int, + f: wp.array4d(dtype=Any), + index: Any, ): - streamed_i = i + _c[0, l] - streamed_j = j + _c[1, l] - streamed_k = k + _c[2, l] - if streamed_i < 0: - streamed_i = max_i - 1 - elif streamed_i >= max_i: - streamed_i = 0 - if streamed_j < 0: - streamed_j = max_j - 1 - elif streamed_j >= max_j: - streamed_j = 0 - if streamed_k < 0: - streamed_k = max_k - 1 - elif streamed_k >= max_k: - streamed_k = 0 - return streamed_i, streamed_j, streamed_k + + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + + # Get pull index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, l] + + if pull_index[d] < 0: + pull_index[d] = f.shape[d + 1] - 1 + elif pull_index[d] >= f.shape[d + 1]: + pull_index[d] = 0 + + # Read the distribution function + _f[l] = f[l, pull_index[0], pull_index[1], pull_index[2]] + + return _f # Construct the warp kernel @wp.kernel def kernel( - f_0: self._warp_array_type, - f_1: self._warp_array_type, - max_i: int, - max_j: int, - max_k: int, + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # Set the output - for l in range(_q): - streamed_i, streamed_j, streamed_k = functional( - l, i, j, k, max_i, max_j, max_k - ) - f_1[l, streamed_i, streamed_j, streamed_k] = f_0[l, i, j, k] + _f = functional(f_0, index) + + # Write the output + for l in range(self.velocity_set.q): + f_1[l, index[0], index[1], index[2]] = _f[l] return functional, kernel @@ -109,13 +107,10 @@ def kernel( def warp_implementation(self, f_0, f_1): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ f_0, f_1, - f_0.shape[1], - f_0.shape[2], - f_0.shape[3], ], dim=f_0.shape[1:], ) diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 6c4c0c2..1b0ab43 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -6,9 +6,9 @@ from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend -from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.collision import BGK, KBC +from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium +from xlb.operator.collision.bgk import BGK +from xlb.operator.collision.kbc import KBC from xlb.operator.stream import Stream from xlb.operator.macroscopic import Macroscopic from xlb.solver.solver import Solver diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 7d3db77..335fd72 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -1,9 +1,8 @@ # Base class for all stepper operators from xlb.compute_backend import ComputeBackend -from xlb.operator.boundary_condition import ImplementationStep from xlb.global_config import GlobalConfig -from xlb.operator import Operator +from xlb.operator.operator import Operator class Solver(Operator): diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 0564fde..a137b87 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -43,6 +43,13 @@ def __init__(self, d, q, c, w): self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() + # Make warp constants for these vectors + # TODO: Following warp updates these may not be necessary + self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) + self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow + self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) + + def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) From 5c4732f3e81944b7c5a5b388dcd7971f09144dc1 Mon Sep 17 00:00:00 2001 From: Oliver Date: Tue, 26 Mar 2024 16:03:47 -0700 Subject: [PATCH 022/144] LDC example --- examples/interfaces/boundary_conditions.py | 157 ++++++++- examples/interfaces/ldc.py | 323 +++++++++++++++--- examples/interfaces/taylor_green.py | 131 +++++-- xlb/__init__.py | 2 +- xlb/grid/jax_grid.py | 13 +- xlb/grid/warp_grid.py | 5 +- xlb/operator/boundary_condition/__init__.py | 5 + .../boundary_condition/boundary_condition.py | 4 +- .../boundary_condition_registry.py | 10 +- .../collision_boundary_applier.py | 143 -------- xlb/operator/boundary_condition/do_nothing.py | 56 ++- .../boundary_condition/do_nothing_applier.py | 47 --- .../boundary_condition/equilibrium.py | 66 +++- .../boundary_condition/equilibrium_applier.py | 50 --- .../boundary_condition/full_bounce_back.py | 89 ----- ...back_applier.py => fullway_bounce_back.py} | 81 +++-- .../boundary_condition/halfway_bounce_back.py | 61 +++- .../stream_boundary_applier.py | 143 -------- .../indices_boundary_masker.py | 23 +- .../boundary_masker/planar_boundary_masker.py | 65 +++- .../boundary_masker/stl_boundary_masker.py | 7 +- xlb/operator/collision/bgk.py | 2 +- .../equilibrium/quadratic_equilibrium.py | 3 - xlb/operator/operator.py | 22 +- xlb/operator/stepper/nse.py | 95 ++++-- xlb/operator/stepper/stepper.py | 81 +++-- xlb/operator/stream/stream.py | 6 +- xlb/operator/test/test.py | 1 - xlb/precision_policy.py | 33 ++ 29 files changed, 989 insertions(+), 735 deletions(-) delete mode 100644 xlb/operator/boundary_condition/collision_boundary_applier.py delete mode 100644 xlb/operator/boundary_condition/do_nothing_applier.py delete mode 100644 xlb/operator/boundary_condition/equilibrium_applier.py delete mode 100644 xlb/operator/boundary_condition/full_bounce_back.py rename xlb/operator/boundary_condition/{full_bounce_back_applier.py => fullway_bounce_back.py} (50%) delete mode 100644 xlb/operator/boundary_condition/stream_boundary_applier.py delete mode 100644 xlb/operator/test/test.py diff --git a/examples/interfaces/boundary_conditions.py b/examples/interfaces/boundary_conditions.py index 4f7fc9a..70648bf 100644 --- a/examples/interfaces/boundary_conditions.py +++ b/examples/interfaces/boundary_conditions.py @@ -1,3 +1,156 @@ -from xlb.operator.boundary_condition import boundary_condition_registry +# Simple script to run different boundary conditions with jax and warp backends +import time +from tqdm import tqdm +import os +import matplotlib.pyplot as plt +from typing import Any +import numpy as np +import jax.numpy as jnp +import warp as wp -print(boundary_condition_registry.ids) +wp.init() + +import xlb + +def run_boundary_conditions(backend): + + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX + + # Set the precision policy + precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set + velocity_set = xlb.velocity_set.D3Q19() + + # Make grid + nr = 256 + shape = (nr, nr, nr) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make feilds + f_pre = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f_post = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) + + # Make needed operators + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + halfway_bounce_back_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + fullway_bounce_back_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Make indices for boundary conditions (sphere) + sphere_radius = 32 + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + indices = np.array(indices).T + if backend == "jax": + indices = jnp.array(indices) + elif backend == "warp": + indices = wp.from_numpy(indices, dtype=wp.int32) + + # Test equilibrium boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + equilibrium_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask, f) + print(f"Equilibrium BC test passed for {backend}") + + # Test do nothing boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + do_nothing_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask, f) + + # Test halfway bounce back boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + halfway_bounce_back_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f) + print(f"Halfway bounce back BC test passed for {backend}") + + # Test the full boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + fullway_bounce_back_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f) + print(f"Fullway bounce back BC test passed for {backend}") + + +if __name__ == "__main__": + + # Test the boundary conditions + backends = ["warp", "jax"] + for backend in backends: + run_boundary_conditions(backend) diff --git a/examples/interfaces/ldc.py b/examples/interfaces/ldc.py index 962357f..ced31fe 100644 --- a/examples/interfaces/ldc.py +++ b/examples/interfaces/ldc.py @@ -1,86 +1,319 @@ -# Simple Taylor green example using the functional interface to xlb +# Simple flow past sphere example using the functional interface to xlb import time from tqdm import tqdm import os import matplotlib.pyplot as plt +from typing import Any +import numpy as np import warp as wp + wp.init() import xlb from xlb.operator import Operator -if __name__ == "__main__": +class UniformInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + + # Set the velocity + u[0, i, j, k] = 0.0 + u[1, i, j, k] = 0.0 + u[2, i, j, k] = 0.0 + + # Set the density + rho[0, i, j, k] = 1.0 + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + ], + dim=rho.shape[1:], + ) + return rho, u + - # Set parameters - compute_backend = xlb.ComputeBackend.WARP +def run_ldc(backend, compute_mlup=True): + + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX + + # Set the precision policy precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set velocity_set = xlb.velocity_set.D3Q19() - # Make feilds + # Make grid nr = 256 shape = (nr, nr, nr) - grid = xlb.grid.WarpGrid(shape=shape) - rho = grid.create_field(cardinality=1, dtype=wp.float32) - u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) - f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) - mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make feilds + rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) + u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) + f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators + initializer = UniformInitializer( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) collision = xlb.operator.collision.BGK( - omega=1.9, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + omega=1.9, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0, 0.10, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream, - boundary_conditions=[]) + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + boundary_conditions=[equilibrium_bc, do_nothing_bc, half_way_bc, full_way_bc], + ) + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Set inlet bc (bottom x face) + lower_bound = (0, 1, 1) + upper_bound = (0, nr-1, nr-1) + direction = (1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + equilibrium_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) - # Parrallelize the stepper - #stepper = grid.parallelize_operator(stepper) + # Set outlet bc (top x face) + lower_bound = (nr-1, 1, 1) + upper_bound = (nr-1, nr-1, nr-1) + direction = (-1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + #do_nothing_bc.id, + full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (bottom y face) + lower_bound = (1, 0, 1) + upper_bound = (nr, 0, nr) + direction = (0, 1, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + #half_way_bc.id, + full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (top y face) + lower_bound = (1, nr-1, 1) + upper_bound = (nr, nr-1, nr) + direction = (0, -1, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + #half_way_bc.id, + full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (bottom z face) + lower_bound = (1, 1, 0) + upper_bound = (nr, nr, 0) + direction = (0, 0, 1) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + #half_way_bc.id, + full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (top z face) + lower_bound = (1, 1, nr-1) + upper_bound = (nr, nr, nr-1) + direction = (0, 0, -1) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + #half_way_bc.id, + full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set full way bc (sphere) + """ + sphere_radius = nr // 8 + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + indices = np.array(indices).T + indices = wp.from_numpy(indices, dtype=wp.int32) + boundary_id, missing_mask = indices_boundary_masker( + indices, + full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + """ # Set initial conditions - rho, u = initializer(f0, rho, u, 0.1, nr) - f0 = equilibrium(rho, u, f0) + if backend == "warp": + rho, u = initializer(rho, u) + f0 = equilibrium(rho, u, f0) + elif backend == "jax": + rho = rho + 1.0 + f0 = equilibrium(rho, u) # Time stepping - plot_freq = 32 - save_dir = "taylor_green" + plot_freq = 512 + save_dir = "ldc" os.makedirs(save_dir, exist_ok=True) - #compute_mlup = False # Plotting results - compute_mlup = True - num_steps = 1024 + num_steps = nr * 32 start = time.time() + for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_id, mask, _) - f1, f0 = f0, f1 + # Time step + if backend == "warp": + f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1, f0 = f0, f1 + elif backend == "jax": + f0 = stepper(f0, boundary_id, missing_mask, _) + + # Plot if necessary if (_ % plot_freq == 0) and (not compute_mlup): - rho, u = macroscopic(f0, rho, u) - plt.imshow(u[0, :, nr//2, :].numpy()) + if backend == "warp": + rho, u = macroscopic(f0, rho, u) + local_rho = rho.numpy() + local_u = u.numpy() + elif backend == "jax": + local_rho, local_u = macroscopic(f0) + + # Plot the velocity field, rho and boundary id side by side + plt.subplot(1, 3, 1) + plt.imshow(np.linalg.norm(u[:, :, nr // 2, :], axis=0)) + plt.colorbar() + plt.subplot(1, 3, 2) + plt.imshow(rho[0, :, nr // 2, :]) plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(4)}.png") + plt.subplot(1, 3, 3) + plt.imshow(boundary_id[0, :, nr // 2, :]) + plt.colorbar() + plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() + wp.synchronize() end = time.time() # Print MLUPS print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + +if __name__ == "__main__": + + # Run the LDC example + backends = ["warp", "jax"] + #backends = ["jax"] + for backend in backends: + run_ldc(backend, compute_mlup=True) diff --git a/examples/interfaces/taylor_green.py b/examples/interfaces/taylor_green.py index 0529b11..4d97476 100644 --- a/examples/interfaces/taylor_green.py +++ b/examples/interfaces/taylor_green.py @@ -4,21 +4,61 @@ from tqdm import tqdm import os import matplotlib.pyplot as plt +from functools import partial from typing import Any - +import jax.numpy as jnp +from jax import jit import warp as wp + wp.init() import xlb from xlb.operator import Operator class TaylorGreenInitializer(Operator): + """ + Initialize the Taylor-Green vortex. + """ + + @Operator.register_backend(xlb.ComputeBackend.JAX) + #@partial(jit, static_argnums=(0)) + def jax_implementation(self, vel, nr): + # Make meshgrid + x = jnp.linspace(0, 2 * jnp.pi, nr) + y = jnp.linspace(0, 2 * jnp.pi, nr) + z = jnp.linspace(0, 2 * jnp.pi, nr) + X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij") + + # Compute u + u = jnp.stack( + [ + vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), + - vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), + jnp.zeros_like(X), + ], + axis=0, + ) + + # Compute rho + rho = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * ( + jnp.cos(2.0 * X) + + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0)) + ) + + 1.0 + ) + rho = jnp.expand_dims(rho, axis=0) + + return rho, u def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel( - f0: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), vel: float, @@ -54,12 +94,11 @@ def kernel( return None, kernel @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, f0, rho, u, vel, nr): + def warp_implementation(self, rho, u, vel, nr): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ - f0, rho, u, vel, @@ -69,23 +108,35 @@ def warp_implementation(self, f0, rho, u, vel, nr): ) return rho, u -if __name__ == "__main__": +def run_taylor_green(backend, compute_mlup=True): + + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX - # Set parameters - compute_backend = xlb.ComputeBackend.WARP + # Set the precision policy precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set velocity_set = xlb.velocity_set.D3Q19() - # Make feilds - nr = 256 + # Make grid + nr = 128 shape = (nr, nr, nr) - grid = xlb.grid.WarpGrid(shape=shape) - rho = grid.create_field(cardinality=1, dtype=wp.float32) - u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) - f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) - missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make feilds + rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) + u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) + f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators initializer = TaylorGreenInitializer( @@ -113,35 +164,57 @@ def warp_implementation(self, f0, rho, u, vel, nr): collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, - stream=stream, - boundary_conditions=[]) + stream=stream) - # Parrallelize the stepper + # Parrallelize the stepper TODO: Add this functionality #stepper = grid.parallelize_operator(stepper) # Set initial conditions - rho, u = initializer(f0, rho, u, 0.1, nr) - f0 = equilibrium(rho, u, f0) + if backend == "warp": + rho, u = initializer(rho, u, 0.1, nr) + f0 = equilibrium(rho, u, f0) + elif backend == "jax": + rho, u = initializer(0.1, nr) + f0 = equilibrium(rho, u) # Time stepping plot_freq = 32 save_dir = "taylor_green" os.makedirs(save_dir, exist_ok=True) - #compute_mlup = False # Plotting results - compute_mlup = True - num_steps = 1024 * 8 + num_steps = 8192 start = time.time() + for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_id, missing_mask, _) - f1, f0 = f0, f1 + # Time step + if backend == "warp": + f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1, f0 = f0, f1 + elif backend == "jax": + f0 = stepper(f0, boundary_id, missing_mask, _) + + # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): - rho, u = macroscopic(f0, rho, u) - plt.imshow(u[0, :, nr//2, :].numpy()) + if backend == "warp": + rho, u = macroscopic(f0, rho, u) + local_u = u.numpy() + elif backend == "jax": + rho, local_u = macroscopic(f0) + + + plt.imshow(local_u[0, :, nr//2, :]) plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(4)}.png") + plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() wp.synchronize() end = time.time() # Print MLUPS print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + +if __name__ == "__main__": + + # Run Taylor-Green vortex on different backends + #backends = ["warp", "jax"] + backends = ["jax"] + for backend in backends: + run_taylor_green(backend, compute_mlup=False) diff --git a/xlb/__init__.py b/xlb/__init__.py index 88dcff2..84d38c5 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -1,6 +1,6 @@ # Enum classes from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy +from xlb.precision_policy import PrecisionPolicy, Precision from xlb.physics_type import PhysicsType # Config diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 1698c34..d2b579a 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -6,6 +6,7 @@ from xlb.grid import Grid from xlb.compute_backend import ComputeBackend from xlb.operator import Operator +from xlb.precision_policy import Precision class JaxGrid(Grid): def __init__(self, shape): @@ -31,8 +32,8 @@ def _initialize_jax_backend(self): else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) ) self.grid_shape_per_gpu = ( - self.grid_shape[0] // self.nDevices, - ) + self.grid_shape[1:] + self.shape[0] // self.nDevices, + ) + self.shape[1:] def parallelize_operator(self, operator: Operator): @@ -73,15 +74,15 @@ def _parallel_operator(f): return f - def create_field(self, name: str, cardinality: int, callback=None): + def create_field(self, cardinality: int, precision: Precision, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape) # Create field if callback is None: - f = jax.numpy.full(shape, 0.0, dtype=self.precision_policy) - if self.sharding is not None: - f = jax.make_sharded_array(self.sharding, f) + f = jax.numpy.full(shape, 0.0, dtype=precision.jax_dtype) + #if self.sharding is not None: + # f = jax.make_sharded_array(self.sharding, f) else: f = jax.make_array_from_callback(shape, self.sharding, callback) diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index e4d160e..97b337b 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -2,6 +2,7 @@ from xlb.grid import Grid from xlb.operator import Operator +from xlb.precision_policy import Precision class WarpGrid(Grid): def __init__(self, shape): @@ -11,12 +12,12 @@ def parallelize_operator(self, operator: Operator): # TODO: Implement parallelization of the operator raise NotImplementedError("Parallelization of the operator is not implemented yet for the WarpGrid") - def create_field(self, cardinality: int, dtype, callback=None): + def create_field(self, cardinality: int, precision: Precision, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape) # Create the field - f = wp.zeros(shape, dtype=dtype) + f = wp.zeros(shape, dtype=precision.wp_dtype) # Raise error on callback if callback is not None: diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index c4d44e8..27e0472 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,3 +1,8 @@ +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition_registry import ( + BoundaryConditionRegistry, +) from xlb.operator.boundary_condition.equilibrium import EquilibriumBC from xlb.operator.boundary_condition.do_nothing import DoNothingBC from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC +from xlb.operator.boundary_condition.fullway_bounce_back import FullwayBounceBackBC diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index ef32bfb..7f7aadc 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -13,11 +13,13 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator + # Enum for implementation step class ImplementationStep(Enum): COLLISION = 1 STREAMING = 2 + class BoundaryCondition(Operator): """ Base class for boundary conditions in a LBM simulation. @@ -28,7 +30,7 @@ def __init__( implementation_step: ImplementationStep, velocity_set: VelocitySet, precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend + compute_backend: ComputeBackend, ): super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 5990d53..0a3b2c7 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -2,6 +2,7 @@ Registry for boundary conditions in a LBM simulation. """ + class BoundaryConditionRegistry: """ Registry for boundary conditions in a LBM simulation. @@ -10,8 +11,9 @@ class BoundaryConditionRegistry: def __init__( self, ): - self.ids = {} - self.next_id = 1 # 0 is reserved for regular streaming + self.id_to_bc = {} # Maps id number to boundary condition + self.bc_to_id = {} # Maps boundary condition to id number + self.next_id = 1 # 0 is reserved for no boundary condition def register_boundary_condition(self, boundary_condition): """ @@ -19,7 +21,9 @@ def register_boundary_condition(self, boundary_condition): """ id = self.next_id self.next_id += 1 - self.ids[boundary_condition] = id + self.id_to_bc[id] = boundary_condition + self.bc_to_id[boundary_condition] = id return id + boundary_condition_registry = BoundaryConditionRegistry() diff --git a/xlb/operator/boundary_condition/collision_boundary_applier.py b/xlb/operator/boundary_condition/collision_boundary_applier.py deleted file mode 100644 index 633ef47..0000000 --- a/xlb/operator/boundary_condition/collision_boundary_applier.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -from enum import Enum - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - -# Import all collision boundary conditions -from xlb.boundary_condition.full_bounce_back import FullBounceBack - - -class CollisionBoundaryApplier(Operator): - """ - Class for combining collision and boundary conditions together - into a single operator. - """ - - def __init__( - self, - boundary_appliers: list[BoundaryApplier], - ): - # Set boundary conditions - self.boundary_appliers = boundary_appliers - - # Check that all boundary conditions have the same implementation step other properties - for bc in self.boundary_appliers: - assert bc.implementation_step == ImplementationStep.COLLISION, ( - "All boundary conditions must be applied during the collision step." - ) - - # Get velocity set, precision policy, and compute backend - velocity_sets = set([bc.velocity_set for bc in self.boundary_appliers]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - velocity_set = velocity_sets.pop() - precision_policies = set([bc.precision_policy for bc in self.boundary_appliers]) - assert len(precision_policies) == 1, "All precision policies must be the same" - precision_policy = precision_policies.pop() - compute_backends = set([bc.compute_backend for bc in self.boundary_appliers]) - assert len(compute_backends) == 1, "All compute backends must be the same" - compute_backend = compute_backends.pop() - - # Make all possible collision boundary conditions to obtain the warp functions - self.full_bounce_back = FullBounceBack( - None, velocity_set, precision_policy, compute_backend - ) - - super().__init__( - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, mask, boundary_id): - """ - Apply collision boundary conditions - """ - for bc in self.boundary_conditions: - f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) - return f_post, mask - - def _construct_warp(self): - """ - Construct the warp kernel for the collision boundary condition. - """ - - # Make constants for warp - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - - # Get boolean constants for all boundary conditions - if any([isinstance(bc, FullBounceBack) for bc in self.boundary_conditions]): - _use_full_bounce_back = wp.constant(True) - - # Construct the funcional for all boundary conditions - @wp.func - def functional( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - boundary_id: wp.uint8, - mask: self._warp_bool_lattice_vec, - ): - # Apply all boundary conditions - # Full bounce-back - if _use_full_bounce_back: - if boundary_id == self.full_bounce_back.id: - f_post = self.full_bounce_back.warp_functional(f_pre, f_post, mask) - - return f_post - - # Construct the warp kernel - @wp.kernel - def kernel( - f_pre: self._warp_array_type, - f_post: self._warp_array_type, - f: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, - ): - # Get the global index - i, j, k = wp.tid() - - # Make vectors for the lattice - _f_pre = self._warp_lattice_vec() - _f_post = self._warp_lattice_vec() - _mask = self._warp_bool_lattice_vec() - _boundary_id = wp.uint8(boundary_id[0, i, j, k]) - for l in range(_q): - _f_pre[l] = f_pre[l, i, j, k] - _f_post[l] = f_post[l, i, j, k] - - # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) - - # Apply all boundary conditions - if _boundary_id != wp.uint8(0): - _f_post = functional(_f_pre, _f_post, _boundary_id, _mask) - - # Write the result to the output - for l in range(_q): - f[l, i, j, k] = _f_post[l] - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary_id, mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary_id, mask], dim=f_pre.shape[1:] - ) - return f diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py index 5d060ed..46a6fdd 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -14,13 +14,21 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_condition import ImplementationStep, BoundaryCondition -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + class DoNothingBC(BoundaryCondition): """ - Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + Do nothing boundary condition. This boundary condition skips the streaming step for the + boundary nodes. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( @@ -37,8 +45,10 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + # TODO: This is unoptimized boundary = boundary_id == self.id flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) skipped_f = lax.select(flip, f_pre, f_post) @@ -47,12 +57,15 @@ def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func def functional( f: wp.array4d(dtype=Any), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: Any, index: Any, ): _f = _f_vec() @@ -67,23 +80,42 @@ def kernel( f_post: wp.array4d(dtype=Any), boundary_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get boundary id - if boundary_id[0, index[0], index[1], index[2]] == wp.uint8(DoNothing.id): - _f = functional(f_pre, index) - for l in range(_q): - f_post[l, index[0], index[1], index[2]] = _f[l] + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(DoNothingBC.id): + _f = functional(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1], index[2]] + + # Write the result + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary, mask): + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): # Launch the warp kernel wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/do_nothing_applier.py b/xlb/operator/boundary_condition/do_nothing_applier.py deleted file mode 100644 index 222dff6..0000000 --- a/xlb/operator/boundary_condition/do_nothing_applier.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.boundary_condition.boundary_applier import ( - BoundaryApplier, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - -class DoNothingApplier(BoundaryApplier): - """ - Do nothing boundary condition. Basically skips the streaming step - """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - - def __init__( - self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - super().__init__( - ImplementationStep.STREAMING, - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) - f = lax.select(do_nothing, f_pre, f_post) - return f diff --git a/xlb/operator/boundary_condition/equilibrium.py b/xlb/operator/boundary_condition/equilibrium.py index 013871b..6de68ec 100644 --- a/xlb/operator/boundary_condition/equilibrium.py +++ b/xlb/operator/boundary_condition/equilibrium.py @@ -14,13 +14,20 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_condition import ImplementationStep, BoundaryCondition -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + class EquilibriumBC(BoundaryCondition): """ Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( @@ -34,7 +41,7 @@ def __init__( ): # Store the equilibrium information self.rho = rho - self.u = u + self.u = u self.equilibrium_operator = equilibrium_operator # Call the parent constructor @@ -46,21 +53,35 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): - raise NotImplementedError + # TODO: This is unoptimized + feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) + feq = jnp.reshape(feq, (self.velocity_set.q, 1, 1, 1)) + feq = jnp.repeat(feq, f_pre.shape[1], axis=1) + feq = jnp.repeat(feq, f_pre.shape[2], axis=2) + feq = jnp.repeat(feq, f_pre.shape[3], axis=3) + boundary = boundary_id == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + skipped_f = lax.select(boundary, feq, f_post) + return skipped_f def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(self.rho) _u = _u_vec(self.u[0], self.u[1], self.u[2]) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func def functional( f: wp.array4d(dtype=Any), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: Any, index: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) @@ -73,23 +94,42 @@ def kernel( f_post: wp.array4d(dtype=Any), boundary_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get boundary id - if boundary_id[0, index[0], index[1], index[2]] == wp.uint8(DoNothing.id): - _f = functional(f_pre, index) - for l in range(_q): - f_post[l, index[0], index[1], index[2]] = _f[l] + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(EquilibriumBC.id): + _f = functional(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1], index[2]] + + # Write the result + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary, mask): + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): # Launch the warp kernel wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/equilibrium_applier.py b/xlb/operator/boundary_condition/equilibrium_applier.py deleted file mode 100644 index 3b3ef5b..0000000 --- a/xlb/operator/boundary_condition/equilibrium_applier.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.boundary_condition.boundary_applier import ( - BoundaryApplier, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - -class EquilibriumApplier(BoundaryApplier): - """ - Apply Equilibrium boundary condition to the distribution function. - """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - - def __init__( - self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - super().__init__( - ImplementationStep.STREAMING, - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) - equilibrium_f = jnp.repeat(self.f[None, ...], boundary.shape[0], axis=0) - equilibrium_f = jnp.repeat(equilibrium_f[:, None], boundary.shape[1], axis=1) - equilibrium_f = jnp.repeat(equilibrium_f[:, :, None], boundary.shape[2], axis=2) - f = lax.select(equilibrium_mask, equilibrium_f, f_post) - return f diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py deleted file mode 100644 index 97f680b..0000000 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) -from xlb.operator.boundary_condition.boundary_applier import FullBounceBackApplier - - -class FullBounceBack(BoundaryCondition): - """ - Full Bounce-back boundary condition for a lattice Boltzmann method simulation. - """ - - def __init__( - self, - boundary_masker: BoundaryMasker, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - - boundary_applier = FullBounceBackApplier( - velocity_set, precision_policy, compute_backend - ) - - super().__init__( - boundary_applier, - boundary_masker, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_indices( - cls, velocity_set, precision_policy, compute_backend - ): - """ - Create a full bounce-back boundary condition from indices. - """ - # Create boundary mask - boundary_mask = IndicesBoundaryMasker( - False, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_stl(cls, velocity_set, precision_policy, compute_backend): - """ - Create a full bounce-back boundary condition from an STL file. - """ - # Create boundary mask - boundary_mask = STLBoundaryMasker( - False, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) diff --git a/xlb/operator/boundary_condition/full_bounce_back_applier.py b/xlb/operator/boundary_condition/fullway_bounce_back.py similarity index 50% rename from xlb/operator/boundary_condition/full_bounce_back_applier.py rename to xlb/operator/boundary_condition/fullway_bounce_back.py index 48860fd..547cde1 100644 --- a/xlb/operator/boundary_condition/full_bounce_back_applier.py +++ b/xlb/operator/boundary_condition/fullway_bounce_back.py @@ -8,22 +8,26 @@ from functools import partial import numpy as np import warp as wp +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.boundary_condition.boundary_applier import ( - BoundaryApplier, +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, ImplementationStep, ) -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) -class FullBounceBackApplier(BoundaryApplier): +class FullwayBounceBackBC(BoundaryCondition): """ Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( @@ -40,28 +44,30 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary_id, mask): + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): boundary = boundary_id == self.id - flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) - flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) - return flipped_f + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + return lax.select(boundary, f_pre[self.velocity_set.opp_indices], f_post) def _construct_warp(self): - # Make constants for warp - _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) + # Set local constants TODO: This is a hack and should be fixed with warp update + _opp_indices = self.velocity_set.wp_opp_indices _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - _id = wp.constant(self.id) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func def functional( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, + f_pre: Any, + f_post: Any, + missing_mask: Any, ): - fliped_f = self._warp_lattice_vec() + fliped_f = _f_vec() for l in range(_q): fliped_f[l] = f_pre[_opp_indices[l]] return fliped_f @@ -69,45 +75,52 @@ def functional( # Construct the warp kernel @wp.kernel def kernel( - f_pre: self._warp_array_type, - f_post: self._warp_array_type, - f: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + # Make vectors for the lattice - _f_pre = self._warp_lattice_vec() - _f_post = self._warp_lattice_vec() - _mask = self._warp_bool_lattice_vec() - for l in range(_q): - _f_pre[l] = f_pre[l, i, j, k] - _f_post[l] = f_post[l, i, j, k] + _f_pre = _f_vec() + _f_post = _f_vec() + _mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] # TODO fix vec bool - if mask[l, i, j, k]: + if missing_mask[l, index[0], index[1], index[2]]: _mask[l] = wp.uint8(1) else: _mask[l] = wp.uint8(0) # Check if the boundary is active - if boundary_id[i, j, k] == wp.uint8(_id: + if _boundary_id == wp.uint8(FullwayBounceBackBC.id): _f = functional(_f_pre, _f_post, _mask) else: _f = _f_post # Write the result to the output - for l in range(_q): - f[l, i, j, k] = _f[l] + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] return functional, kernel + @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary, mask): + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): # Launch the warp kernel wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index af6d3e1..f8a4dd7 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -14,8 +14,14 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_condition import ImplementationStep, BoundaryCondition -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + class HalfwayBounceBackBC(BoundaryCondition): """ @@ -23,6 +29,7 @@ class HalfwayBounceBackBC(BoundaryCondition): TODO: Implement moving boundary conditions for this """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) def __init__( @@ -31,7 +38,6 @@ def __init__( precision_policy: PrecisionPolicy, compute_backend: ComputeBackend, ): - # Call the parent constructor super().__init__( ImplementationStep.STREAMING, @@ -41,33 +47,37 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): - raise NotImplementedError + boundary = boundary_id == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + return lax.select(missing_mask & boundary, f_pre[self.velocity_set.opp_indices], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _c = self.velocity_set.wp_c _opp_indices = self.velocity_set.wp_opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func def functional( f: wp.array4d(dtype=Any), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: Any, index: Any, ): - # Pull the distribution function _f = _f_vec() for l in range(self.velocity_set.q): - # Get pull index pull_index = type(index)() # If the mask is missing then take the opposite index - if missing_mask[l, index[0], index[1], index[2]] == wp.bool(True): + if missing_mask[l] == wp.uint8(1): use_l = _opp_indices[l] for d in range(self.velocity_set.d): pull_index[d] = index[d] @@ -90,23 +100,42 @@ def kernel( f_post: wp.array4d(dtype=Any), boundary_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get boundary id - if boundary_id[0, index[0], index[1], index[2]] == wp.uint8(DoNothing.id): - _f = functional(f_pre, index) - for l in range(_q): - f_post[l, index[0], index[1], index[2]] = _f[l] + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + _f = functional(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1], index[2]] + + # Write the distribution function + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): # Launch the warp kernel wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary, missing_mask], dim=f_pre.shape[1:] + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/stream_boundary_applier.py b/xlb/operator/boundary_condition/stream_boundary_applier.py deleted file mode 100644 index fd51c5a..0000000 --- a/xlb/operator/boundary_condition/stream_boundary_applier.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np -from enum import Enum - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - -# Import all collision boundary conditions -from xlb.boundary_condition.full_bounce_back import FullBounceBack - - -class StreamBoundaryCondition(Operator): - """ - Class for combining collision and boundary conditions together - into a single operator. - """ - - def __init__( - self, - boundary_appliers: list[BoundaryApplier], - ): - # Set boundary conditions - self.boundary_appliers = boundary_appliers - - # Check that all boundary conditions have the same implementation step other properties - for bc in self.boundary_appliers: - assert bc.implementation_step == ImplementationStep.COLLISION, ( - "All boundary conditions must be applied during the collision step." - ) - - # Get velocity set, precision policy, and compute backend - velocity_sets = set([bc.velocity_set for bc in self.boundary_appliers]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - velocity_set = velocity_sets.pop() - precision_policies = set([bc.precision_policy for bc in self.boundary_appliers]) - assert len(precision_policies) == 1, "All precision policies must be the same" - precision_policy = precision_policies.pop() - compute_backends = set([bc.compute_backend for bc in self.boundary_appliers]) - assert len(compute_backends) == 1, "All compute backends must be the same" - compute_backend = compute_backends.pop() - - # Make all possible collision boundary conditions to obtain the warp functions - self.full_bounce_back = FullBounceBack( - None, velocity_set, precision_policy, compute_backend - ) - - super().__init__( - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, mask, boundary_id): - """ - Apply collision boundary conditions - """ - for bc in self.boundary_conditions: - f_post, mask = bc.jax_implementation(f_pre, f_post, mask, boundary_id) - return f_post, mask - - def _construct_warp(self): - """ - Construct the warp kernel for the collision boundary condition. - """ - - # Make constants for warp - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - - # Get boolean constants for all boundary conditions - if any([isinstance(bc, FullBounceBack) for bc in self.boundary_conditions]): - _use_full_bounce_back = wp.constant(True) - - # Construct the funcional for all boundary conditions - @wp.func - def functional( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - boundary_id: wp.uint8, - mask: self._warp_bool_lattice_vec, - ): - # Apply all boundary conditions - # Full bounce-back - if _use_full_bounce_back: - if boundary_id == self.full_bounce_back.id: - f_post = self.full_bounce_back.warp_functional(f_pre, f_post, mask) - - return f_post - - # Construct the warp kernel - @wp.kernel - def kernel( - f_pre: self._warp_array_type, - f_post: self._warp_array_type, - f: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, - ): - # Get the global index - i, j, k = wp.tid() - - # Make vectors for the lattice - _f_pre = self._warp_lattice_vec() - _f_post = self._warp_lattice_vec() - _mask = self._warp_bool_lattice_vec() - _boundary_id = wp.uint8(boundary_id[0, i, j, k]) - for l in range(_q): - _f_pre[l] = f_pre[l, i, j, k] - _f_post[l] = f_post[l, i, j, k] - - # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) - - # Apply all boundary conditions - if _boundary_id != wp.uint8(0): - _f_post = functional(_f_pre, _f_post, _boundary_id, _mask) - - # Write the result to the output - for l in range(_q): - f[l, i, j, k] = _f_post[l] - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary_id, mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, inputs=[f_pre, f_post, f, boundary_id, mask], dim=f_pre.shape[1:] - ) - return f diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 02a8a87..b9e9f5b 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -42,9 +42,13 @@ def _indices_to_tuple(indices): @Operator.register_backend(ComputeBackend.JAX) # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this - def jax_implementation(self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0)): + def jax_implementation( + self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0) + ): + # TODO: This is somewhat untested and unoptimized + # Get local indices from the meshgrid and the indices - local_indices = self.indices - np.array(start_index)[np.newaxis, :] + local_indices = indices - np.array(start_index)[np.newaxis, :] # Remove any indices that are out of bounds local_indices = local_indices[ @@ -68,7 +72,7 @@ def jax_implementation(self, indices, id_number, boundary_id, mask, start_index= ) post_stream_mask = self.stream(pre_stream_mask) - # Set false for points inside the boundary (NOTE: removing this to be more consistent with the other boundary maskers) + # Set false for points inside the boundary (NOTE: removing this to be more consistent with the other boundary maskers, maybe add back in later) # post_stream_mask = post_stream_mask.at[ # post_stream_mask[0, ...] == True # ].set(False) @@ -119,7 +123,6 @@ def kernel( and index[2] >= 0 and index[2] < mask.shape[3] ): - # Stream indices for l in range(_q): # Get the index of the streaming direction @@ -128,16 +131,16 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id[0, push_index[0], push_index[1], push_index[2]] = wp.uint8( - id_number - ) + boundary_id[ + 0, push_index[0], push_index[1], push_index[2] + ] = wp.uint8(id_number) mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation( - self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0) + self, indices, id_number, boundary_id, missing_mask, start_index=(0, 0, 0) ): # Launch the warp kernel wp.launch( @@ -146,10 +149,10 @@ def warp_implementation( indices, id_number, boundary_id, - mask, + missing_mask, start_index, ], dim=indices.shape[0], ) - return boundary_id, mask + return boundary_id, missing_mask diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 24156c9..4cd405c 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -31,8 +31,54 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this - def jax_implementation(self, edge, start_index, boundary_id, mask, id_number): - raise NotImplementedError + def jax_implementation( + self, + lower_bound, + upper_bound, + direction, + id_number, + boundary_id, + mask, + start_index=(0, 0, 0), + ): + # Get plane dimensions + if direction[0] != 0: + dim = ( + upper_bound[1] - lower_bound[1] + 1, + upper_bound[2] - lower_bound[2] + 1, + ) + elif direction[1] != 0: + dim = ( + upper_bound[0] - lower_bound[0] + 1, + upper_bound[2] - lower_bound[2] + 1, + ) + elif direction[2] != 0: + dim = ( + upper_bound[0] - lower_bound[0] + 1, + upper_bound[1] - lower_bound[1] + 1, + ) + + # Get the constants + _c = self.velocity_set.wp_c + _q = self.velocity_set.q + + # Get the mask + for i in range(dim[0]): + for j in range(dim[1]): + for k in range(_q): + d_dot_c = ( + direction[0] * _c[0, k] + + direction[1] * _c[1, k] + + direction[2] * _c[2, k] + ) + if d_dot_c >= 0: + mask[k, i, j] = True + + # Get the boundary id + boundary_id[:, :, :] = id_number + + return boundary_id, mask + def _construct_warp(self): # Make constants for warp @@ -104,11 +150,20 @@ def warp_implementation( ): # Get plane dimensions if direction[0] != 0: - dim = (upper_bound[1] - lower_bound[1], upper_bound[2] - lower_bound[2]) + dim = ( + upper_bound[1] - lower_bound[1] + 1, + upper_bound[2] - lower_bound[2] + 1, + ) elif direction[1] != 0: - dim = (upper_bound[0] - lower_bound[0], upper_bound[2] - lower_bound[2]) + dim = ( + upper_bound[0] - lower_bound[0] + 1, + upper_bound[2] - lower_bound[2] + 1, + ) elif direction[2] != 0: - dim = (upper_bound[0] - lower_bound[0], upper_bound[1] - lower_bound[1]) + dim = ( + upper_bound[0] - lower_bound[0] + 1, + upper_bound[1] - lower_bound[1] + 1, + ) # Launch the warp kernel wp.launch( diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index 8ae8456..cda8c00 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -28,11 +28,16 @@ def __init__( ): super().__init__(velocity_set, precision_policy, compute_backend) + # TODO: Implement this + raise NotImplementedError + # Make stream operator self.stream = Stream(velocity_set, precision_policy, compute_backend) @Operator.register_backend(ComputeBackend.JAX) - def jax_implementation(self, mesh, id_number, boundary_id, mask, start_index=(0, 0, 0)): + def jax_implementation( + self, mesh, id_number, boundary_id, mask, start_index=(0, 0, 0) + ): # TODO: Implement this raise NotImplementedError diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index b70be6c..69718aa 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -61,7 +61,7 @@ def kernel( ): # Get the global index i, j, k = wp.tid() - index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this # Load needed values _f = _f_vec() diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 65a7fb6..d5bb20a 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -15,8 +15,6 @@ class QuadraticEquilibrium(Equilibrium): """ Quadratic equilibrium of Boltzmann equation using hermite polynomials. Standard equilibrium model for LBM. - - TODO: move this to a separate file and lower and higher order equilibriums """ @Operator.register_backend(ComputeBackend.JAX) @@ -74,7 +72,6 @@ def functional( # Compute the equilibrium for l in range(self.velocity_set.q): - # Compute cu cu = self.compute_dtype(0.0) for d in range(self.velocity_set.d): diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index ff9ba58..1724ffc 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -103,26 +103,20 @@ def compute_dtype(self): """ Returns the compute dtype """ - return self._precision_to_dtype(self.precision_policy.compute_precision) + if self.compute_backend == ComputeBackend.JAX: + return self.precision_policy.compute_precision.jax_dtype + elif self.compute_backend == ComputeBackend.WARP: + return self.precision_policy.compute_precision.wp_dtype @property def store_dtype(self): """ Returns the store dtype """ - return self._precision_to_dtype(self.precision_policy.store_precision) - - def _precision_to_dtype(self, precision): - """ - Convert the precision to the corresponding dtype - TODO: Maybe move this to precision policy? - """ - if precision == Precision.FP64: - return self.backend.float64 - elif precision == Precision.FP32: - return self.backend.float32 - elif precision == Precision.FP16: - return self.backend.float16 + if self.compute_backend == ComputeBackend.JAX: + return self.precision_policy.store_precision.jax_dtype + elif self.compute_backend == ComputeBackend.WARP: + return self.precision_policy.store_precision.wp_dtype def _construct_warp(self): """ diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index 8e89e51..63c32ee 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -10,8 +10,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.stepper import Stepper -#from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.collision import BGK +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep class IncompressibleNavierStokesStepper(Stepper): @@ -20,17 +19,17 @@ class IncompressibleNavierStokesStepper(Stepper): """ @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0, 5)) + @partial(jit, static_argnums=(0, 4), donate_argnums=(1)) def apply_jax(self, f, boundary_id, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ - # Cast to compute precision - f_pre_collision = self.precision_policy.cast_to_compute_jax(f) + # Cast to compute precision TODO add this back in + #f_pre_collision = self.precision_policy.cast_to_compute_jax(f) # Compute the macroscopic variables - rho, u = self.macroscopic(f_pre_collision) + rho, u = self.macroscopic(f) # Compute equilibrium feq = self.equilibrium(rho, u) @@ -44,39 +43,42 @@ def apply_jax(self, f, boundary_id, missing_mask, timestep): ) # Apply collision type boundary conditions - f_post_collision = self.collision_boundary_applier.jax_implementation( - f_pre_collision, f_post_collision, missing_mask, boundary_id - ) - f_pre_streaming = f_post_collision + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.COLLISION: + f_post_collision = bc( + f, + f_post_collision, + boundary_id, + missing_mask, + ) ## Apply forcing # if self.forcing_op is not None: # f = self.forcing_op.apply_jax(f, timestep) # Apply streaming - f_post_streaming = self.stream(f_pre_streaming) + f_post_streaming = self.stream(f_post_collision) # Apply boundary conditions - for id_number, bc in self.stream_boundary_conditions.items(): - f_post_streaming = bc( - f_pre_streaming, - f_post_streaming, - boundary_id == id_number, - missing_mask, - ) + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.STREAMING: + f_post_streaming = bc( + f_post_collision, + f_post_streaming, + boundary_id, + missing_mask, + ) # Copy back to store precision - f = self.precision_policy.cast_to_store_jax(f_post_streaming) + #f = self.precision_policy.cast_to_store_jax(f_post_streaming) - return f + return f_post_streaming @Operator.register_backend(ComputeBackend.PALLAS) @partial(jit, static_argnums=(0,)) def apply_pallas(self, fin, boundary_id, missing_mask, timestep): # Raise warning that the boundary conditions are not implemented - ################################################################ warning("Boundary conditions are not implemented for PALLAS backend currently") - ################################################################ from xlb.operator.parallel_operator import ParallelOperator @@ -130,10 +132,15 @@ def _pallas_collide_and_stream(f): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Get the boundary condition ids _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) - _half_way_bc = wp.uint8(self.half_way_bc.id) + _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) + _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) # Construct the kernel @wp.kernel @@ -146,7 +153,7 @@ def kernel( ): # Get the global index i, j, k = wp.tid() - index = wp.vec3i(i, j, k) # TODO warp should fix this + index = wp.vec3i(i, j, k) # TODO warp should fix this # Get the boundary id and missing mask _boundary_id = boundary_id[0, index[0], index[1], index[2]] @@ -159,34 +166,48 @@ def kernel( _missing_mask[l] = wp.uint8(0) # Apply streaming boundary conditions - if _boundary_id == wp.uint8(0): - # Regular streaming - _post_stream_f = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + if _boundary_id == _equilibrium_bc: # Equilibrium boundary condition - _post_stream_f = self.equilibrium_bc.warp_functional(f_0, missing_mask, index) + f_post_stream = self.equilibrium_bc.warp_functional( + f_0, _missing_mask, index + ) elif _boundary_id == _do_nothing_bc: # Do nothing boundary condition - _post_stream_f = self.do_nothing_bc.warp_functional(f_0, missing_mask, index) - elif _boundary_id == _half_way_bc: + f_post_stream = self.do_nothing_bc.warp_functional( + f_0, _missing_mask, index + ) + elif _boundary_id == _halfway_bounce_back_bc: # Half way boundary condition - _post_stream_f = self.half_way_bc.warp_functional(f_0, missing_mask, index) - #_post_stream_f = self.stream.warp_functional(f_0, index) - + f_post_stream = self.halfway_bounce_back_bc.warp_functional( + f_0, _missing_mask, index + ) + else: + # Regular streaming + f_post_stream = self.stream.warp_functional(f_0, index) + # Compute rho and u - rho, u = self.macroscopic.warp_functional(_post_stream_f) + rho, u = self.macroscopic.warp_functional(f_post_stream) # Compute equilibrium feq = self.equilibrium.warp_functional(rho, u) # Apply collision f_post_collision = self.collision.warp_functional( - _post_stream_f, + f_post_stream, feq, rho, u, ) + # Apply collision type boundary conditions + if _boundary_id == _fullway_bounce_back_bc: + # Full way boundary condition + f_post_collision = self.fullway_bounce_back_bc.warp_functional( + f_post_stream, + f_post_collision, + _missing_mask, + ) + # Set the output for l in range(self.velocity_set.q): f_1[l, index[0], index[1], index[2]] = f_post_collision[l] diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c1acefa..e1eed44 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -8,9 +8,6 @@ from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -#from xlb.operator.boundary_condition.boundary_condition import ImplementationStep -#from xlb.operator.boundary_condition.boundary_applier.collision_boundary_applier import CollisionBoundaryApplier -#from xlb.operator.boundary_condition.boundary_applier.stream_boundary_applier import StreamBoundaryApplier from xlb.operator.precision_caster import PrecisionCaster @@ -25,24 +22,15 @@ def __init__( stream, equilibrium, macroscopic, - equilibrium_bc, - do_nothing_bc, - half_way_bc, - forcing=None, + boundary_conditions=[], + forcing=None, # TODO: Add forcing later ): - # Set parameters + # Add operators self.collision = collision self.stream = stream self.equilibrium = equilibrium self.macroscopic = macroscopic - self.equilibrium_bc = equilibrium_bc - self.do_nothing_bc = do_nothing_bc - self.half_way_bc = half_way_bc - self.boundary_conditions = [ - equilibrium_bc, - do_nothing_bc, - half_way_bc, - ] + self.boundary_conditions = boundary_conditions self.forcing = forcing # Get all operators for checking @@ -52,8 +40,9 @@ def __init__( equilibrium, macroscopic, *self.boundary_conditions, - #forcing, ] + if forcing is not None: + self.operators.append(forcing) # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators]) @@ -66,13 +55,57 @@ def __init__( assert len(compute_backends) == 1, "All compute backends must be the same" compute_backend = compute_backends.pop() - # Make single operators for all collision and streaming boundary conditions - #self.collision_boundary_applier = CollisionBoundaryApplier( - # [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.COLLISION] - #) - #self.streaming_boundary_applier = StreamBoundaryApplier( - # [bc.boundary_applier for bc in boundary_conditions if bc.implementation_step == ImplementationStep.STREAMING] - #) + # Add boundary conditions + # Warp cannot handle lists of functions currently + # Because of this we manually unpack the boundary conditions + ############################################ + # TODO: Fix this later + ############################################ + from xlb.operator.boundary_condition.equilibrium import EquilibriumBC + from xlb.operator.boundary_condition.do_nothing import DoNothingBC + from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.fullway_bounce_back import FullwayBounceBackBC + self.equilibrium_bc = None + self.do_nothing_bc = None + self.halfway_bounce_back_bc = None + self.fullway_bounce_back_bc = None + for bc in boundary_conditions: + if isinstance(bc, EquilibriumBC): + self.equilibrium_bc = bc + elif isinstance(bc, DoNothingBC): + self.do_nothing_bc = bc + elif isinstance(bc, HalfwayBounceBackBC): + self.halfway_bounce_back_bc = bc + elif isinstance(bc, FullwayBounceBackBC): + self.fullway_bounce_back_bc = bc + if self.equilibrium_bc is None: + self.equilibrium_bc = EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0), + equilibrium_operator=self.equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + if self.do_nothing_bc is None: + self.do_nothing_bc = DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + if self.halfway_bounce_back_bc is None: + self.halfway_bounce_back_bc = HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + if self.fullway_bounce_back_bc is None: + self.fullway_bounce_back_bc = FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + ############################################ # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index e4d255c..c5fd16e 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -45,7 +45,9 @@ def _streaming_jax_i(f, c): The updated distribution function after streaming. """ if self.velocity_set.d == 2: - return jnp.roll(f, (-c[0], -c[1]), axis=(0, 1)) # Negative sign is used to pull the distribution instead of pushing + return jnp.roll( + f, (-c[0], -c[1]), axis=(0, 1) + ) # Negative sign is used to pull the distribution instead of pushing elif self.velocity_set.d == 3: return jnp.roll(f, (-c[0], -c[1], -c[2]), axis=(0, 1, 2)) @@ -64,11 +66,9 @@ def functional( f: wp.array4d(dtype=Any), index: Any, ): - # Pull the distribution function _f = _f_vec() for l in range(self.velocity_set.q): - # Get pull index pull_index = type(index)() for d in range(self.velocity_set.d): diff --git a/xlb/operator/test/test.py b/xlb/operator/test/test.py deleted file mode 100644 index 7d4290a..0000000 --- a/xlb/operator/test/test.py +++ /dev/null @@ -1 +0,0 @@ -x = 1 diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 0ba6c1c..db8a422 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -2,12 +2,45 @@ from enum import Enum, auto +import jax.numpy as jnp +import warp as wp class Precision(Enum): FP64 = auto() FP32 = auto() FP16 = auto() + UINT8 = auto() + BOOL = auto() + @property + def wp_dtype(self): + if self == Precision.FP64: + return wp.float64 + elif self == Precision.FP32: + return wp.float32 + elif self == Precision.FP16: + return wp.float16 + elif self == Precision.UINT8: + return wp.uint8 + elif self == Precision.BOOL: + return wp.bool + else: + raise ValueError("Invalid precision") + + @property + def jax_dtype(self): + if self == Precision.FP64: + return jnp.float64 + elif self == Precision.FP32: + return jnp.float32 + elif self == Precision.FP16: + return jnp.float16 + elif self == Precision.UINT8: + return jnp.uint8 + elif self == Precision.BOOL: + return jnp.bool_ + else: + raise ValueError("Invalid precision") class PrecisionPolicy(Enum): FP64FP64 = auto() From 03872479d159b90f5491648dcbbdf3111675fb97 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 28 Mar 2024 20:16:31 -0700 Subject: [PATCH 023/144] Working! --- examples/interfaces/ldc.py | 72 ++++-------------- examples/interfaces/taylor_green.py | 6 +- .../boundary_masker/planar_boundary_masker.py | 74 +++++++++++-------- xlb/operator/stepper/nse.py | 8 +- 4 files changed, 67 insertions(+), 93 deletions(-) diff --git a/examples/interfaces/ldc.py b/examples/interfaces/ldc.py index ced31fe..d75ce8a 100644 --- a/examples/interfaces/ldc.py +++ b/examples/interfaces/ldc.py @@ -65,7 +65,7 @@ def run_ldc(backend, compute_mlup=True): velocity_set = xlb.velocity_set.D3Q19() # Make grid - nr = 256 + nr = 128 shape = (nr, nr, nr) if backend == "jax": grid = xlb.grid.JaxGrid(shape=shape) @@ -115,32 +115,17 @@ def run_ldc(backend, compute_mlup=True): precision_policy=precision_policy, compute_backend=compute_backend, ) - do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, ) - full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream, - boundary_conditions=[equilibrium_bc, do_nothing_bc, half_way_bc, full_way_bc], - ) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, + boundary_conditions=[equilibrium_bc, half_way_bc], ) planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( velocity_set=velocity_set, @@ -170,8 +155,7 @@ def run_ldc(backend, compute_mlup=True): lower_bound, upper_bound, direction, - #do_nothing_bc.id, - full_way_bc.id, + half_way_bc.id, boundary_id, missing_mask, (0, 0, 0) @@ -185,8 +169,7 @@ def run_ldc(backend, compute_mlup=True): lower_bound, upper_bound, direction, - #half_way_bc.id, - full_way_bc.id, + half_way_bc.id, boundary_id, missing_mask, (0, 0, 0) @@ -200,8 +183,7 @@ def run_ldc(backend, compute_mlup=True): lower_bound, upper_bound, direction, - #half_way_bc.id, - full_way_bc.id, + half_way_bc.id, boundary_id, missing_mask, (0, 0, 0) @@ -215,8 +197,7 @@ def run_ldc(backend, compute_mlup=True): lower_bound, upper_bound, direction, - #half_way_bc.id, - full_way_bc.id, + half_way_bc.id, boundary_id, missing_mask, (0, 0, 0) @@ -230,34 +211,11 @@ def run_ldc(backend, compute_mlup=True): lower_bound, upper_bound, direction, - #half_way_bc.id, - full_way_bc.id, - boundary_id, - missing_mask, - (0, 0, 0) - ) - - # Set full way bc (sphere) - """ - sphere_radius = nr // 8 - x = np.arange(nr) - y = np.arange(nr) - z = np.arange(nr) - X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) - indices = np.array(indices).T - indices = wp.from_numpy(indices, dtype=wp.int32) - boundary_id, missing_mask = indices_boundary_masker( - indices, - full_way_bc.id, + half_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) - """ # Set initial conditions if backend == "warp": @@ -271,7 +229,7 @@ def run_ldc(backend, compute_mlup=True): plot_freq = 512 save_dir = "ldc" os.makedirs(save_dir, exist_ok=True) - num_steps = nr * 32 + num_steps = nr * 512 start = time.time() for _ in tqdm(range(num_steps)): @@ -288,20 +246,22 @@ def run_ldc(backend, compute_mlup=True): rho, u = macroscopic(f0, rho, u) local_rho = rho.numpy() local_u = u.numpy() + local_boundary_id = boundary_id.numpy() elif backend == "jax": local_rho, local_u = macroscopic(f0) + local_boundary_id = boundary_id # Plot the velocity field, rho and boundary id side by side plt.subplot(1, 3, 1) - plt.imshow(np.linalg.norm(u[:, :, nr // 2, :], axis=0)) + plt.imshow(np.linalg.norm(local_u[:, :, nr // 2, :], axis=0)) plt.colorbar() plt.subplot(1, 3, 2) - plt.imshow(rho[0, :, nr // 2, :]) + plt.imshow(local_rho[0, :, nr // 2, :]) plt.colorbar() plt.subplot(1, 3, 3) - plt.imshow(boundary_id[0, :, nr // 2, :]) + plt.imshow(local_boundary_id[0, :, nr // 2, :]) plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") + plt.savefig(f"{save_dir}/{backend}_{str(_).zfill(6)}.png") plt.close() wp.synchronize() @@ -314,6 +274,6 @@ def run_ldc(backend, compute_mlup=True): # Run the LDC example backends = ["warp", "jax"] - #backends = ["jax"] + compute_mlup = False for backend in backends: - run_ldc(backend, compute_mlup=True) + run_ldc(backend, compute_mlup=compute_mlup) diff --git a/examples/interfaces/taylor_green.py b/examples/interfaces/taylor_green.py index 4d97476..f842107 100644 --- a/examples/interfaces/taylor_green.py +++ b/examples/interfaces/taylor_green.py @@ -214,7 +214,7 @@ def run_taylor_green(backend, compute_mlup=True): if __name__ == "__main__": # Run Taylor-Green vortex on different backends - #backends = ["warp", "jax"] - backends = ["jax"] + backends = ["warp", "jax"] + #backends = ["jax"] for backend in backends: - run_taylor_green(backend, compute_mlup=False) + run_taylor_green(backend, compute_mlup=True) diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 4cd405c..832a8b2 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -41,41 +41,55 @@ def jax_implementation( mask, start_index=(0, 0, 0), ): - # Get plane dimensions + # TODO: Optimize this + + # x plane if direction[0] != 0: - dim = ( - upper_bound[1] - lower_bound[1] + 1, - upper_bound[2] - lower_bound[2] + 1, - ) + + # Set boundary id + boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(id_number) + + # Set mask + for l in range(self.velocity_set.q): + d_dot_c = ( + direction[0] * self.velocity_set.c[0, l] + + direction[1] * self.velocity_set.c[1, l] + + direction[2] * self.velocity_set.c[2, l] + ) + if d_dot_c >= 0: + mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(True) + + # y plane elif direction[1] != 0: - dim = ( - upper_bound[0] - lower_bound[0] + 1, - upper_bound[2] - lower_bound[2] + 1, - ) - elif direction[2] != 0: - dim = ( - upper_bound[0] - lower_bound[0] + 1, - upper_bound[1] - lower_bound[1] + 1, - ) - # Get the constants - _c = self.velocity_set.wp_c - _q = self.velocity_set.q + # Set boundary id + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(id_number) - # Get the mask - for i in range(dim[0]): - for j in range(dim[1]): - for k in range(_q): - d_dot_c = ( - direction[0] * _c[0, k] - + direction[1] * _c[1, k] - + direction[2] * _c[2, k] - ) - if d_dot_c >= 0: - mask[k, i, j] = True + # Set mask + for l in range(self.velocity_set.q): + d_dot_c = ( + direction[0] * self.velocity_set.c[0, l] + + direction[1] * self.velocity_set.c[1, l] + + direction[2] * self.velocity_set.c[2, l] + ) + if d_dot_c >= 0: + mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(True) + + # z plane + elif direction[2] != 0: - # Get the boundary id - boundary_id[:, :, :] = id_number + # Set boundary id + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(id_number) + + # Set mask + for l in range(self.velocity_set.q): + d_dot_c = ( + direction[0] * self.velocity_set.c[0, l] + + direction[1] * self.velocity_set.c[1, l] + + direction[2] * self.velocity_set.c[2, l] + ) + if d_dot_c >= 0: + mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(True) return boundary_id, mask diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index 63c32ee..9c6d56e 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -166,7 +166,10 @@ def kernel( _missing_mask[l] = wp.uint8(0) # Apply streaming boundary conditions - if _boundary_id == _equilibrium_bc: + if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + # Regular streaming + f_post_stream = self.stream.warp_functional(f_0, index) + elif _boundary_id == _equilibrium_bc: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index @@ -181,9 +184,6 @@ def kernel( f_post_stream = self.halfway_bounce_back_bc.warp_functional( f_0, _missing_mask, index ) - else: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) From 18abd12f515b9e04524f5d2bba39476fd0063f75 Mon Sep 17 00:00:00 2001 From: Oliver Date: Fri, 29 Mar 2024 19:51:51 -0700 Subject: [PATCH 024/144] Finalized changes --- examples/interfaces/ldc.py | 29 ++++++++++----- .../boundary_condition/boundary_condition.py | 6 ++-- .../boundary_condition/halfway_bounce_back.py | 2 +- .../boundary_masker/planar_boundary_masker.py | 36 +++++++++---------- xlb/operator/stream/stream.py | 6 ++-- 5 files changed, 45 insertions(+), 34 deletions(-) diff --git a/examples/interfaces/ldc.py b/examples/interfaces/ldc.py index d75ce8a..e5ca559 100644 --- a/examples/interfaces/ldc.py +++ b/examples/interfaces/ldc.py @@ -120,12 +120,18 @@ def run_ldc(backend, compute_mlup=True): precision_policy=precision_policy, compute_backend=compute_backend, ) + full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream, - boundary_conditions=[equilibrium_bc, half_way_bc], + #boundary_conditions=[equilibrium_bc, half_way_bc, full_way_bc], + boundary_conditions=[half_way_bc, full_way_bc, equilibrium_bc], ) planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( velocity_set=velocity_set, @@ -148,21 +154,22 @@ def run_ldc(backend, compute_mlup=True): ) # Set outlet bc (top x face) - lower_bound = (nr-1, 1, 1) - upper_bound = (nr-1, nr-1, nr-1) + lower_bound = (nr-1, 0, 0) + upper_bound = (nr-1, nr, nr) direction = (-1, 0, 0) boundary_id, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (bottom y face) - lower_bound = (1, 0, 1) + lower_bound = (0, 0, 0) upper_bound = (nr, 0, nr) direction = (0, 1, 0) boundary_id, missing_mask = planar_boundary_masker( @@ -170,13 +177,14 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (top y face) - lower_bound = (1, nr-1, 1) + lower_bound = (0, nr-1, 0) upper_bound = (nr, nr-1, nr) direction = (0, -1, 0) boundary_id, missing_mask = planar_boundary_masker( @@ -184,13 +192,14 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (bottom z face) - lower_bound = (1, 1, 0) + lower_bound = (0, 0, 0) upper_bound = (nr, nr, 0) direction = (0, 0, 1) boundary_id, missing_mask = planar_boundary_masker( @@ -198,13 +207,14 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) ) # Set half way bc (top z face) - lower_bound = (1, 1, nr-1) + lower_bound = (0, 0, nr-1) upper_bound = (nr, nr, nr-1) direction = (0, 0, -1) boundary_id, missing_mask = planar_boundary_masker( @@ -212,6 +222,7 @@ def run_ldc(backend, compute_mlup=True): upper_bound, direction, half_way_bc.id, + #full_way_bc.id, boundary_id, missing_mask, (0, 0, 0) @@ -226,10 +237,10 @@ def run_ldc(backend, compute_mlup=True): f0 = equilibrium(rho, u) # Time stepping - plot_freq = 512 + plot_freq = 128 save_dir = "ldc" os.makedirs(save_dir, exist_ok=True) - num_steps = nr * 512 + num_steps = nr * 16 start = time.time() for _ in tqdm(range(num_steps)): diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 7f7aadc..92a2c1f 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -6,7 +6,7 @@ from jax import jit, device_count from functools import partial import numpy as np -from enum import Enum +from enum import Enum, auto from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -16,8 +16,8 @@ # Enum for implementation step class ImplementationStep(Enum): - COLLISION = 1 - STREAMING = 2 + COLLISION = auto() + STREAMING = auto() class BoundaryCondition(Operator): diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index f8a4dd7..e47cc26 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -52,7 +52,7 @@ def __init__( def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): boundary = boundary_id == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return lax.select(missing_mask & boundary, f_pre[self.velocity_set.opp_indices], f_post) + return lax.select(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 832a8b2..572f345 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -47,7 +47,7 @@ def jax_implementation( if direction[0] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(id_number) + boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -57,13 +57,13 @@ def jax_implementation( + direction[2] * self.velocity_set.c[2, l] ) if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(True) + mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(True) # y plane elif direction[1] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(id_number) + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -73,13 +73,13 @@ def jax_implementation( + direction[2] * self.velocity_set.c[2, l] ) if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(True) + mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(True) # z plane elif direction[2] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(id_number) + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -89,7 +89,7 @@ def jax_implementation( + direction[2] * self.velocity_set.c[2, l] ) if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(True) + mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(True) return boundary_id, mask @@ -116,15 +116,15 @@ def kernel( # Get local indices if direction[0] != 0: i = lower_bound[0] - start_index[0] - j = plane_i - start_index[1] - k = plane_j - start_index[2] + j = plane_i + lower_bound[1] - start_index[1] + k = plane_j + lower_bound[2] - start_index[2] elif direction[1] != 0: - i = plane_i - start_index[0] + i = plane_i + lower_bound[0] - start_index[0] j = lower_bound[1] - start_index[1] - k = plane_j - start_index[2] + k = plane_j + lower_bound[2] - start_index[2] elif direction[2] != 0: - i = plane_i - start_index[0] - j = plane_j - start_index[1] + i = plane_i + lower_bound[0] - start_index[0] + j = plane_j + lower_bound[1] - start_index[1] k = lower_bound[2] - start_index[2] # Check if in bounds @@ -165,18 +165,18 @@ def warp_implementation( # Get plane dimensions if direction[0] != 0: dim = ( - upper_bound[1] - lower_bound[1] + 1, - upper_bound[2] - lower_bound[2] + 1, + upper_bound[1] - lower_bound[1], + upper_bound[2] - lower_bound[2], ) elif direction[1] != 0: dim = ( - upper_bound[0] - lower_bound[0] + 1, - upper_bound[2] - lower_bound[2] + 1, + upper_bound[0] - lower_bound[0], + upper_bound[2] - lower_bound[2], ) elif direction[2] != 0: dim = ( - upper_bound[0] - lower_bound[0] + 1, - upper_bound[1] - lower_bound[1] + 1, + upper_bound[0] - lower_bound[0], + upper_bound[1] - lower_bound[1], ) # Launch the warp kernel diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index c5fd16e..8bb2568 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -46,10 +46,10 @@ def _streaming_jax_i(f, c): """ if self.velocity_set.d == 2: return jnp.roll( - f, (-c[0], -c[1]), axis=(0, 1) - ) # Negative sign is used to pull the distribution instead of pushing + f, (c[0], c[1]), axis=(0, 1) + ) elif self.velocity_set.d == 3: - return jnp.roll(f, (-c[0], -c[1], -c[2]), axis=(0, 1, 2)) + return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T From 94fd5f0ed409b4fb4996b849616c31867668caa1 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 8 Apr 2024 09:57:52 -0700 Subject: [PATCH 025/144] added kbc --- examples/CFD_refactor/windtunnel3d.py | 592 ++++++++++++++---- xlb/operator/boundary_masker/__init__.py | 3 + .../boundary_masker/stl_boundary_masker.py | 126 ++-- xlb/operator/collision/kbc.py | 178 +++++- xlb/velocity_set/velocity_set.py | 13 +- 5 files changed, 716 insertions(+), 196 deletions(-) diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py index 9af75f3..156a219 100644 --- a/examples/CFD_refactor/windtunnel3d.py +++ b/examples/CFD_refactor/windtunnel3d.py @@ -1,168 +1,512 @@ +# Wind tunnel simulation using the XLB library + +from typing import Any import os import jax import trimesh from time import time import numpy as np -import jax.numpy as jnp -from jax import config +import warp as wp +import pyvista as pv +import tqdm +import matplotlib.pyplot as plt + +wp.init() + +import xlb +from xlb.operator import Operator + +class UniformInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + vel: float, + ): + # Get the global index + i, j, k = wp.tid() + + # Set the velocity + u[0, i, j, k] = vel + u[1, i, j, k] = 0.0 + u[2, i, j, k] = 0.0 + + # Set the density + rho[0, i, j, k] = 1.0 + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u, vel): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + vel, + ], + dim=rho.shape[1:], + ) + return rho, u + +class MomentumTransfer(Operator): + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Find velocity index for 0, 0, 0 + for l in range(self.velocity_set.q): + if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0: + zero_index = l + _zero_index = wp.int32(zero_index) + print(f"Zero index: {_zero_index}") + + # Construct the warp kernel + @wp.kernel + def kernel( + f: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + momentum: wp.array(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Determin if boundary is an edge by checking if center is missing + is_edge = wp.bool(False) + if _boundary_id == wp.uint8(xlb.operator.boundary_condition.HalfwayBounceBackBC.id): + if _missing_mask[_zero_index] != wp.uint8(1): + is_edge = wp.bool(True) + + # If the boundary is an edge then add the momentum transfer + m = wp.vec3() + if is_edge: + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = 2.0 * f[_opp_indices[l], index[0], index[1], index[2]] + + # Compute the momentum transfer + for d in range(self.velocity_set.d): + m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) -from xlb.solver import IncompressibleNavierStokesSolver -from xlb.velocity_set import D3Q27, D3Q19 -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy -from xlb.grid_backend import GridBackend -from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC + wp.atomic_add(momentum, 0, m) + return None, kernel + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, f, boundary_id, missing_mask): -class WindTunnel(IncompressibleNavierStokesSolver): + # Allocate the momentum field + momentum = wp.zeros((1), dtype=wp.vec3) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f, boundary_id, missing_mask, momentum], + dim=f.shape[1:], + ) + return momentum.numpy() + + +class WindTunnel: """ - This class extends the IncompressibleNavierStokesSolver class to define the boundary conditions for the wind tunnel simulation. - Units are in meters, seconds, and kilograms. + Wind tunnel simulation using the XLB library """ def __init__( self, - stl_filename: str - stl_center: tuple[float, float, float] = (0.0, 0.0, 0.0), # m - inlet_velocity: float = 27.78 # m/s + stl_filename: str, + inlet_velocity: float = 27.78, # m/s lower_bounds: tuple[float, float, float] = (0.0, 0.0, 0.0), # m upper_bounds: tuple[float, float, float] = (1.0, 0.5, 0.5), # m dx: float = 0.01, # m viscosity: float = 1.42e-5, # air at 20 degrees Celsius density: float = 1.2754, # kg/m^3 - collision="BGK", + solve_time: float = 1.0, # s + #collision="BGK", + collision="KBC", equilibrium="Quadratic", - velocity_set=D3Q27(), - precision_policy=PrecisionPolicy.FP32FP32, - compute_backend=ComputeBackend.JAX, - grid_backend=GridBackend.JAX, + velocity_set="D3Q27", + precision_policy=xlb.PrecisionPolicy.FP32FP32, + compute_backend=xlb.ComputeBackend.WARP, grid_configs={}, + save_state_frequency=1024, + monitor_frequency=32, ): # Set parameters + self.stl_filename = stl_filename self.inlet_velocity = inlet_velocity self.lower_bounds = lower_bounds self.upper_bounds = upper_bounds self.dx = dx + self.solve_time = solve_time self.viscosity = viscosity self.density = density + self.save_state_frequency = save_state_frequency + self.monitor_frequency = monitor_frequency # Get fluid properties needed for the simulation - self.velocity_conversion = 0.05 / inlet_velocity + self.base_velocity = 0.05 # LBM units + self.velocity_conversion = self.base_velocity / inlet_velocity self.dt = self.dx * self.velocity_conversion self.lbm_viscosity = self.viscosity * self.dt / (self.dx ** 2) self.tau = 0.5 + self.lbm_viscosity + self.omega = 1.0 / self.tau + print(f"tau: {self.tau}") + print(f"omega: {self.omega}") self.lbm_density = 1.0 self.mass_conversion = self.dx ** 3 * (self.density / self.lbm_density) + self.nr_steps = int(solve_time / self.dt) - # Make boundary conditions - - - # Initialize the IncompressibleNavierStokesSolver - super().__init__( - omega=self.tau, - shape=shape, - collision=collision, - equilibrium=equilibrium, - boundary_conditions=boundary_conditions, - initializer=initializer, - forcing=forcing, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - grid_backend=grid_backend, - grid_configs=grid_configs, - ) - - def voxelize_stl(self, stl_filename, length_lbm_unit): - mesh = trimesh.load_mesh(stl_filename, process=False) - length_phys_unit = mesh.extents.max() - pitch = length_phys_unit/length_lbm_unit - mesh_voxelized = mesh.voxelized(pitch=pitch) - mesh_matrix = mesh_voxelized.matrix - return mesh_matrix, pitch - - def set_boundary_conditions(self): - print('Voxelizing mesh...') - time_start = time() - stl_filename = 'stl-files/DrivAer-Notchback.stl' - car_length_lbm_unit = self.nx / 4 - car_voxelized, pitch = voxelize_stl(stl_filename, car_length_lbm_unit) - car_matrix = car_voxelized.matrix - print('Voxelization time for pitch={}: {} seconds'.format(pitch, time() - time_start)) - print("Car matrix shape: ", car_matrix.shape) - - self.car_area = np.prod(car_matrix.shape[1:]) - tx, ty, tz = np.array([nx, ny, nz]) - car_matrix.shape - shift = [tx//4, ty//2, 0] - car_indices = np.argwhere(car_matrix) + shift - self.BCs.append(BounceBackHalfway(tuple(car_indices.T), self.gridInfo, self.precisionPolicy)) - - wall = np.concatenate((self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'], - self.boundingBoxIndices['front'], self.boundingBoxIndices['back'])) - self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - - doNothing = self.boundingBoxIndices['right'] - self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) - self.BCs[-1].implementationStep = 'PostCollision' - # rho_outlet = np.ones(doNothing.shape[0], dtype=self.precisionPolicy.compute_dtype) - # self.BCs.append(ZouHe(tuple(doNothing.T), - # self.gridInfo, - # self.precisionPolicy, - # 'pressure', rho_outlet)) - - inlet = self.boundingBoxIndices['left'] - rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) - - vel_inlet[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) - # self.BCs.append(ZouHe(tuple(inlet.T), - # self.gridInfo, - # self.precisionPolicy, - # 'velocity', vel_inlet)) - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho'][..., 1:-1, 1:-1, :]) - u = np.array(kwargs['u'][..., 1:-1, 1:-1, :]) - timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'][..., 1:-1, 1:-1, :] - - # compute lift and drag over the car - car = self.BCs[0] - boundary_force = car.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(boundary_force, axis=0) - drag = np.sqrt(boundary_force[0]**2 + boundary_force[1]**2) #xy-plane - lift = boundary_force[2] #z-direction - cd = 2. * drag / (prescribed_vel ** 2 * self.car_area) - cl = 2. * lift / (prescribed_vel ** 2 * self.car_area) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) + # Get the grid shape + self.nx = int((upper_bounds[0] - lower_bounds[0]) / dx) + self.ny = int((upper_bounds[1] - lower_bounds[1]) / dx) + self.nz = int((upper_bounds[2] - lower_bounds[2]) / dx) + self.shape = (self.nx, self.ny, self.nz) -if __name__ == '__main__': - precision = 'f32/f32' - lattice = LatticeD3Q27(precision) + # Set the compute backend + self.compute_backend = xlb.ComputeBackend.WARP + + # Set the precision policy + self.precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set + if velocity_set == "D3Q27": + self.velocity_set = xlb.velocity_set.D3Q27() + elif velocity_set == "D3Q19": + self.velocity_set = xlb.velocity_set.D3Q19() + else: + raise ValueError("Invalid velocity set") + + # Make grid + self.grid = xlb.grid.WarpGrid(shape=self.shape) + + # Make feilds + self.rho = self.grid.create_field(cardinality=1, precision=xlb.Precision.FP32) + self.u = self.grid.create_field(cardinality=self.velocity_set.d, precision=xlb.Precision.FP32) + self.f0 = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.FP32) + self.f1 = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.FP32) + self.boundary_id = self.grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + self.missing_mask = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.BOOL) + + # Make operators + self.initializer = UniformInitializer( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.momentum_transfer = MomentumTransfer( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + if collision == "BGK": + self.collision = xlb.operator.collision.BGK( + omega=self.omega, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + elif collision == "KBC": + self.collision = xlb.operator.collision.KBC( + omega=self.omega, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.stream = xlb.operator.stream.Stream( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=self.lbm_density, + u=(self.base_velocity, 0.0, 0.0), + equilibrium_operator=self.equilibrium, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=self.collision, + equilibrium=self.equilibrium, + macroscopic=self.macroscopic, + stream=self.stream, + boundary_conditions=[ + self.half_way_bc, + self.full_way_bc, + self.equilibrium_bc, + self.do_nothing_bc + ], + ) + self.planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.stl_boundary_masker = xlb.operator.boundary_masker.STLBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + + # Make list to store drag coefficients + self.drag_coefficients = [] + + def initialize_flow(self): + """ + Initialize the flow field + """ + + # Set initial conditions + self.rho, self.u = self.initializer(self.rho, self.u, self.base_velocity) + self.f0 = self.equilibrium(self.rho, self.u, self.f0) + + def initialize_boundary_conditions(self): + """ + Initialize the boundary conditions + """ + + # Set inlet bc (bottom x face) + lower_bound = (0, 1, 1) # no edges + upper_bound = (0, self.ny-1, self.nz-1) + direction = (1, 0, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.equilibrium_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set outlet bc (top x face) + lower_bound = (self.nx-1, 1, 1) + upper_bound = (self.nx-1, self.ny-1, self.nz-1) + direction = (-1, 0, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.do_nothing_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (bottom y face) + lower_bound = (0, 0, 0) + upper_bound = (self.nx, 0, self.nz) + direction = (0, 1, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (top y face) + lower_bound = (0, self.ny-1, 0) + upper_bound = (self.nx, self.ny-1, self.nz) + direction = (0, -1, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (bottom z face) + lower_bound = (0, 0, 0) + upper_bound = (self.nx, self.ny, 0) + direction = (0, 0, 1) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (top z face) + lower_bound = (0, 0, self.nz-1) + upper_bound = (self.nx, self.ny, self.nz-1) + direction = (0, 0, -1) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set stl half way bc + self.boundary_id, self.missing_mask = self.stl_boundary_masker( + self.stl_filename, + self.lower_bounds, + (self.dx, self.dx, self.dx), + self.half_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + def save_state( + self, + postfix: str, + save_velocity_distribution: bool = False, + ): + """ + Save the solid id array. + """ + + # Create grid + grid = pv.RectilinearGrid( + np.linspace(self.lower_bounds[0], self.upper_bounds[0], self.nx, endpoint=False), + np.linspace(self.lower_bounds[1], self.upper_bounds[1], self.ny, endpoint=False), + np.linspace(self.lower_bounds[2], self.upper_bounds[2], self.nz, endpoint=False), + ) # TODO off by one? + grid["boundary_id"] = self.boundary_id.numpy().flatten("F") + grid["u"] = self.u.numpy().transpose(1, 2, 3, 0).reshape(-1, 3, order="F") + grid["rho"] = self.rho.numpy().flatten("F") + if save_velocity_distribution: + grid["f0"] = self.f0.numpy().transpose(1, 2, 3, 0).reshape(-1, self.velocity_set.q, order="F") + grid.save(f"state_{postfix}.vtk") + + def step(self): + self.f1 = self.stepper(self.f0, self.f1, self.boundary_id, self.missing_mask, 0) + self.f0, self.f1 = self.f1, self.f0 + + def compute_rho_u(self): + self.rho, self.u = self.macroscopic(self.f0, self.rho, self.u) + + def monitor(self): + # Compute the momentum transfer + momentum = self.momentum_transfer(self.f0, self.boundary_id, self.missing_mask)[0] + drag = momentum[0] + lift = momentum[2] + c_d = 2.0 * drag / (self.base_velocity ** 2 * self.cross_section) + c_l = 2.0 * lift / (self.base_velocity ** 2 * self.cross_section) + self.drag_coefficients.append(c_d) + + def plot_drag_coefficient(self): + plt.plot(self.drag_coefficients[-30:]) + plt.xlabel("Time step") + plt.ylabel("Drag coefficient") + plt.savefig("drag_coefficient.png") + plt.close() + + def run(self): - nx = 601 - ny = 351 - nz = 251 + # Initialize the flow field + self.initialize_flow() + + # Initialize the boundary conditions + self.initialize_boundary_conditions() + + # Compute cross section + np_boundary_id = self.boundary_id.numpy() + cross_section = np.sum(np_boundary_id == self.half_way_bc.id, axis=(0, 1)) + self.cross_section = np.sum(cross_section > 0) + + # Run the simulation + for i in tqdm.tqdm(range(self.nr_steps)): + + # Step + self.step() + + # Monitor + if i % self.monitor_frequency == 0: + self.monitor() + + # Save monitor plot + if i % (self.monitor_frequency * 10) == 0: + self.plot_drag_coefficient() + + # Save state + if i % self.save_state_frequency == 0: + self.compute_rho_u() + self.save_state(str(i).zfill(8)) + +if __name__ == '__main__': - Re = 50000.0 - prescribed_vel = 0.05 - clength = nx - 1 + # Parameters + inlet_velocity = 0.01 # m/s + stl_filename = "fastback_baseline.stl" + lower_bounds = (-4.0, -2.5, -1.5) + upper_bounds = (12.0, 2.5, 2.5) + dx = 0.03 + solve_time = 10000.0 - visc = prescribed_vel * clength / Re - omega = 1.0 / (3. * visc + 0.5) + # Make wind tunnel + wind_tunnel = WindTunnel( + stl_filename=stl_filename, + inlet_velocity=inlet_velocity, + lower_bounds=lower_bounds, + upper_bounds=upper_bounds, + solve_time=solve_time, + dx=dx, + ) - os.system('rm -rf ./*.vtk && rm -rf ./*.png') + # Run the simulation + wind_tunnel.run() + wind_tunnel.save_state("final", save_velocity_distribution=True) - sim = Car(**kwargs) - sim.run(200000) diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index a9069c6..f69252f 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -7,3 +7,6 @@ from xlb.operator.boundary_masker.planar_boundary_masker import ( PlanarBoundaryMasker, ) +from xlb.operator.boundary_masker.stl_boundary_masker import ( + STLBoundaryMasker, +) diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index cda8c00..148e9b8 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -2,6 +2,7 @@ from functools import partial import numpy as np +from stl import mesh as np_mesh import jax.numpy as jnp from jax import jit import warp as wp @@ -26,89 +27,106 @@ def __init__( precision_policy: PrecisionPolicy, compute_backend: ComputeBackend.JAX, ): + # Call super super().__init__(velocity_set, precision_policy, compute_backend) - # TODO: Implement this - raise NotImplementedError - - # Make stream operator - self.stream = Stream(velocity_set, precision_policy, compute_backend) - - @Operator.register_backend(ComputeBackend.JAX) - def jax_implementation( - self, mesh, id_number, boundary_id, mask, start_index=(0, 0, 0) - ): - # TODO: Implement this - raise NotImplementedError - def _construct_warp(self): # Make constants for warp - _opp_indices = wp.constant( - self._warp_int_lattice_vec(self.velocity_set.opp_indices) - ) + _c = self.velocity_set.wp_c _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - _id = wp.constant(self.id) # Construct the warp kernel @wp.kernel - def _voxelize_mesh( - voxels: wp.array3d(dtype=wp.uint8), + def kernel( mesh: wp.uint64, - spacing: wp.vec3, origin: wp.vec3, - shape: wp.vec(3, wp.uint32), - max_length: float, - material_id: int, + spacing: wp.vec3, + id_number: wp.int32, + boundary_id: wp.array4d(dtype=wp.uint8), + mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, ): - # get index of voxel + # get index i, j, k = wp.tid() - # position of voxel - ijk = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) + # Get local indices + index = wp.vec3i() + index[0] = i - start_index[0] + index[1] = j - start_index[1] + index[2] = k - start_index[2] + + # position of the point + ijk = wp.vec3( + wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]) + ) ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center pos = wp.cw_mul(ijk, spacing) + origin - # Only evaluate voxel if not set yet - if voxels[i, j, k] != wp.uint8(0): - return + # Compute the maximum length + max_length = wp.sqrt( + (spacing[0] * wp.float32(boundary_id.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_id.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_id.shape[3])) ** 2.0 + ) - # evaluate distance of point + # evaluate if point is inside mesh face_index = int(0) face_u = float(0.0) face_v = float(0.0) sign = float(0.0) - if wp.mesh_query_point( + if wp.mesh_query_point_sign_winding_number( mesh, pos, max_length, sign, face_index, face_u, face_v ): - p = wp.mesh_eval_position(mesh, face_index, face_u, face_v) - delta = pos - p - norm = wp.sqrt(wp.dot(delta, delta)) - # set point to be solid - if norm < wp.min(spacing): - voxels[i, j, k] = wp.uint8(255) - elif sign < 0: # TODO: fix this - voxels[i, j, k] = wp.uint8(material_id) - else: - pass + if sign <= 0: # TODO: fix this + # Stream indices + for l in range(_q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and mask + boundary_id[ + 0, push_index[0], push_index[1], push_index[2] + ] = wp.uint8(id_number) + mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, start_index, boundary_id, mask, id_number): - # Reuse the jax implementation, TODO: implement a warp version - # Convert to jax - boundary_id = wp.jax.to_jax(boundary_id) - mask = wp.jax.to_jax(mask) - - # Call jax implementation - boundary_id, mask = self.jax_implementation( - start_index, boundary_id, mask, id_number + def warp_implementation( + self, + stl_file, + origin, + spacing, + id_number, + boundary_id, + mask, + start_index=(0, 0, 0), + ): + # Load the mesh + mesh = np_mesh.Mesh.from_file(stl_file) + mesh_points = mesh.points.reshape(-1, 3) + mesh_indices = np.arange(mesh_points.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_points, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=int), ) - # Convert back to warp - boundary_id = wp.jax.to_warp(boundary_id) - mask = wp.jax.to_warp(mask) + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + mesh.id, + origin, + spacing, + id_number, + boundary_id, + mask, + start_index, + ], + dim=boundary_id.shape[1:], + ) return boundary_id, mask diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 8302978..f3c996b 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -4,11 +4,14 @@ import jax.numpy as jnp from jax import jit -from functools import partial -from xlb.operator import Operator +import warp as wp +from typing import Any + from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision +from xlb.operator import Operator +from functools import partial class KBC(Collision): @@ -25,15 +28,16 @@ def __init__( precision_policy=None, compute_backend=None, ): + self.epsilon = 1e-32 + self.beta = omega * 0.5 + self.inv_beta = 1.0 / self.beta + super().__init__( omega=omega, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, ) - self.epsilon = 1e-32 - self.beta = self.omega * 0.5 - self.inv_beta = 1.0 / self.beta @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) @@ -80,16 +84,6 @@ def jax_implementation( return fout - @Operator.register_backend(ComputeBackend.WARP) - @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) - def warp_implementation( - self, - f: jnp.ndarray, - feq: jnp.ndarray, - rho: jnp.ndarray, - ): - raise NotImplementedError("Warp implementation not yet implemented") - @partial(jit, static_argnums=(0,), inline=True) def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarray): """ @@ -208,3 +202,157 @@ def decompose_shear_d2q9_jax(self, fneq): s = s.at[7, ...].set(Pi[1, ...]) return s + + def _construct_warp(self): + # Raise error if velocity set is not supported + if not isinstance(self.velocity_set, D3Q27): + raise NotImplementedError( + "Velocity set not supported for warp backend: {}".format( + type(self.velocity_set) + ) + ) + + # Set local constants TODO: This is a hack and should be fixed with warp update + _w = self.velocity_set.wp_w + _cc = self.velocity_set.wp_cc + _omega = wp.constant(self.compute_dtype(self.omega)) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _pi_vec = wp.vec( + self.velocity_set.d * (self.velocity_set.d + 1) // 2, + dtype=self.compute_dtype, + ) + _epsilon = wp.constant(self.compute_dtype(self.epsilon)) + _beta = wp.constant(self.compute_dtype(self.beta)) + _inv_beta = wp.constant(self.compute_dtype(1.0 / self.beta)) + + # Construct functional for computing momentum flux + @wp.func + def momentum_flux( + fneq: Any, + ): + # Get momentum flux + pi = _pi_vec() + for d in range(6): + pi[d] = 0.0 + for q in range(self.velocity_set.q): + pi[d] += _cc[q, d] * fneq[q] + return pi + + # Construct functional for decomposing shear + @wp.func + def decompose_shear_d3q27( + fneq: Any, + ): + # Get momentum flux + pi = momentum_flux(fneq) + nxz = pi[0] - pi[5] + nyz = pi[3] - pi[5] + + # set shear components + s = _f_vec() + + # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) + s[9] = (2.0 * nxz - nyz) / 6.0 + s[18] = (2.0 * nxz - nyz) / 6.0 + s[3] = (-nxz + 2.0 * nyz) / 6.0 + s[6] = (-nxz + 2.0 * nyz) / 6.0 + s[1] = (-nxz - nyz) / 6.0 + s[2] = (-nxz - nyz) / 6.0 + + # For c = (i, j, 0) + s[12] = pi[1] / 4.0 + s[24] = pi[1] / 4.0 + s[21] = -pi[1] / 4.0 + s[15] = -pi[1] / 4.0 + + # For c = (i, 0, k) + s[10] = pi[2] / 4.0 + s[20] = pi[2] / 4.0 + s[19] = -pi[2] / 4.0 + s[11] = -pi[2] / 4.0 + + # For c = (0, j, k) + s[8] = pi[4] / 4.0 + s[4] = pi[4] / 4.0 + s[7] = -pi[4] / 4.0 + s[5] = -pi[4] / 4.0 + + return s + + # Construct functional for computing entropic scalar product + @wp.func + def entropic_scalar_product( + x: Any, + y: Any, + feq: Any, + ): + e = wp.cw_div(wp.cw_mul(x, y), feq) + e_sum = wp.float32(0.0) + for i in range(self.velocity_set.q): + e_sum += e[i] + return e_sum + + # Construct the functional + @wp.func + def functional( + f: Any, + feq: Any, + rho: Any, + u: Any, + ): + # Compute shear and delta_s + fneq = f - feq + shear = decompose_shear_d3q27(fneq) + delta_s = shear * rho # TODO: Check this + + # Perform collision + delta_h = fneq - delta_s + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( + delta_s, delta_h, feq + ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + fout = f - _beta * (2.0 * delta_s + gamma * delta_h) + + return fout + + # Construct the warp kernel + @wp.kernel + def kernel( + f: wp.array4d(dtype=Any), + feq: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this + + # Load needed values + _f = _f_vec() + _feq = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _feq[l] = feq[l, index[0], index[1], index[2]] + _u = self._warp_u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + + # Compute the collision + _fout = functional(_f, _feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1], index[2]] = _fout[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) + def warp_implementation( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + ): + raise NotImplementedError("Warp implementation not yet implemented") diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index a137b87..03395c8 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -8,6 +8,7 @@ import warp as wp + class VelocitySet(object): """ Base class for the velocity set of the Lattice Boltzmann Method (LBM), e.g. D2Q9, D3Q27, etc. @@ -46,9 +47,15 @@ def __init__(self, d, q, c, w): # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) - self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow - self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) - + self.wp_w = wp.constant( + wp.vec(self.q, dtype=wp.float32)(self.w) + ) # TODO: Make type optional somehow + self.wp_opp_indices = wp.constant( + wp.vec(self.q, dtype=wp.int32)(self.opp_indices) + ) + self.wp_cc = wp.constant( + wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc) + ) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) From 6836b0932516a068622d470f3443cc5e6aae0d07 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Tue, 9 Apr 2024 16:16:39 -0400 Subject: [PATCH 026/144] Fixed boundaries and added conformance test. --- examples/interfaces/boundary_conditions.py | 156 -- examples/refactor/example_basic.py | 4 +- examples/refactor/example_pallas_3d.py | 8 +- examples/refactor/mlups3d.py | 6 +- examples/refactor/mlups_pallas_3d.py | 6 +- examples/warp_backend/testing.py | 4 +- requirements.txt | 10 +- test.py | 40 - .../boundary_conditions.py | 218 +++ xlb/base.py | 1252 ----------------- xlb/grid/jax_grid.py | 9 +- .../boundary_condition/equilibrium.py | 16 +- .../indices_boundary_masker.py | 65 +- xlb/operator/stepper/stepper.py | 6 +- xlb/precision_policy/precision_policy.py | 22 +- 15 files changed, 281 insertions(+), 1541 deletions(-) delete mode 100644 examples/interfaces/boundary_conditions.py delete mode 100644 test.py create mode 100644 tests/backends_conformance/boundary_conditions.py delete mode 100644 xlb/base.py diff --git a/examples/interfaces/boundary_conditions.py b/examples/interfaces/boundary_conditions.py deleted file mode 100644 index 70648bf..0000000 --- a/examples/interfaces/boundary_conditions.py +++ /dev/null @@ -1,156 +0,0 @@ -# Simple script to run different boundary conditions with jax and warp backends -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import numpy as np -import jax.numpy as jnp -import warp as wp - -wp.init() - -import xlb - -def run_boundary_conditions(backend): - - # Set the compute backend - if backend == "warp": - compute_backend = xlb.ComputeBackend.WARP - elif backend == "jax": - compute_backend = xlb.ComputeBackend.JAX - - # Set the precision policy - precision_policy = xlb.PrecisionPolicy.FP32FP32 - - # Set the velocity set - velocity_set = xlb.velocity_set.D3Q19() - - # Make grid - nr = 256 - shape = (nr, nr, nr) - if backend == "jax": - grid = xlb.grid.JaxGrid(shape=shape) - elif backend == "warp": - grid = xlb.grid.WarpGrid(shape=shape) - - # Make feilds - f_pre = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - f_post = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) - - # Make needed operators - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=equilibrium, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - halfway_bounce_back_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - fullway_bounce_back_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - - # Make indices for boundary conditions (sphere) - sphere_radius = 32 - x = np.arange(nr) - y = np.arange(nr) - z = np.arange(nr) - X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) - indices = np.array(indices).T - if backend == "jax": - indices = jnp.array(indices) - elif backend == "warp": - indices = wp.from_numpy(indices, dtype=wp.int32) - - # Test equilibrium boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, - equilibrium_bc.id, - boundary_id, - missing_mask, - (0, 0, 0) - ) - if backend == "jax": - f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask) - elif backend == "warp": - f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask, f) - print(f"Equilibrium BC test passed for {backend}") - - # Test do nothing boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, - do_nothing_bc.id, - boundary_id, - missing_mask, - (0, 0, 0) - ) - if backend == "jax": - f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask) - elif backend == "warp": - f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask, f) - - # Test halfway bounce back boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, - halfway_bounce_back_bc.id, - boundary_id, - missing_mask, - (0, 0, 0) - ) - if backend == "jax": - f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask) - elif backend == "warp": - f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f) - print(f"Halfway bounce back BC test passed for {backend}") - - # Test the full boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, - fullway_bounce_back_bc.id, - boundary_id, - missing_mask, - (0, 0, 0) - ) - if backend == "jax": - f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask) - elif backend == "warp": - f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f) - print(f"Fullway bounce back BC test passed for {backend}") - - -if __name__ == "__main__": - - # Test the boundary conditions - backends = ["warp", "jax"] - for backend in backends: - run_boundary_conditions(backend) diff --git a/examples/refactor/example_basic.py b/examples/refactor/example_basic.py index 9c9b033..5f74d21 100644 --- a/examples/refactor/example_basic.py +++ b/examples/refactor/example_basic.py @@ -1,5 +1,5 @@ import xlb -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.precision_policy import Fp32Fp32 from xlb.solver import IncompressibleNavierStokes @@ -10,7 +10,7 @@ xlb.init( precision_policy=Fp32Fp32, - compute_backend=ComputeBackends.JAX, + compute_backend=ComputeBackend.JAX, velocity_set=xlb.velocity_set.D2Q9, ) diff --git a/examples/refactor/example_pallas_3d.py b/examples/refactor/example_pallas_3d.py index 09b0305..17084d4 100644 --- a/examples/refactor/example_pallas_3d.py +++ b/examples/refactor/example_pallas_3d.py @@ -1,5 +1,5 @@ import xlb -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.precision_policy import Fp32Fp32 from xlb.solver import IncompressibleNavierStokes @@ -13,7 +13,7 @@ # Initialize XLB with Pallas backend for 3D simulation xlb.init( precision_policy=Fp32Fp32, - compute_backend=ComputeBackends.PALLAS, # Changed to Pallas backend + compute_backend=ComputeBackend.PALLAS, # Changed to Pallas backend velocity_set=xlb.velocity_set.D3Q19, # Changed to D3Q19 for 3D ) @@ -45,7 +45,7 @@ def initializer(): rho = jnp.where(inside_sphere, rho.at[0, x, y, z].add(0.001), rho) - func_eq = QuadraticEquilibrium(compute_backend=ComputeBackends.JAX) + func_eq = QuadraticEquilibrium(compute_backend=ComputeBackend.JAX) f_eq = func_eq(rho, u) return f_eq @@ -53,7 +53,7 @@ def initializer(): f = initializer() -compute_macro = Macroscopic(compute_backend=ComputeBackends.JAX) +compute_macro = Macroscopic(compute_backend=ComputeBackend.JAX) solver = IncompressibleNavierStokes(grid, omega=1.0) diff --git a/examples/refactor/mlups3d.py b/examples/refactor/mlups3d.py index 2a37fdb..2e13769 100644 --- a/examples/refactor/mlups3d.py +++ b/examples/refactor/mlups3d.py @@ -2,7 +2,7 @@ import time import jax import argparse -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.precision_policy import Fp32Fp32 from xlb.operator.initializer import EquilibriumInitializer @@ -23,7 +23,7 @@ xlb.init( precision_policy=Fp32Fp32, - compute_backend=ComputeBackends.PALLAS, + compute_backend=ComputeBackend.PALLAS, velocity_set=xlb.velocity_set.D3Q19, ) @@ -37,7 +37,7 @@ solver = IncompressibleNavierStokes(grid, omega=1.0) # Ahead-of-Time Compilation to remove JIT overhead -# if xlb.current_backend() == ComputeBackends.JAX or xlb.current_backend() == ComputeBackends.PALLAS: +# if xlb.current_backend() == ComputeBackend.JAX or xlb.current_backend() == ComputeBackend.PALLAS: # lowered = jax.jit(solver.step).lower(f, timestep=0) # solver_step_compiled = lowered.compile() diff --git a/examples/refactor/mlups_pallas_3d.py b/examples/refactor/mlups_pallas_3d.py index 4c8ff50..71715d9 100644 --- a/examples/refactor/mlups_pallas_3d.py +++ b/examples/refactor/mlups_pallas_3d.py @@ -1,7 +1,7 @@ import xlb import time import argparse -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.precision_policy import Fp32Fp32 from xlb.solver import IncompressibleNavierStokes from xlb.grid import Grid @@ -23,7 +23,7 @@ # Initialize XLB xlb.init( precision_policy=Fp32Fp32, - compute_backend=ComputeBackends.PALLAS, + compute_backend=ComputeBackend.PALLAS, velocity_set=xlb.velocity_set.D3Q19, ) @@ -56,7 +56,7 @@ def initializer(): rho = jnp.where(inside_sphere, rho.at[0, x, y, z].add(0.001), rho) - func_eq = QuadraticEquilibrium(compute_backend=ComputeBackends.JAX) + func_eq = QuadraticEquilibrium(compute_backend=ComputeBackend.JAX) f_eq = func_eq(rho, u) return f_eq diff --git a/examples/warp_backend/testing.py b/examples/warp_backend/testing.py index 3940378..a20feb3 100644 --- a/examples/warp_backend/testing.py +++ b/examples/warp_backend/testing.py @@ -98,11 +98,11 @@ def test_backends(compute_backend): if __name__ == "__main__": # Test backends - compute_backends = [ + compute_backend = [ xlb.ComputeBackend.WARP, xlb.ComputeBackend.JAX ] - for compute_backend in compute_backends: + for compute_backend in compute_backend: test_backends(compute_backend) print(f"Backend {compute_backend} passed all tests.") diff --git a/requirements.txt b/requirements.txt index 11ee0fd..c8c6fc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,13 @@ jax==0.4.20 jaxlib==0.4.20 -jmp==0.0.4 matplotlib==3.8.0 numpy==1.26.1 -pyvista==0.42.3 +pyvista==0.43.4 Rtree==1.0.1 -trimesh==4.0.0 +trimesh==4.2.4 orbax-checkpoint==0.4.1 termcolor==2.3.0 -PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git \ No newline at end of file +PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git +tqdm==4.66.2 +warp-lang==1.0.2 +numpy-stl==3.1.1 \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index 5529ac3..0000000 --- a/test.py +++ /dev/null @@ -1,40 +0,0 @@ -import jax.numpy as jnp -from xlb.operator.macroscopic import Macroscopic -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.stream import Stream -from xlb.velocity_set import D2Q9, D3Q27 -from xlb.operator.collision.kbc import KBC, BGK -from xlb.compute_backends import ComputeBackends -from xlb.grid import Grid -import xlb - - -xlb.init(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) - -collision = BGK(omega=0.6) - -# eq = QuadraticEquilibrium(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) - -# macro = Macroscopic(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) - -# s = Stream(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) - -Q = 19 -# create random jnp arrays -f = jnp.ones((Q, 10, 10)) -rho = jnp.ones((1, 10, 10)) -u = jnp.zeros((2, 10, 10)) -# feq = eq(rho, u) - -print(collision(f, f)) - -grid = Grid.create(grid_shape=(10, 10), velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) - -def advection_result(index): - return 1.0 - - -f = grid.initialize_pop(advection_result) - -print(f) -print(f.sharding) \ No newline at end of file diff --git a/tests/backends_conformance/boundary_conditions.py b/tests/backends_conformance/boundary_conditions.py new file mode 100644 index 0000000..e75e249 --- /dev/null +++ b/tests/backends_conformance/boundary_conditions.py @@ -0,0 +1,218 @@ +import unittest +import numpy as np +import jax.numpy as jnp +import warp as wp +import xlb + +wp.init() + + +class TestBoundaryConditions(unittest.TestCase): + def setUp(self): + self.backends = ["warp", "jax"] + self.results = {} + + def run_boundary_conditions(self, backend): + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX + + # Set the precision policy + precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set + velocity_set = xlb.velocity_set.D3Q19() + + # Make grid + nr = 128 + shape = (nr, nr, nr) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make fields + f_pre = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.FP32 + ) + f_post = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.FP32 + ) + f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.BOOL + ) + + # Make needed operators + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + halfway_bounce_back_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + fullway_bounce_back_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Make indices for boundary conditions (sphere) + sphere_radius = 10 + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + indices = np.array(indices).T + if backend == "jax": + indices = jnp.array(indices) + elif backend == "warp": + indices = wp.from_numpy(indices, dtype=wp.int32) + + # Test equilibrium boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, equilibrium_bc.id, boundary_id, missing_mask, (0, 0, 0) + ) + if backend == "jax": + f_equilibrium = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f_equilibrium = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.FP32 + ) + f_equilibrium = equilibrium_bc( + f_pre, f_post, boundary_id, missing_mask, f_equilibrium + ) + + # Test do nothing boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, do_nothing_bc.id, boundary_id, missing_mask, (0, 0, 0) + ) + if backend == "jax": + f_do_nothing = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f_do_nothing = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.FP32 + ) + f_do_nothing = do_nothing_bc( + f_pre, f_post, boundary_id, missing_mask, f_do_nothing + ) + + # Test halfway bounce back boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, halfway_bounce_back_bc.id, boundary_id, missing_mask, (0, 0, 0) + ) + if backend == "jax": + f_halfway_bounce_back = halfway_bounce_back_bc( + f_pre, f_post, boundary_id, missing_mask + ) + elif backend == "warp": + f_halfway_bounce_back = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.FP32 + ) + f_halfway_bounce_back = halfway_bounce_back_bc( + f_pre, f_post, boundary_id, missing_mask, f_halfway_bounce_back + ) + + # Test the full boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, fullway_bounce_back_bc.id, boundary_id, missing_mask, (0, 0, 0) + ) + if backend == "jax": + f_fullway_bounce_back = fullway_bounce_back_bc( + f_pre, f_post, boundary_id, missing_mask + ) + elif backend == "warp": + f_fullway_bounce_back = grid.create_field( + cardinality=velocity_set.q, precision=xlb.Precision.FP32 + ) + f_fullway_bounce_back = fullway_bounce_back_bc( + f_pre, f_post, boundary_id, missing_mask, f_fullway_bounce_back + ) + + return f_equilibrium, f_do_nothing, f_halfway_bounce_back, f_fullway_bounce_back + + def test_boundary_conditions(self): + for backend in self.backends: + ( + f_equilibrium, + f_do_nothing, + f_halfway_bounce_back, + f_fullway_bounce_back, + ) = self.run_boundary_conditions(backend) + self.results[backend] = { + "equilibrium": np.array(f_equilibrium) + if backend == "jax" + else f_equilibrium.numpy(), + "do_nothing": np.array(f_do_nothing) + if backend == "jax" + else f_do_nothing.numpy(), + "halfway_bounce_back": np.array(f_halfway_bounce_back) + if backend == "jax" + else f_halfway_bounce_back.numpy(), + "fullway_bounce_back": np.array(f_fullway_bounce_back) + if backend == "jax" + else f_fullway_bounce_back.numpy(), + } + + for test_name in [ + "equilibrium", + "do_nothing", + "halfway_bounce_back", + "fullway_bounce_back", + ]: + with self.subTest(test_name=test_name): + warp_results = self.results["warp"][test_name] + jax_results = self.results["jax"][test_name] + + is_close = np.allclose(warp_results, jax_results, atol=1e-8, rtol=1e-5) + if not is_close: + diff_indices = np.where( + ~np.isclose(warp_results, jax_results, atol=1e-8, rtol=1e-5) + ) + differences = [ + (idx, warp_results[idx], jax_results[idx]) + for idx in zip(*diff_indices) + ] + difference_str = "\n".join( + [ + f"Index: {idx}, Warp: {w}, JAX: {j}" + for idx, w, j in differences + ] + ) + msg = f"{test_name} test failed: results do not match between backends. Differences:\n{difference_str}" + else: + msg = "" + + self.assertTrue(is_close, msg=msg) + + +if __name__ == "__main__": + unittest.main() diff --git a/xlb/base.py b/xlb/base.py deleted file mode 100644 index fa99d1d..0000000 --- a/xlb/base.py +++ /dev/null @@ -1,1252 +0,0 @@ -# Standard Libraries -import os -import time - -# Third-Party Libraries -import jax -import jax.numpy as jnp -import jmp -import numpy as np -from termcolor import colored - -# JAX-related imports -from jax import jit, lax, vmap -from jax.experimental import mesh_utils -from jax.experimental.multihost_utils import process_allgather -from jax.experimental.shard_map import shard_map -from jax.sharding import NamedSharding, PartitionSpec, PositionalSharding, Mesh -import orbax.checkpoint as orb - -# functools imports -from functools import partial - -# Local/Custom Libraries -from src.utils import downsample_field - -jax.config.update("jax_spmd_mode", "allow_all") -# Disables annoying TF warnings -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - - -class LBMBase(object): - """ - LBMBase: A class that represents a base for Lattice Boltzmann Method simulation. - - Parameters - ---------- - lattice (object): The lattice object that contains the lattice structure and weights. - omega (float): The relaxation parameter for the LBM simulation. - nx (int): Number of grid points in the x-direction. - ny (int): Number of grid points in the y-direction. - nz (int, optional): Number of grid points in the z-direction. Defaults to 0. - precision (str, optional): A string specifying the precision used for the simulation. Defaults to "f32/f32". - """ - - def __init__(self, **kwargs): - self.omega = kwargs.get("omega") - self.nx = kwargs.get("nx") - self.ny = kwargs.get("ny") - self.nz = kwargs.get("nz") - - self.precision = kwargs.get("precision") - computedType, storedType = self.set_precisions(self.precision) - self.precisionPolicy = jmp.Policy( - compute_dtype=computedType, - param_dtype=computedType, - output_dtype=storedType, - ) - - self.lattice = kwargs.get("lattice") - self.checkpointRate = kwargs.get("checkpoint_rate", 0) - self.checkpointDir = kwargs.get("checkpoint_dir", "./checkpoints") - self.downsamplingFactor = kwargs.get("downsampling_factor", 1) - self.printInfoRate = kwargs.get("print_info_rate", 100) - self.ioRate = kwargs.get("io_rate", 0) - self.returnFpost = kwargs.get("return_fpost", False) - self.computeMLUPS = kwargs.get("compute_MLUPS", False) - self.restore_checkpoint = kwargs.get("restore_checkpoint", False) - self.nDevices = jax.device_count() - self.backend = jax.default_backend() - - if self.computeMLUPS: - self.restore_checkpoint = False - self.ioRate = 0 - self.checkpointRate = 0 - self.printInfoRate = 0 - - # Check for distributed mode - if self.nDevices > jax.local_device_count(): - print( - "WARNING: Running in distributed mode. Make sure that jax.distributed.initialize is called before performing any JAX computations." - ) - - self.c = self.lattice.c - self.q = self.lattice.q - self.w = self.lattice.w - self.dim = self.lattice.d - - # Set the checkpoint manager - if self.checkpointRate > 0: - mngr_options = orb.CheckpointManagerOptions( - save_interval_steps=self.checkpointRate, max_to_keep=1 - ) - self.mngr = orb.CheckpointManager( - self.checkpointDir, orb.PyTreeCheckpointer(), options=mngr_options - ) - else: - self.mngr = None - - # Adjust the number of grid points in the x direction, if necessary. - # If the number of grid points is not divisible by the number of devices - # it increases the number of grid points to the next multiple of the number of devices. - # This is done in order to accommodate the domain sharding per XLA device - nx, ny, nz = kwargs.get("nx"), kwargs.get("ny"), kwargs.get("nz") - if None in {nx, ny, nz}: - raise ValueError( - "nx, ny, and nz must be provided. For 2D examples, nz must be set to 0." - ) - self.nx = nx - if nx % self.nDevices: - self.nx = nx + (self.nDevices - nx % self.nDevices) - print( - "WARNING: nx increased from {} to {} in order to accommodate domain sharding per XLA device.".format( - nx, self.nx - ) - ) - self.ny = ny - self.nz = nz - - self.show_simulation_parameters() - - # Store grid information - self.gridInfo = { - "nx": self.nx, - "ny": self.ny, - "nz": self.nz, - "dim": self.lattice.d, - "lattice": self.lattice, - } - - P = PartitionSpec - - # Define the right permutation - self.rightPerm = [(i, (i + 1) % self.nDevices) for i in range(self.nDevices)] - # Define the left permutation - self.leftPerm = [((i + 1) % self.nDevices, i) for i in range(self.nDevices)] - - # Set up the sharding and streaming for 2D and 3D simulations - if self.dim == 2: - self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1)) - self.mesh = Mesh(self.devices, axis_names=("x", "y", "value")) - self.sharding = NamedSharding(self.mesh, P("x", "y", "value")) - - self.streaming = jit( - shard_map( - self.streaming_m, - mesh=self.mesh, - in_specs=P("x", None, None), - out_specs=P("x", None, None), - check_rep=False, - ) - ) - - # Set up the sharding and streaming for 2D and 3D simulations - elif self.dim == 3: - self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1, 1)) - self.mesh = Mesh(self.devices, axis_names=("x", "y", "z", "value")) - self.sharding = NamedSharding(self.mesh, P("x", "y", "z", "value")) - - self.streaming = jit( - shard_map( - self.streaming_m, - mesh=self.mesh, - in_specs=P("x", None, None, None), - out_specs=P("x", None, None, None), - check_rep=False, - ) - ) - - else: - raise ValueError(f"dim = {self.dim} not supported") - - # Compute the bounding box indices for boundary conditions - self.boundingBoxIndices = self.bounding_box_indices() - # Create boundary data for the simulation - self._create_boundary_data() - self.force = self.get_force() - - @property - def lattice(self): - return self._lattice - - @lattice.setter - def lattice(self, value): - if value is None: - raise ValueError("Lattice type must be provided.") - if self.nz == 0 and value.name not in ["D2Q9"]: - raise ValueError("For 2D simulations, lattice type must be LatticeD2Q9.") - if self.nz != 0 and value.name not in ["D3Q19", "D3Q27"]: - raise ValueError( - "For 3D simulations, lattice type must be LatticeD3Q19, or LatticeD3Q27." - ) - - self._lattice = value - - @property - def omega(self): - return self._omega - - @omega.setter - def omega(self, value): - if value is None: - raise ValueError("omega must be provided") - if not isinstance(value, float): - raise TypeError("omega must be a float") - self._omega = value - - @property - def nx(self): - return self._nx - - @nx.setter - def nx(self, value): - if value is None: - raise ValueError("nx must be provided") - if not isinstance(value, int): - raise TypeError("nx must be an integer") - self._nx = value - - @property - def ny(self): - return self._ny - - @ny.setter - def ny(self, value): - if value is None: - raise ValueError("ny must be provided") - if not isinstance(value, int): - raise TypeError("ny must be an integer") - self._ny = value - - @property - def nz(self): - return self._nz - - @nz.setter - def nz(self, value): - if value is None: - raise ValueError("nz must be provided") - if not isinstance(value, int): - raise TypeError("nz must be an integer") - self._nz = value - - @property - def precision(self): - return self._precision - - @precision.setter - def precision(self, value): - if not isinstance(value, str): - raise TypeError("precision must be a string") - self._precision = value - - @property - def checkpointRate(self): - return self._checkpointRate - - @checkpointRate.setter - def checkpointRate(self, value): - if not isinstance(value, int): - raise TypeError("checkpointRate must be an integer") - self._checkpointRate = value - - @property - def checkpointDir(self): - return self._checkpointDir - - @checkpointDir.setter - def checkpointDir(self, value): - if not isinstance(value, str): - raise TypeError("checkpointDir must be a string") - self._checkpointDir = value - - @property - def downsamplingFactor(self): - return self._downsamplingFactor - - @downsamplingFactor.setter - def downsamplingFactor(self, value): - if not isinstance(value, int): - raise TypeError("downsamplingFactor must be an integer") - self._downsamplingFactor = value - - @property - def printInfoRate(self): - return self._printInfoRate - - @printInfoRate.setter - def printInfoRate(self, value): - if not isinstance(value, int): - raise TypeError("printInfoRate must be an integer") - self._printInfoRate = value - - @property - def ioRate(self): - return self._ioRate - - @ioRate.setter - def ioRate(self, value): - if not isinstance(value, int): - raise TypeError("ioRate must be an integer") - self._ioRate = value - - @property - def returnFpost(self): - return self._returnFpost - - @returnFpost.setter - def returnFpost(self, value): - if not isinstance(value, bool): - raise TypeError("returnFpost must be a boolean") - self._returnFpost = value - - @property - def computeMLUPS(self): - return self._computeMLUPS - - @computeMLUPS.setter - def computeMLUPS(self, value): - if not isinstance(value, bool): - raise TypeError("computeMLUPS must be a boolean") - self._computeMLUPS = value - - @property - def restore_checkpoint(self): - return self._restore_checkpoint - - @restore_checkpoint.setter - def restore_checkpoint(self, value): - if not isinstance(value, bool): - raise TypeError("restore_checkpoint must be a boolean") - self._restore_checkpoint = value - - @property - def nDevices(self): - return self._nDevices - - @nDevices.setter - def nDevices(self, value): - if not isinstance(value, int): - raise TypeError("nDevices must be an integer") - self._nDevices = value - - def show_simulation_parameters(self): - attributes_to_show = [ - "omega", - "nx", - "ny", - "nz", - "dim", - "precision", - "lattice", - "checkpointRate", - "checkpointDir", - "downsamplingFactor", - "printInfoRate", - "ioRate", - "computeMLUPS", - "restore_checkpoint", - "backend", - "nDevices", - ] - - descriptive_names = { - "omega": "Omega", - "nx": "Grid Points in X", - "ny": "Grid Points in Y", - "nz": "Grid Points in Z", - "dim": "Dimensionality", - "precision": "Precision Policy", - "lattice": "Lattice Type", - "checkpointRate": "Checkpoint Rate", - "checkpointDir": "Checkpoint Directory", - "downsamplingFactor": "Downsampling Factor", - "printInfoRate": "Print Info Rate", - "ioRate": "I/O Rate", - "computeMLUPS": "Compute MLUPS", - "restore_checkpoint": "Restore Checkpoint", - "backend": "Backend", - "nDevices": "Number of Devices", - } - simulation_name = self.__class__.__name__ - - print( - colored(f"**** Simulation Parameters for {simulation_name} ****", "green") - ) - - header = f"{colored('Parameter', 'blue'):>30} | {colored('Value', 'yellow')}" - print(header) - print("-" * 50) - - for attr in attributes_to_show: - value = getattr(self, attr, "Attribute not set") - descriptive_name = descriptive_names.get( - attr, attr - ) # Use the attribute name as a fallback - row = ( - f"{colored(descriptive_name, 'blue'):>30} | {colored(value, 'yellow')}" - ) - print(row) - - def _create_boundary_data(self): - """ - Create boundary data for the Lattice Boltzmann simulation by setting boundary conditions, - creating grid mask, and preparing local masks and normal arrays. - """ - self.BCs = [] - self.set_boundary_conditions() - # Accumulate the indices of all BCs to create the grid mask with FALSE along directions that - # stream into a boundary voxel. - solid_halo_list = [np.array(bc.indices).T for bc in self.BCs if bc.isSolid] - solid_halo_voxels = ( - np.unique(np.vstack(solid_halo_list), axis=0) if solid_halo_list else None - ) - - # Create the grid mask on each process - start = time.time() - grid_mask = self.create_grid_mask(solid_halo_voxels) - print("Time to create the grid mask:", time.time() - start) - - start = time.time() - for bc in self.BCs: - assert bc.implementationStep in ["PostStreaming", "PostCollision"] - bc.create_local_mask_and_normal_arrays(grid_mask) - print("Time to create the local masks and normal arrays:", time.time() - start) - - # This is another non-JITed way of creating the distributed arrays. It is not used at the moment. - # def distributed_array_init(self, shape, type, init_val=None): - # sharding_dim = shape[0] // self.nDevices - # sharded_shape = (self.nDevices, sharding_dim, *shape[1:]) - # device_shape = sharded_shape[1:] - # arrays = [] - - # for d, index in self.sharding.addressable_devices_indices_map(sharded_shape).items(): - # jax.default_device = d - # if init_val is None: - # x = jnp.zeros(shape=device_shape, dtype=type) - # else: - # x = jnp.full(shape=device_shape, fill_value=init_val, dtype=type) - # arrays += [jax.device_put(x, d)] - # jax.default_device = jax.devices()[0] - # return jax.make_array_from_single_device_arrays(shape, self.sharding, arrays) - - @partial(jit, static_argnums=(0, 1, 2, 4)) - def distributed_array_init(self, shape, type, init_val=0, sharding=None): - """ - Initialize a distributed array using JAX, with a specified shape, data type, and initial value. - Optionally, provide a custom sharding strategy. - - Parameters - ---------- - shape (tuple): The shape of the array to be created. - type (dtype): The data type of the array to be created. - init_val (scalar, optional): The initial value to fill the array with. Defaults to 0. - sharding (Sharding, optional): The sharding strategy to use. Defaults to `self.sharding`. - - Returns - ------- - jax.numpy.ndarray: A JAX array with the specified shape, data type, initial value, and sharding strategy. - """ - if sharding is None: - sharding = self.sharding - x = jnp.full(shape=shape, fill_value=init_val, dtype=type) - return jax.lax.with_sharding_constraint(x, sharding) - - @partial(jit, static_argnums=(0,)) - def create_grid_mask(self, solid_halo_voxels): - """ - This function creates a mask for the background grid that accounts for the location of the boundaries. - - Parameters - ---------- - solid_halo_voxels: A numpy array representing the voxels in the halo of the solid object. - - Returns - ------- - A JAX array representing the grid mask of the grid. - """ - # Halo width (hw_x is different to accommodate the domain sharding per XLA device) - hw_x = self.nDevices - hw_y = hw_z = 1 - if self.dim == 2: - grid_mask = self.distributed_array_init( - (self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), - jnp.bool_, - init_val=True, - ) - grid_mask = grid_mask.at[ - (slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None)) - ].set(False) - if solid_halo_voxels is not None: - solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) - solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) - grid_mask = grid_mask.at[tuple(solid_halo_voxels.T)].set(True) - - grid_mask = self.streaming(grid_mask) - return lax.with_sharding_constraint(grid_mask, self.sharding) - - elif self.dim == 3: - grid_mask = self.distributed_array_init( - ( - self.nx + 2 * hw_x, - self.ny + 2 * hw_y, - self.nz + 2 * hw_z, - self.lattice.q, - ), - jnp.bool_, - init_val=True, - ) - grid_mask = grid_mask.at[ - ( - slice(hw_x, -hw_x), - slice(hw_y, -hw_y), - slice(hw_z, -hw_z), - slice(None), - ) - ].set(False) - if solid_halo_voxels is not None: - solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x) - solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y) - solid_halo_voxels = solid_halo_voxels.at[:, 2].add(hw_z) - grid_mask = grid_mask.at[tuple(solid_halo_voxels.T)].set(True) - grid_mask = self.streaming(grid_mask) - return lax.with_sharding_constraint(grid_mask, self.sharding) - - def bounding_box_indices(self): - """ - This function calculates the indices of the bounding box of a 2D or 3D grid. - The bounding box is defined as the set of grid points on the outer edge of the grid. - - Returns - ------- - boundingBox (dict): A dictionary where keys are the names of the bounding box faces - ("bottom", "top", "left", "right" for 2D; additional "front", "back" for 3D), and values - are numpy arrays of indices corresponding to each face. - """ - if self.dim == 2: - # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. - # Each edge is represented as an array of indices. For example, the bottom edge includes - # all points where the y-coordinate is 0, so its indices are [[i, 0] for i in range(self.nx)]. - bounding_box = { - "bottom": np.array([[i, 0] for i in range(self.nx)], dtype=int), - "top": np.array([[i, self.ny - 1] for i in range(self.nx)], dtype=int), - "left": np.array([[0, i] for i in range(self.ny)], dtype=int), - "right": np.array( - [[self.nx - 1, i] for i in range(self.ny)], dtype=int - ), - } - - return bounding_box - - elif self.dim == 3: - # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. - # Each face is represented as an array of indices. For example, the bottom face includes all points - # where the z-coordinate is 0, so its indices are [[i, j, 0] for i in range(self.nx) for j in range(self.ny)]. - bounding_box = { - "bottom": np.array( - [[i, j, 0] for i in range(self.nx) for j in range(self.ny)], - dtype=int, - ), - "top": np.array( - [ - [i, j, self.nz - 1] - for i in range(self.nx) - for j in range(self.ny) - ], - dtype=int, - ), - "left": np.array( - [[0, j, k] for j in range(self.ny) for k in range(self.nz)], - dtype=int, - ), - "right": np.array( - [ - [self.nx - 1, j, k] - for j in range(self.ny) - for k in range(self.nz) - ], - dtype=int, - ), - "front": np.array( - [[i, 0, k] for i in range(self.nx) for k in range(self.nz)], - dtype=int, - ), - "back": np.array( - [ - [i, self.ny - 1, k] - for i in range(self.nx) - for k in range(self.nz) - ], - dtype=int, - ), - } - - return bounding_box - - def set_precisions(self, precision): - """ - This function sets the precision of the computations. The precision is defined by a pair of values, - representing the precision of the computation and the precision of the storage, respectively. - - Parameters - ---------- - precision (str): A string representing the desired precision. The string should be in the format - "computation/storage", where "computation" and "storage" are either "f64", "f32", or "f16", - representing 64-bit, 32-bit, or 16-bit floating point numbers, respectively. - - Returns - ------- - tuple: A pair of jax.numpy data types representing the computation and storage precisions, respectively. - If the input string does not match any of the predefined options, the function defaults to (jnp.float32, jnp.float32). - """ - return { - "f64/f64": (jnp.float64, jnp.float64), - "f32/f32": (jnp.float32, jnp.float32), - "f32/f16": (jnp.float32, jnp.float16), - "f16/f16": (jnp.float16, jnp.float16), - "f64/f32": (jnp.float64, jnp.float32), - "f64/f16": (jnp.float64, jnp.float16), - }.get(precision, (jnp.float32, jnp.float32)) - - def initialize_macroscopic_fields(self): - """ - This function initializes the macroscopic fields (density and velocity) to their default values. - The default density is 1 and the default velocity is 0. - - Note: This function is a placeholder and should be overridden in a subclass or in an instance of the class - to provide specific initial conditions. - - Returns - ------- - None, None: The default density and velocity, both None. This indicates that the actual values should be set elsewhere. - """ - print("WARNING: Default initial conditions assumed: density = 1, velocity = 0") - print( - " To set explicit initial density and velocity, use self.initialize_macroscopic_fields." - ) - return None, None - - def assign_fields_sharded(self): - """ - This function is used to initialize the simulation by assigning the macroscopic fields and populations. - - The function first initializes the macroscopic fields, which are the density (rho0) and velocity (u0). - Depending on the dimension of the simulation (2D or 3D), it then sets the shape of the array that will hold the - distribution functions (f). - - If the density or velocity are not provided, the function initializes the distribution functions with a default - value (self.w), representing density=1 and velocity=0. Otherwise, it uses the provided density and velocity to initialize the populations. - - Parameters - ---------- - None - - Returns - ------- - f: a distributed JAX array of shape (nx, ny, nz, q) or (nx, ny, q) holding the distribution functions for the simulation. - """ - rho0, u0 = self.initialize_macroscopic_fields() - - if self.dim == 2: - shape = (self.nx, self.ny, self.lattice.q) - if self.dim == 3: - shape = (self.nx, self.ny, self.nz, self.lattice.q) - - if rho0 is None or u0 is None: - f = self.distributed_array_init( - shape, self.precisionPolicy.output_dtype, init_val=self.w - ) - else: - f = self.initialize_populations(rho0, u0) - - return f - - def initialize_populations(self, rho0, u0): - """ - This function initializes the populations (distribution functions) for the simulation. - It uses the equilibrium distribution function, which is a function of the macroscopic - density and velocity. - - Parameters - ---------- - rho0: jax.numpy.ndarray - The initial density field. - u0: jax.numpy.ndarray - The initial velocity field. - - Returns - ------- - f: jax.numpy.ndarray - The array holding the initialized distribution functions for the simulation. - """ - return self.equilibrium(rho0, u0) - - def send_right(self, x, axis_name): - """ - This function sends the data to the right neighboring process in a parallel computing environment. - It uses a permutation operation provided by the LAX library. - - Parameters - ---------- - x: jax.numpy.ndarray - The data to be sent. - axis_name: str - The name of the axis along which the data is sent. - - Returns - ------- - jax.numpy.ndarray - The data after being sent to the right neighboring process. - """ - return lax.ppermute(x, perm=self.rightPerm, axis_name=axis_name) - - def send_left(self, x, axis_name): - """ - This function sends the data to the left neighboring process in a parallel computing environment. - It uses a permutation operation provided by the LAX library. - - Parameters - ---------- - x: jax.numpy.ndarray - The data to be sent. - axis_name: str - The name of the axis along which the data is sent. - - Returns - ------- - The data after being sent to the left neighboring process. - """ - return lax.ppermute(x, perm=self.leftPerm, axis_name=axis_name) - - def streaming_m(self, f): - """ - This function performs the streaming step in the Lattice Boltzmann Method, which is - the propagation of the distribution functions in the lattice. - - To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the - distribution functions that need to be communicated to the neighboring processes. - - The function then sends the left boundary slice to the right neighboring process and the right - boundary slice to the left neighboring process. The received data is then set to the - corresponding indices in the receiving domain. - - Parameters - ---------- - f: jax.numpy.ndarray - The array holding the distribution functions for the simulation. - - Returns - ------- - jax.numpy.ndarray - The distribution functions after the streaming operation. - """ - f = self.streaming_p(f) - left_comm, right_comm = ( - f[:1, ..., self.lattice.right_indices], - f[-1:, ..., self.lattice.left_indices], - ) - - left_comm, right_comm = self.send_right(left_comm, "x"), self.send_left( - right_comm, "x" - ) - f = f.at[:1, ..., self.lattice.right_indices].set(left_comm) - f = f.at[-1:, ..., self.lattice.left_indices].set(right_comm) - return f - - @partial(jit, static_argnums=(0,)) - def streaming_p(self, f): - """ - Perform streaming operation on a partitioned (in the x-direction) distribution function. - - The function uses the vmap operation provided by the JAX library to vectorize the computation - over all lattice directions. - - Parameters - ---------- - f: The distribution function. - - Returns - ------- - The updated distribution function after streaming. - """ - - def streaming_i(f, c): - """ - Perform individual streaming operation in a direction. - - Parameters - ---------- - f: The distribution function. - c: The streaming direction vector. - - Returns - ------- - jax.numpy.ndarray - The updated distribution function after streaming. - """ - if self.dim == 2: - return jnp.roll(f, (c[0], c[1]), axis=(0, 1)) - elif self.dim == 3: - return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) - - return vmap(streaming_i, in_axes=(-1, 0), out_axes=-1)(f, self.c.T) - - @partial(jit, static_argnums=(0, 3), inline=True) - def equilibrium(self, rho, u, cast_output=True): - """ - This function computes the equilibrium distribution function in the Lattice Boltzmann Method. - The equilibrium distribution function is a function of the macroscopic density and velocity. - - The function first casts the density and velocity to the compute precision if the cast_output flag is True. - The function finally casts the equilibrium distribution function to the output precision if the cast_output - flag is True. - - Parameters - ---------- - rho: jax.numpy.ndarray - The macroscopic density. - u: jax.numpy.ndarray - The macroscopic velocity. - cast_output: bool, optional - A flag indicating whether to cast the density, velocity, and equilibrium distribution function to the - compute and output precisions. Default is True. - - Returns - ------- - feq: ja.numpy.ndarray - The equilibrium distribution function. - """ - # Cast the density and velocity to the compute precision if the cast_output flag is True - if cast_output: - rho, u = self.precisionPolicy.cast_to_compute((rho, u)) - - # Cast c to compute precision so that XLA call FXX matmul, - # which is faster (it is faster in some older versions of JAX, newer versions are smart enough to do this automatically) - c = jnp.array(self.c, dtype=self.precisionPolicy.compute_dtype) - cu = 3.0 * jnp.dot(u, c) - usqr = 1.5 * jnp.sum(jnp.square(u), axis=-1, keepdims=True) - feq = rho * self.w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) - - if cast_output: - return self.precisionPolicy.cast_to_output(feq) - else: - return feq - - @partial(jit, static_argnums=(0,)) - def momentum_flux(self, fneq): - """ - This function computes the momentum flux, which is the product of the non-equilibrium - distribution functions (fneq) and the lattice moments (cc). - - The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann - Method (LBM). - - Parameters - ---------- - fneq: jax.numpy.ndarray - The non-equilibrium distribution functions. - - Returns - ------- - jax.numpy.ndarray - The computed momentum flux. - """ - return jnp.dot(fneq, self.lattice.cc) - - @partial(jit, static_argnums=(0,), inline=True) - def update_macroscopic(self, f): - """ - This function computes the macroscopic variables (density and velocity) based on the - distribution functions (f). - - The density is computed as the sum of the distribution functions over all lattice directions. - The velocity is computed as the dot product of the distribution functions and the lattice - velocities, divided by the density. - - Parameters - ---------- - f: jax.numpy.ndarray - The distribution functions. - - Returns - ------- - rho: jax.numpy.ndarray - Computed density. - u: jax.numpy.ndarray - Computed velocity. - """ - rho = jnp.sum(f, axis=-1, keepdims=True) - c = jnp.array(self.c, dtype=self.precisionPolicy.compute_dtype).T - u = jnp.dot(f, c) / rho - - return rho, u - - @partial(jit, static_argnums=(0, 4), inline=True) - def apply_bc(self, fout, fin, timestep, implementation_step): - """ - This function applies the boundary conditions to the distribution functions. - - It iterates over all boundary conditions (BCs) and checks if the implementation step of the - boundary condition matches the provided implementation step. If it does, it applies the - boundary condition to the post-streaming distribution functions (fout). - - Parameters - ---------- - fout: jax.numpy.ndarray - The post-collision distribution functions. - fin: jax.numpy.ndarray - The post-streaming distribution functions. - implementation_step: str - The implementation step at which the boundary conditions should be applied. - - Returns - ------- - ja.numpy.ndarray - The output distribution functions after applying the boundary conditions. - """ - for bc in self.BCs: - fout = bc.prepare_populations(fout, fin, implementation_step) - if bc.implementationStep == implementation_step: - if bc.isDynamic: - fout = bc.apply(fout, fin, timestep) - else: - fout = fout.at[bc.indices].set(bc.apply(fout, fin)) - - return fout - - @partial(jit, static_argnums=(0, 3), donate_argnums=(1,)) - def step(self, f_poststreaming, timestep, return_fpost=False): - """ - This function performs a single step of the LBM simulation. - - It first performs the collision step, which is the relaxation of the distribution functions - towards their equilibrium values. It then applies the respective boundary conditions to the - post-collision distribution functions. - - The function then performs the streaming step, which is the propagation of the distribution - functions in the lattice. It then applies the respective boundary conditions to the post-streaming - distribution functions. - - Parameters - ---------- - f_poststreaming: jax.numpy.ndarray - The post-streaming distribution functions. - timestep: int - The current timestep of the simulation. - return_fpost: bool, optional - If True, the function also returns the post-collision distribution functions. - - Returns - ------- - f_poststreaming: jax.numpy.ndarray - The post-streaming distribution functions after the simulation step. - f_postcollision: jax.numpy.ndarray or None - The post-collision distribution functions after the simulation step, or None if - return_fpost is False. - """ - f_postcollision = self.collision(f_poststreaming) - f_postcollision = self.apply_bc( - f_postcollision, f_poststreaming, timestep, "PostCollision" - ) - f_poststreaming = self.streaming(f_postcollision) - f_poststreaming = self.apply_bc( - f_poststreaming, f_postcollision, timestep, "PostStreaming" - ) - - if return_fpost: - return f_poststreaming, f_postcollision - else: - return f_poststreaming, None - - def run(self, t_max): - """ - This function runs the LBM simulation for a specified number of time steps. - - It first initializes the distribution functions and then enters a loop where it performs the - simulation steps (collision, streaming, and boundary conditions) for each time step. - - The function can also print the progress of the simulation, save the simulation data, and - compute the performance of the simulation in million lattice updates per second (MLUPS). - - Parameters - ---------- - t_max: int - The total number of time steps to run the simulation. - Returns - ------- - f: jax.numpy.ndarray - The distribution functions after the simulation. - """ - f = self.assign_fields_sharded() - start_step = 0 - if self.restore_checkpoint: - latest_step = self.mngr.latest_step() - if latest_step is not None: # existing checkpoint present - # Assert that the checkpoint manager is not None - assert self.mngr is not None, "Checkpoint manager does not exist." - state = {"f": f} - shardings = jax.tree_map(lambda x: x.sharding, state) - restore_args = orb.checkpoint_utils.construct_restore_args( - state, shardings - ) - try: - f = self.mngr.restore( - latest_step, restore_kwargs={"restore_args": restore_args} - )["f"] - print(f"Restored checkpoint at step {latest_step}.") - except ValueError: - raise ValueError( - f"Failed to restore checkpoint at step {latest_step}." - ) - - start_step = latest_step + 1 - if not (t_max > start_step): - raise ValueError( - f"Simulation already exceeded maximum allowable steps (t_max = {t_max}). Consider increasing t_max." - ) - if self.computeMLUPS: - start = time.time() - # Loop over all time steps - for timestep in range(start_step, t_max + 1): - io_flag = self.ioRate > 0 and ( - timestep % self.ioRate == 0 or timestep == t_max - ) - print_iter_flag = ( - self.printInfoRate > 0 and timestep % self.printInfoRate == 0 - ) - checkpoint_flag = ( - self.checkpointRate > 0 and timestep % self.checkpointRate == 0 - ) - - if io_flag: - # Update the macroscopic variables and save the previous values (for error computation) - rho_prev, u_prev = self.update_macroscopic(f) - rho_prev = downsample_field(rho_prev, self.downsamplingFactor) - u_prev = downsample_field(u_prev, self.downsamplingFactor) - # Gather the data from all processes and convert it to numpy arrays (move to host memory) - rho_prev = process_allgather(rho_prev) - u_prev = process_allgather(u_prev) - - # Perform one time-step (collision, streaming, and boundary conditions) - f, fstar = self.step(f, timestep, return_fpost=self.returnFpost) - # Print the progress of the simulation - if print_iter_flag: - print( - colored("Timestep ", "blue") - + colored(f"{timestep}", "green") - + colored(" of ", "blue") - + colored(f"{t_max}", "green") - + colored(" completed", "blue") - ) - - if io_flag: - # Save the simulation data - print(f"Saving data at timestep {timestep}/{t_max}") - rho, u = self.update_macroscopic(f) - rho = downsample_field(rho, self.downsamplingFactor) - u = downsample_field(u, self.downsamplingFactor) - - # Gather the data from all processes and convert it to numpy arrays (move to host memory) - rho = process_allgather(rho) - u = process_allgather(u) - - # Save the data - self.handle_io_timestep(timestep, f, fstar, rho, u, rho_prev, u_prev) - - if checkpoint_flag: - # Save the checkpoint - print(f"Saving checkpoint at timestep {timestep}/{t_max}") - state = {"f": f} - self.mngr.save(timestep, state) - - # Start the timer for the MLUPS computation after the first timestep (to remove compilation overhead) - if self.computeMLUPS and timestep == 1: - jax.block_until_ready(f) - start = time.time() - - if self.computeMLUPS: - # Compute and print the performance of the simulation in MLUPS - jax.block_until_ready(f) - end = time.time() - if self.dim == 2: - print( - colored("Domain: ", "blue") - + colored(f"{self.nx} x {self.ny}", "green") - if self.dim == 2 - else colored(f"{self.nx} x {self.ny} x {self.nz}", "green") - ) - print( - colored("Number of voxels: ", "blue") - + colored(f"{self.nx * self.ny}", "green") - if self.dim == 2 - else colored(f"{self.nx * self.ny * self.nz}", "green") - ) - print( - colored("MLUPS: ", "blue") - + colored( - f"{self.nx * self.ny * t_max / (end - start) / 1e6}", "red" - ) - ) - - elif self.dim == 3: - print( - colored("Domain: ", "blue") - + colored(f"{self.nx} x {self.ny} x {self.nz}", "green") - ) - print( - colored("Number of voxels: ", "blue") - + colored(f"{self.nx * self.ny * self.nz}", "green") - ) - print( - colored("MLUPS: ", "blue") - + colored( - f"{self.nx * self.ny * self.nz * t_max / (end - start) / 1e6}", - "red", - ) - ) - - return f - - def handle_io_timestep(self, timestep, f, fstar, rho, u, rho_prev, u_prev): - """ - This function handles the input/output (I/O) operations at each time step of the simulation. - - It prepares the data to be saved and calls the output_data function, which can be overwritten - by the user to customize the I/O operations. - - Parameters - ---------- - timestep: int - The current time step of the simulation. - f: jax.numpy.ndarray - The post-streaming distribution functions at the current time step. - fstar: jax.numpy.ndarray - The post-collision distribution functions at the current time step. - rho: jax.numpy.ndarray - The density field at the current time step. - u: jax.numpy.ndarray - The velocity field at the current time step. - """ - kwargs = { - "timestep": timestep, - "rho": rho, - "rho_prev": rho_prev, - "u": u, - "u_prev": u_prev, - "f_poststreaming": f, - "f_postcollision": fstar, - } - self.output_data(**kwargs) - - def output_data(self, **kwargs): - """ - This function is intended to be overwritten by the user to customize the input/output (I/O) - operations of the simulation. - - By default, it does nothing. When overwritten, it could save the simulation data to files, - display the simulation results in real time, send the data to another process for analysis, etc. - - Parameters - ---------- - **kwargs: dict - A dictionary containing the simulation data to be outputted. The keys are the names of the - data fields, and the values are the data fields themselves. - """ - pass - - def set_boundary_conditions(self): - """ - This function sets the boundary conditions for the simulation. - - It is intended to be overwritten by the user to specify the boundary conditions according to - the specific problem being solved. - - By default, it does nothing. When overwritten, it could set periodic boundaries, no-slip - boundaries, inflow/outflow boundaries, etc. - """ - pass - - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def collision(self, fin): - """ - This function performs the collision step in the Lattice Boltzmann Method. - - It is intended to be overwritten by the user to specify the collision operator according to - the specific LBM model being used. - - By default, it does nothing. When overwritten, it could implement the BGK collision operator, - the MRT collision operator, etc. - - Parameters - ---------- - fin: jax.numpy.ndarray - The pre-collision distribution functions. - - Returns - ------- - fin: jax.numpy.ndarray - The post-collision distribution functions. - """ - pass - - def get_force(self): - """ - This function computes the force to be applied to the fluid in the Lattice Boltzmann Method. - - It is intended to be overwritten by the user to specify the force according to the specific - problem being solved. - - By default, it does nothing and returns None. When overwritten, it could implement a constant - force term. - - Returns - ------- - force: jax.numpy.ndarray - The force to be applied to the fluid. - """ - pass - - @partial(jit, static_argnums=(0,), inline=True) - def apply_force(self, f_postcollision, feq, rho, u): - """ - add force based on exact-difference method due to Kupershtokh - - Parameters - ---------- - f_postcollision: jax.numpy.ndarray - The post-collision distribution functions. - feq: jax.numpy.ndarray - The equilibrium distribution functions. - rho: jax.numpy.ndarray - The density field. - - u: jax.numpy.ndarray - The velocity field. - - Returns - ------- - f_postcollision: jax.numpy.ndarray - The post-collision distribution functions with the force applied. - - References - ---------- - Kupershtokh, A. (2004). New method of incorporating a body force term into the lattice Boltzmann equation. In - Proceedings of the 5th International EHD Workshop (pp. 241-246). University of Poitiers, Poitiers, France. - Chikatamarla, S. S., & Karlin, I. V. (2013). Entropic lattice Boltzmann method for turbulent flow simulations: - Boundary conditions. Physica A, 392, 1925-1930. - Krüger, T., et al. (2017). The lattice Boltzmann method. Springer International Publishing, 10.978-3, 4-15. - """ - delta_u = self.get_force() - feq_force = self.equilibrium(rho, u + delta_u, cast_output=False) - f_postcollision = f_postcollision + feq_force - feq - return f_postcollision diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index d2b579a..f17545b 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -1,6 +1,9 @@ from jax.sharding import PartitionSpec as P from jax.sharding import NamedSharding, Mesh from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +import numpy as np +import jax.numpy as jnp import jax from xlb.grid import Grid @@ -8,6 +11,7 @@ from xlb.operator import Operator from xlb.precision_policy import Precision + class JaxGrid(Grid): def __init__(self, shape): super().__init__(shape) @@ -80,9 +84,8 @@ def create_field(self, cardinality: int, precision: Precision, callback=None): # Create field if callback is None: - f = jax.numpy.full(shape, 0.0, dtype=precision.jax_dtype) - #if self.sharding is not None: - # f = jax.make_sharded_array(self.sharding, f) + f = np.full(shape, 0.0, dtype=precision.jax_dtype) + f = jax.device_put(f, self.sharding) else: f = jax.make_array_from_callback(shape, self.sharding, callback) diff --git a/xlb/operator/boundary_condition/equilibrium.py b/xlb/operator/boundary_condition/equilibrium.py index 6de68ec..4a8ccb5 100644 --- a/xlb/operator/boundary_condition/equilibrium.py +++ b/xlb/operator/boundary_condition/equilibrium.py @@ -53,19 +53,13 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) - @partial(jit, static_argnums=(0)) + # @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): - # TODO: This is unoptimized feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) - feq = jnp.reshape(feq, (self.velocity_set.q, 1, 1, 1)) - feq = jnp.repeat(feq, f_pre.shape[1], axis=1) - feq = jnp.repeat(feq, f_pre.shape[2], axis=2) - feq = jnp.repeat(feq, f_pre.shape[3], axis=3) - boundary = boundary_id == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - skipped_f = lax.select(boundary, feq, f_post) - return skipped_f + feq = feq[:, None, None, None] + boundary = (boundary_id == self.id) + + return jnp.where(boundary, feq, f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index b9e9f5b..83660c1 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -3,7 +3,7 @@ from functools import partial import numpy as np import jax.numpy as jnp -from jax import jit +from jax import jit, lax import warp as wp from typing import Tuple @@ -36,60 +36,31 @@ def __init__( def _indices_to_tuple(indices): """ Converts a tensor of indices to a tuple for indexing - TODO: Might be better to index """ - return tuple([indices[:, i] for i in range(indices.shape[1])]) + return tuple(indices.T) @Operator.register_backend(ComputeBackend.JAX) - # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this def jax_implementation( self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0) ): - # TODO: This is somewhat untested and unoptimized - - # Get local indices from the meshgrid and the indices local_indices = indices - np.array(start_index)[np.newaxis, :] # Remove any indices that are out of bounds - local_indices = local_indices[ - (local_indices[:, 0] >= 0) - & (local_indices[:, 0] < mask.shape[0]) - & (local_indices[:, 1] >= 0) - & (local_indices[:, 1] < mask.shape[1]) - & (local_indices[:, 2] >= 0) - & (local_indices[:, 2] < mask.shape[2]) - ] - - # Set the boundary id - boundary_id = boundary_id.at[0, self._indices_to_tuple(local_indices)].set( - id_number - ) + indices_mask_x = (local_indices[:, 0] >= 0) & (local_indices[:, 0] < mask.shape[1]) + indices_mask_y = (local_indices[:, 1] >= 0) & (local_indices[:, 1] < mask.shape[2]) + indices_mask_z = (local_indices[:, 2] >= 0) & (local_indices[:, 2] < mask.shape[3]) + indices_mask = indices_mask_x & indices_mask_y & indices_mask_z - # Make mask then stream to get the edge points - pre_stream_mask = jnp.zeros_like(mask) - pre_stream_mask = pre_stream_mask.at[self._indices_to_tuple(local_indices)].set( - True - ) - post_stream_mask = self.stream(pre_stream_mask) - - # Set false for points inside the boundary (NOTE: removing this to be more consistent with the other boundary maskers, maybe add back in later) - # post_stream_mask = post_stream_mask.at[ - # post_stream_mask[0, ...] == True - # ].set(False) - - # Get indices on edges - edge_indices = jnp.argwhere(post_stream_mask) - - # Set the mask - mask = mask.at[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ].set( - post_stream_mask[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ] - ) + local_indices = self._indices_to_tuple(local_indices[indices_mask]) + + @jit + def compute_boundary_id_and_mask(boundary_id, mask): + boundary_id = boundary_id.at[0, local_indices[0], local_indices[1], local_indices[2]].set(id_number) + mask = mask.at[:, local_indices[0], local_indices[1], local_indices[2]].set(True) + mask = self.stream(mask) + return boundary_id, mask - return boundary_id, mask + return compute_boundary_id_and_mask(boundary_id, mask) def _construct_warp(self): # Make constants for warp @@ -131,9 +102,9 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id[ - 0, push_index[0], push_index[1], push_index[2] - ] = wp.uint8(id_number) + boundary_id[0, index[0], index[1], index[2]] = ( + wp.uint8(id_number) + ) mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index e1eed44..1e89547 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -51,9 +51,9 @@ def __init__( precision_policies = set([op.precision_policy for op in self.operators]) assert len(precision_policies) == 1, "All precision policies must be the same" precision_policy = precision_policies.pop() - compute_backends = set([op.compute_backend for op in self.operators]) - assert len(compute_backends) == 1, "All compute backends must be the same" - compute_backend = compute_backends.pop() + compute_backend = set([op.compute_backend for op in self.operators]) + assert len(compute_backend) == 1, "All compute backends must be the same" + compute_backend = compute_backend.pop() # Add boundary conditions # Warp cannot handle lists of functions currently diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index aacc785..bf40c55 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from xlb.compute_backends import ComputeBackends +from xlb.compute_backend import ComputeBackend from xlb.global_config import GlobalConfig from xlb.precision_policy.jax_precision_policy import ( @@ -14,8 +14,8 @@ class Fp64Fp64: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackends.JAX - or GlobalConfig.compute_backend == ComputeBackends.PALLAS + GlobalConfig.compute_backend == ComputeBackend.JAX + or GlobalConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp64Fp64() else: @@ -27,8 +27,8 @@ def __new__(cls): class Fp64Fp32: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackends.JAX - or GlobalConfig.compute_backend == ComputeBackends.PALLAS + GlobalConfig.compute_backend == ComputeBackend.JAX + or GlobalConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp64Fp32() else: @@ -40,8 +40,8 @@ def __new__(cls): class Fp32Fp32: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackends.JAX - or GlobalConfig.compute_backend == ComputeBackends.PALLAS + GlobalConfig.compute_backend == ComputeBackend.JAX + or GlobalConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp32Fp32() else: @@ -53,8 +53,8 @@ def __new__(cls): class Fp64Fp16: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackends.JAX - or GlobalConfig.compute_backend == ComputeBackends.PALLAS + GlobalConfig.compute_backend == ComputeBackend.JAX + or GlobalConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp64Fp16() else: @@ -66,8 +66,8 @@ def __new__(cls): class Fp32Fp16: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackends.JAX - or GlobalConfig.compute_backend == ComputeBackends.PALLAS + GlobalConfig.compute_backend == ComputeBackend.JAX + or GlobalConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp32Fp16() else: From 19e93c02214b33fda350eade186f582450ac075e Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Mon, 22 Apr 2024 15:57:19 -0400 Subject: [PATCH 027/144] Added tests for grid and equilibrium kernels --- examples/CFD/airfoil3d.py | 195 --- examples/CFD/cavity2d.py | 96 -- examples/CFD/cavity3d.py | 123 -- examples/CFD/channel3d.py | 157 --- examples/CFD/couette2d.py | 79 -- examples/CFD/cylinder2d.py | 148 --- examples/CFD/oscilating_cylinder2d.py | 146 --- examples/CFD/taylor_green_vortex.py | 127 -- examples/CFD/windtunnel3d.py | 147 --- examples/CFD_refactor/windtunnel3d.py | 512 -------- examples/backend_comparisons/README.md | 14 - .../backend_comparisons/lattice_boltzmann.py | 1138 ----------------- examples/backend_comparisons/small_example.py | 327 ----- examples/cfd/example_basic.py | 69 + .../{interfaces => cfd}/flow_past_sphere.py | 16 +- .../ldc.py => cfd/lid_driven_cavity.py} | 0 examples/{interfaces => cfd}/taylor_green.py | 0 examples/performance/MLUPS2d.py | 67 - examples/performance/MLUPS3d.py | 76 -- examples/performance/MLUPS3d_distributed.py | 108 -- examples/{refactor => performance}/mlups3d.py | 0 examples/refactor/README.md | 31 - examples/refactor/example_basic.py | 62 - examples/refactor/example_jax.py | 107 -- examples/refactor/example_jax_out_of_core.py | 336 ----- examples/refactor/example_numba.py | 78 -- examples/refactor/example_pallas_3d.py | 73 -- examples/refactor/mlups_pallas_3d.py | 84 -- examples/warp_backend/testing.py | 108 -- requirements.txt | 3 +- tests/__init__.py | 0 .../boundary_conditions.py | 1 + tests/grids/__init__.py | 0 tests/grids/test_jax_grid.py | 61 + tests/grids/test_warp_grid.py | 48 + tests/kernels/__init__.py | 0 tests/kernels/jax/test_equilibrium_jax.py | 42 + tests/kernels/warp/test_equilibrium_warp.py | 48 + xlb/__init__.py | 2 +- xlb/default_config.py | 40 + xlb/global_config.py | 14 - xlb/grid/__init__.py | 4 +- xlb/grid/grid.py | 48 +- xlb/grid/jax_grid.py | 95 +- xlb/grid/warp_grid.py | 40 +- .../indices_boundary_masker.py | 2 +- .../boundary_masker/planar_boundary_masker.py | 2 +- .../boundary_masker/stl_boundary_masker.py | 2 +- xlb/operator/equilibrium/equilibrium.py | 1 - .../equilibrium/quadratic_equilibrium.py | 27 +- xlb/operator/macroscopic/macroscopic.py | 2 +- xlb/operator/operator.py | 8 +- xlb/precision_policy/precision_policy.py | 32 +- xlb/solver/solver.py | 10 +- 54 files changed, 472 insertions(+), 4484 deletions(-) delete mode 100644 examples/CFD/airfoil3d.py delete mode 100644 examples/CFD/cavity2d.py delete mode 100644 examples/CFD/cavity3d.py delete mode 100644 examples/CFD/channel3d.py delete mode 100644 examples/CFD/couette2d.py delete mode 100644 examples/CFD/cylinder2d.py delete mode 100644 examples/CFD/oscilating_cylinder2d.py delete mode 100644 examples/CFD/taylor_green_vortex.py delete mode 100644 examples/CFD/windtunnel3d.py delete mode 100644 examples/CFD_refactor/windtunnel3d.py delete mode 100644 examples/backend_comparisons/README.md delete mode 100644 examples/backend_comparisons/lattice_boltzmann.py delete mode 100644 examples/backend_comparisons/small_example.py create mode 100644 examples/cfd/example_basic.py rename examples/{interfaces => cfd}/flow_past_sphere.py (95%) rename examples/{interfaces/ldc.py => cfd/lid_driven_cavity.py} (100%) rename examples/{interfaces => cfd}/taylor_green.py (100%) delete mode 100644 examples/performance/MLUPS2d.py delete mode 100644 examples/performance/MLUPS3d.py delete mode 100644 examples/performance/MLUPS3d_distributed.py rename examples/{refactor => performance}/mlups3d.py (100%) delete mode 100644 examples/refactor/README.md delete mode 100644 examples/refactor/example_basic.py delete mode 100644 examples/refactor/example_jax.py delete mode 100644 examples/refactor/example_jax_out_of_core.py delete mode 100644 examples/refactor/example_numba.py delete mode 100644 examples/refactor/example_pallas_3d.py delete mode 100644 examples/refactor/mlups_pallas_3d.py delete mode 100644 examples/warp_backend/testing.py create mode 100644 tests/__init__.py create mode 100644 tests/grids/__init__.py create mode 100644 tests/grids/test_jax_grid.py create mode 100644 tests/grids/test_warp_grid.py create mode 100644 tests/kernels/__init__.py create mode 100644 tests/kernels/jax/test_equilibrium_jax.py create mode 100644 tests/kernels/warp/test_equilibrium_warp.py create mode 100644 xlb/default_config.py delete mode 100644 xlb/global_config.py diff --git a/examples/CFD/airfoil3d.py b/examples/CFD/airfoil3d.py deleted file mode 100644 index 77e4b58..0000000 --- a/examples/CFD/airfoil3d.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -This is a example for simulating fluid flow around a NACA airfoil using the lattice Boltzmann method (LBM). -The LBM is a computational fluid dynamics method for simulating fluid flow and is particularly effective -for complex geometries and multiphase flow. - -In this example you'll be introduced to the following concepts: - -1. Lattice: The example uses a D3Q27 lattice, which is a three-dimensional lattice model that considers - 27 discrete velocity directions. This allows for a more accurate representation of the fluid flow - in three dimensions. - -2. NACA Airfoil Generation: The example includes a function to generate a NACA airfoil shape, which is - common in aerodynamics. The function allows for customization of the length, thickness, and angle - of the airfoil. - -3. Boundary Conditions: The example includes several boundary conditions. These include a "bounce back" - condition on the airfoil surface and the top and bottom of the domain, a "do nothing" condition - at the outlet (right side of the domain), and an "equilibrium" condition at the inlet - (left side of the domain) to simulate a uniform flow. - -4. Simulation Parameters: The example allows for the setting of various simulation parameters, - including the Reynolds number, inlet velocity, and characteristic length. - -5. In-situ visualization: The example outputs rendering images of the q-criterion using - PhantomGaze library (https://github.com/loliverhennigh/PhantomGaze) without any I/O overhead - while the data is still on the GPU. -""" - - -import numpy as np -# from IPython import display -import matplotlib.pylab as plt -from src.models import BGKSim, KBCSim -from src.lattice import LatticeD3Q19, LatticeD3Q27 -from src.boundary_conditions import * -import numpy as np -from src.utils import * -from jax import config -import os -#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax -import scipy - -# PhantomGaze for in-situ rendering -import phantomgaze as pg - -def makeNacaAirfoil(length, thickness=30, angle=0): - def nacaAirfoil(x, thickness, chordLength): - coeffs = [0.2969, -0.1260, -0.3516, 0.2843, -0.1015] - exponents = [0.5, 1, 2, 3, 4] - yt = [coeff * (x / chordLength) ** exp for coeff, exp in zip(coeffs, exponents)] - yt = 5. * thickness / 100 * chordLength * np.sum(yt) - - return yt - - x = np.linspace(0, length, num=length) - yt = np.array([nacaAirfoil(xi, thickness, length) for xi in x]) - - y_max = int(np.max(yt)) + 1 - domain = np.zeros((2 * y_max, len(x)), dtype=int) - - for i, xi in enumerate(x): - upper_bound = int(y_max + yt[i]) - lower_bound = int(y_max - yt[i]) - domain[lower_bound:upper_bound, i] = 1 - - domain = scipy.ndimage.rotate(domain, angle, reshape=True) - domain = np.where(domain > 0.5, 1, 0) - - return domain - -class Airfoil(KBCSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - tx, ty = np.array([self.nx, self.ny], dtype=int) - airfoil.shape - - airfoil_mask = np.pad(airfoil, ((tx // 3, tx - tx // 3), (ty // 2, ty - ty // 2)), 'constant', constant_values=False) - airfoil_mask = np.repeat(airfoil_mask[:, :, np.newaxis], self.nz, axis=2) - - airfoil_indices = np.argwhere(airfoil_mask) - wall = np.concatenate((airfoil_indices, - self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'])) - self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - - # Store airfoil boundary for visualization - self.visualization_bc = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32) - self.visualization_bc = self.visualization_bc.at[tuple(airfoil_indices.T)].set(1.0) - - doNothing = self.boundingBoxIndices['right'] - self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) - - inlet = self.boundingBoxIndices['left'] - rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_inlet = np.zeros((inlet.shape), dtype=self.precisionPolicy.compute_dtype) - - vel_inlet[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) - - def output_data(self, **kwargs): - # Compute q-criterion and vorticity using finite differences - # Get velocity field - u = kwargs['u'][..., 1:-1, :] - # vorticity and q-criterion - norm_mu, q = q_criterion(u) - - # Make phantomgaze volume - dx = 0.01 - origin = (0.0, 0.0, 0.0) - upper_bound = (self.visualization_bc.shape[0] * dx, self.visualization_bc.shape[1] * dx, self.visualization_bc.shape[2] * dx) - q_volume = pg.objects.Volume( - q, - spacing=(dx, dx, dx), - origin=origin, - ) - norm_mu_volume = pg.objects.Volume( - norm_mu, - spacing=(dx, dx, dx), - origin=origin, - ) - boundary_volume = pg.objects.Volume( - self.visualization_bc, - spacing=(dx, dx, dx), - origin=origin, - ) - - # Make colormap for norm_mu - colormap = pg.Colormap("jet", vmin=0.0, vmax=0.05) - - # Get camera parameters - focal_point = (self.visualization_bc.shape[0] * dx / 2, self.visualization_bc.shape[1] * dx / 2, self.visualization_bc.shape[2] * dx / 2) - radius = 5.0 - angle = kwargs['timestep'] * 0.0001 - camera_position = (focal_point[0] + radius * np.sin(angle), focal_point[1], focal_point[2] + radius * np.cos(angle)) - - # Rotate camera - camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920, background=pg.SolidBackground(color=(0.0, 0.0, 0.0))) - - # Make wireframe - screen_buffer = pg.render.wireframe(lower_bound=origin, upper_bound=upper_bound, thickness=0.01, camera=camera) - - # Render axes - screen_buffer = pg.render.axes(size=0.1, center=(0.0, 0.0, 1.1), camera=camera, screen_buffer=screen_buffer) - - # Render q-criterion - screen_buffer = pg.render.contour(q_volume, threshold=0.00003, color=norm_mu_volume, colormap=colormap, camera=camera, screen_buffer=screen_buffer) - - # Render boundary - boundary_colormap = pg.Colormap("bone_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256)) - screen_buffer = pg.render.volume(boundary_volume, camera=camera, colormap=boundary_colormap, screen_buffer=screen_buffer) - - # Show the rendered image - plt.imsave('q_criterion_' + str(kwargs['timestep']).zfill(7) + '.png', np.minimum(screen_buffer.image.get(), 1.0)) - - -if __name__ == '__main__': - airfoil_length = 101 - airfoil_thickness = 30 - airfoil_angle = 20 - airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T - precision = 'f32/f32' - - lattice = LatticeD3Q27(precision) - - nx = airfoil.shape[0] - ny = airfoil.shape[1] - - ny = 3 * ny - nx = 5 * nx - nz = 101 - - Re = 30000.0 - prescribed_vel = 0.1 - clength = airfoil_length - - visc = prescribed_vel * clength / Re - omega = 1.0 / (3. * visc + 0.5) - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - - # Set the parameters for the simulation - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': nz, - 'precision': precision, - 'io_rate': 100, - 'print_info_rate': 100, - } - - sim = Airfoil(**kwargs) - sim.run(20000) diff --git a/examples/CFD/cavity2d.py b/examples/CFD/cavity2d.py deleted file mode 100644 index 8d1ffc8..0000000 --- a/examples/CFD/cavity2d.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -This example implements a 2D Lid-Driven Cavity Flow simulation using the lattice Boltzmann method (LBM). -The Lid-Driven Cavity Flow is a standard test case for numerical schemes applied to fluid dynamics, which involves fluid in a square cavity with a moving lid (top boundary). - -In this example you'll be introduced to the following concepts: - -1. Lattice: The simulation employs a D2Q9 lattice. It's a 2D lattice model with nine discrete velocity directions, which is typically used for 2D simulations. - -2. Boundary Conditions: The code implements two types of boundary conditions: - - BounceBackHalfway: This condition is applied to the stationary walls (left, right, and bottom). It models a no-slip boundary where the velocity of fluid at the wall is zero. - EquilibriumBC: This condition is used for the moving lid (top boundary). It defines a boundary with a set velocity, simulating the "driving" of the cavity by the lid. - -3. Checkpointing: The simulation supports checkpointing. Checkpoints are saved periodically (determined by the 'checkpoint_rate'), allowing the simulation to be stopped and restarted from the last checkpoint. This can be beneficial for long simulations or in case of unexpected interruptions. - -4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview. - -""" -from jax import config -import numpy as np -import jax.numpy as jnp -import os - -from src. import * -from src.solver import BGKSim, KBCSim -from src.lattice import LatticeD2Q9 -from src.utils import * - -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' - -class Cavity(KBCSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - - # concatenate the indices of the left, right, and bottom walls - walls = np.concatenate((self.boundingBoxIndices["left"], self.boundingBoxIndices["right"], self.boundingBoxIndices["bottom"])) - # apply bounce back boundary condition to the walls - self.BCs.append(BounceBackHalfway(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - - # apply inlet equilibrium boundary condition to the top wall - moving_wall = self.boundingBoxIndices["top"] - - rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs["rho"][1:-1, 1:-1]) - u = np.array(kwargs["u"][1:-1, 1:-1, :]) - timestep = kwargs["timestep"] - - save_image(timestep, u) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1]} - save_fields_vtk(timestep, fields) - save_BCs_vtk(timestep, self.BCs, self.gridInfo) - -if __name__ == "__main__": - precision = "f32/f32" - lattice = LatticeD2Q9(precision) - - nx = 200 - ny = 200 - - Re = 200.0 - prescribed_vel = 0.1 - clength = nx - 1 - - checkpoint_rate = 1000 - checkpoint_dir = os.path.abspath("./checkpoints") - - visc = prescribed_vel * clength / Re - omega = 1.0 / (3.0 * visc + 0.5) - - os.system("rm -rf ./*.vtk && rm -rf ./*.png") - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': 100, - 'print_info_rate': 100, - 'checkpoint_rate': checkpoint_rate, - 'checkpoint_dir': checkpoint_dir, - 'restore_checkpoint': False, - } - - sim = Cavity(**kwargs) - sim.run(5000) diff --git a/examples/CFD/cavity3d.py b/examples/CFD/cavity3d.py deleted file mode 100644 index 2c30d28..0000000 --- a/examples/CFD/cavity3d.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -This example implements a 3D Lid-Driven Cavity Flow simulation using the lattice Boltzmann method (LBM). -The Lid-Driven Cavity Flow is a standard test case for numerical schemes applied to fluid dynamics, which involves fluid in a square cavity with a moving lid (top boundary). - -In this example you'll be introduced to the following concepts: - -1. Lattice: The simulation employs a D3Q27 lattice. It's a 3D lattice model with 27 discrete velocity directions. - -2. Boundary Conditions: The code implements two types of boundary conditions: - - BounceBack: This condition is applied to the stationary walls, except the top wall. It models a no-slip boundary where the velocity of fluid at the wall is zero. - EquilibriumBC: This condition is used for the moving lid (top boundary). It defines a boundary with a set velocity, simulating the "driving" of the cavity by the lid. - -4. Visualization: The simulation outputs data in VTK format for visualization. The data can be visualized using software like Paraview. - -""" -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' - -import numpy as np -from src.utils import * -from jax import config -import json, codecs - -from src.models import BGKSim, KBCSim -from src.lattice import LatticeD3Q19, LatticeD3Q27 -from src.boundary_conditions import * - - -config.update('jax_enable_x64', True) - -class Cavity(KBCSim): - # Note: We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000 - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # Note: - # We have used halfway BB for Re=(1000, 3200) and regularized BC for Re=10,000 - - # apply inlet boundary condition to the top wall - moving_wall = self.boundingBoxIndices['top'] - vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = prescribed_vel - # self.BCs.append(BounceBackHalfway(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, vel_wall)) - self.BCs.append(Regularized(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall)) - - # concatenate the indices of the left, right, and bottom walls - walls = np.concatenate( - (self.boundingBoxIndices['left'], self.boundingBoxIndices['right'], - self.boundingBoxIndices['front'], self.boundingBoxIndices['back'], - self.boundingBoxIndices['bottom'])) - # apply bounce back boundary condition to the walls - # self.BCs.append(BounceBackHalfway(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - vel_wall = np.zeros(walls.shape, dtype=self.precisionPolicy.compute_dtype) - self.BCs.append(Regularized(tuple(walls.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall)) - return - - def output_data(self, **kwargs): - # 1: -1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho']) - u = np.array(kwargs['u']) - timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'] - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}'.format(err)) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - # save_fields_vtk(timestep, fields) - - # output profiles of velocity at mid-plane for benchmarking - output_filename = "./profiles_" + f"{timestep:07d}.json" - ux_mid = 0.5*(u[nx//2, ny//2, :, 0] + u[nx//2+1, ny//2+1, :, 0]) - uz_mid = 0.5*(u[:, ny//2, nz//2, 2] + u[:, ny//2+1, nz//2+1, 2]) - ldc_ref_result = {'ux(x=y=0)': list(ux_mid/prescribed_vel), - 'uz(z=y=0)': list(uz_mid/prescribed_vel)} - json.dump(ldc_ref_result, codecs.open(output_filename, 'w', encoding='utf-8'), - separators=(',', ':'), - sort_keys=True, - indent=4) - - # Calculate the velocity magnitude - # u_mag = np.linalg.norm(u, axis=2) - # live_volume_randering(timestep, u_mag) - -if __name__ == '__main__': - # Note: - # We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000 - precision = 'f64/f64' - lattice = LatticeD3Q27(precision) - - nx = 256 - ny = 256 - nz = 256 - - Re = 10000.0 - prescribed_vel = 0.06 - clength = nx - 2 - - # characteristic time - tc = prescribed_vel/clength - niter_max = int(500//tc) - - visc = prescribed_vel * clength / Re - omega = 1.0 / (3. * visc + 0.5) - os.system("rm -rf ./*.vtk && rm -rf ./*.png") - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': nz, - 'precision': precision, - 'io_rate': int(10//tc), - 'print_info_rate': int(10//tc), - 'downsampling_factor': 1 - } - sim = Cavity(**kwargs) - sim.run(niter_max) \ No newline at end of file diff --git a/examples/CFD/channel3d.py b/examples/CFD/channel3d.py deleted file mode 100644 index 2c7ab73..0000000 --- a/examples/CFD/channel3d.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -This script performs a 3D simulation of turbulent channel flow using the lattice Boltzmann method (LBM). -Turbulent channel flow, also known as plane Couette flow, is a fundamental case in the study of wall-bounded turbulent flows. - -In this example you'll be introduced to the following concepts: - -1. Lattice: A D3Q27 lattice is used, which is a three-dimensional lattice model with 27 discrete velocity directions. This type of lattice allows for a more precise representation of fluid flow in three dimensions. - -2. Initial Conditions: The initial conditions for the flow are randomly generated, and the populations are initialized to be the solution of an advection-diffusion equation. - -3. Boundary Conditions: Bounce back boundary conditions are applied at the top and bottom walls, simulating a no-slip condition typical for wall-bounded flows. - -4. External Force: An external force is applied to drive the flow. - -""" - -from src.boundary_conditions import * -from jax import config -from src.utils import * -import numpy as np -from src.lattice import LatticeD3Q27 -from src.models import KBCSim, AdvectionDiffusionBGK -import jax.numpy as jnp -import os -import matplotlib.pyplot as plt - -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -import jax - -# disable JIt compilation - -jax.config.update('jax_enable_x64', True) - -def vonKarman_loglaw_wall(yplus): - vonKarmanConst = 0.41 - cplus = 5.5 - uplus = np.log(yplus)/vonKarmanConst + cplus - return uplus - -def get_dns_data(): - """ - Reference: DNS of Turbulent Channel Flow up to Re_tau=590, 1999, - Physics of Fluids, vol 11, 943-945. - https://turbulence.oden.utexas.edu/data/MKM/chan180/profiles/chan180.means - """ - dns_dic = { - "y":[0,0.000301,0.0012,0.00271,0.00482,0.00752,0.0108,0.0147,0.0192,0.0243,0.03,0.0362,0.0431,0.0505,0.0585,0.067,0.0761,0.0858,0.096,0.107,0.118,0.13,0.142,0.155,0.169,0.182,0.197,0.212,0.227,0.243,0.259,0.276,0.293,0.31,0.328,0.347,0.366,0.385,0.404,0.424,0.444,0.465,0.486,0.507,0.529,0.55,0.572,0.595,0.617,0.64,0.663,0.686,0.71,0.733,0.757,0.781,0.805,0.829,0.853,0.878,0.902,0.926,0.951,0.975,1], - "y+":[0,0.053648,0.21456,0.48263,0.85771,1.3396,1.9279,2.6224,3.4226,4.328,5.3381,6.4523,7.67,8.9902,10.412,11.936,13.559,15.281,17.102,19.019,21.033,23.141,25.342,27.635,30.019,32.492,35.053,37.701,40.432,43.247,46.143,49.118,52.171,55.3,58.503,61.778,65.123,68.536,72.016,75.559,79.164,82.828,86.55,90.327,94.157,98.037,101.97,105.94,109.96,114.02,118.12,122.25,126.42,130.62,134.84,139.1,143.37,147.67,151.99,156.32,160.66,165.02,169.38,173.75,178.12], - "Umean":[0,0.053639,0.21443,0.48197,0.85555,1.3339,1.9148,2.5939,3.3632,4.2095,5.1133,6.0493,6.9892,7.9052,8.7741,9.579,10.311,10.967,11.55,12.066,12.52,12.921,13.276,13.59,13.87,14.121,14.349,14.557,14.75,14.931,15.101,15.264,15.419,15.569,15.714,15.855,15.993,16.128,16.26,16.389,16.515,16.637,16.756,16.872,16.985,17.094,17.2,17.302,17.4,17.494,17.585,17.672,17.756,17.835,17.911,17.981,18.045,18.103,18.154,18.198,18.235,18.264,18.285,18.297,18.301], - "dUmean/dy":[178,178,178,178,177,176,175,173,169,163,155,144,131,116,101,87.1,73.9,62.2,52.2,43.8,36.9,31.1,26.4,22.6,19.4,16.9,14.9,13.3,12,10.9,10.1,9.38,8.79,8.29,7.86,7.49,7.19,6.91,6.63,6.35,6.07,5.81,5.58,5.36,5.14,4.92,4.68,4.45,4.23,4.04,3.85,3.66,3.48,3.28,3.06,2.81,2.54,2.25,1.96,1.67,1.35,1.02,0.673,0.33,0], - "Wmean":[0,0.0000707,0.000283,0.000636,0.00113,0.00176,0.00252,0.00339,0.00435,0.00538,0.00643,0.00751,0.00864,0.00986,0.0112,0.0126,0.0141,0.0156,0.017,0.0181,0.0186,0.0184,0.0176,0.0163,0.0149,0.0135,0.0124,0.0116,0.0107,0.00966,0.00843,0.00695,0.00519,0.00329,0.00145,-0.000284,-0.00177,-0.00292,-0.00377,-0.00445,-0.00497,-0.0054,-0.00594,-0.00681,-0.0082,-0.00996,-0.0119,-0.0139,-0.0163,-0.0191,-0.0225,-0.0263,-0.0306,-0.0354,-0.0405,-0.0455,-0.05,-0.0539,-0.0577,-0.0615,-0.0653,-0.0685,-0.071,-0.0724,-0.0729], - "dWmean/dy":[0.235,0.235,0.235,0.234,0.234,0.232,0.228,0.22,0.208,0.194,0.179,0.168,0.164,0.164,0.166,0.167,0.162,0.148,0.121,0.076,0.0159,-0.0439,-0.087,-0.107,-0.106,-0.0871,-0.0643,-0.0546,-0.061,-0.0707,-0.0818,-0.0958,-0.108,-0.106,-0.0989,-0.0881,-0.0697,-0.0506,-0.0379,-0.0303,-0.0221,-0.0216,-0.0314,-0.0522,-0.0756,-0.0841,-0.0884,-0.0974,-0.114,-0.136,-0.154,-0.172,-0.196,-0.214,-0.215,-0.199,-0.174,-0.156,-0.155,-0.159,-0.147,-0.118,-0.0788,-0.0387,0], - "Pmean":[6.2170e-13,-7.3193e-10,-1.5832e-07,-3.7598e-06,-3.3837e-05,-1.7683e-04,-6.5008e-04,-1.8650e-03,-4.4488e-03,-9.2047e-03,-1.7023e-02,-2.8777e-02,-4.5228e-02,-6.6952e-02,-9.4281e-02,-1.2724e-01,-1.6551e-01,-2.0842e-01,-2.5498e-01,-3.0396e-01,-3.5398e-01,-4.0362e-01,-4.5163e-01,-4.9698e-01,-5.3880e-01,-5.7639e-01,-6.0919e-01,-6.3686e-01,-6.5930e-01,-6.7652e-01,-6.8867e-01,-6.9613e-01,-6.9928e-01,-6.9854e-01,-6.9444e-01,-6.8744e-01,-6.7802e-01,-6.6675e-01,-6.5429e-01,-6.4131e-01,-6.2817e-01,-6.1487e-01,-6.0122e-01,-5.8703e-01,-5.7221e-01,-5.5678e-01,-5.4090e-01,-5.2493e-01,-5.0917e-01,-4.9371e-01,-4.7867e-01,-4.6421e-01,-4.5050e-01,-4.3759e-01,-4.2550e-01,-4.1436e-01,-4.0444e-01,-3.9595e-01,-3.8900e-01,-3.8360e-01,-3.7966e-01,-3.7702e-01,-3.7542e-01,-3.7460e-01,-3.7436e-01] - } - return dns_dic - -class TurbulentChannel(KBCSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # top and bottom sides of the channel are no-slip and the other directions are periodic - wall = np.concatenate((self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'])) - self.BCs.append(Regularized(tuple(wall.T), self.gridInfo, self.precisionPolicy, 'velocity', np.zeros((wall.shape[0], 3)))) - return - - def initialize_macroscopic_fields(self): - rho = self.precisionPolicy.cast_to_output(1.0) - u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim), - self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim))) - u = self.precisionPolicy.cast_to_output(u) - return rho, u - - def initialize_populations(self, rho, u): - omegaADE = 1.0 - lattice = LatticeD3Q27(precision) - - kwargs = {'lattice': lattice, 'nx': self.nx, 'ny': self.ny, 'nz': self.nz, 'precision': precision, 'omega': omegaADE, 'vel': u} - ADE = AdvectionDiffusionBGK(**kwargs) - ADE.initialize_macroscopic_fields = self.initialize_macroscopic_fields - print("Initializing the distribution functions using the specified macroscopic fields....") - f = ADE.run(50000) - return f - - def get_force(self): - # define the external force - force = np.zeros((self.nx, self.ny, self.nz, 3)) - force[..., 0] = Re_tau**2 * visc**2 / h**3 - return self.precisionPolicy.cast_to_output(force) - - def output_data(self, **kwargs): - rho = np.array(kwargs["rho"]) - u = np.array(kwargs["u"]) - timestep = kwargs["timestep"] - u_prev = kwargs['u_prev'] - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - - err = np.sum(np.abs(u_old - u_new)) - print("error= {:07.6f}".format(err)) - - # mean streamwise velocity in wall units u^+(z) - uplus = np.mean(u[..., 0], axis=(0,1))/u_tau - uplus_loglaw = vonKarman_loglaw_wall(yplus) - dns_dic = get_dns_data() - plt.clf() - plt.semilogx(yplus, uplus,'r.', yplus, uplus_loglaw, 'k:', dns_dic['y+'], dns_dic['Umean'], 'b-') - ax = plt.gca() - ax.set_xlim([0.1, 300]) - ax.set_ylim([0, 20]) - fname = "uplus_" + str(timestep//10000).zfill(5) + '.pdf' - plt.savefig(fname, format='pdf') - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) - - - -if __name__ == "__main__": - precision = "f64/f64" - lattice = LatticeD3Q27(precision) - - # h: channel half-width - h = 50 - - # Define channel geometry based on h - nx = 6*h - ny = 3*h - nz = 2*h - - # Define flow regime - Re_tau = 180 - u_tau = 0.001 - DeltaPlus = Re_tau/h # DeltaPlus = u_tau / nu * Delta where u_tau / nu = Re_tau/h - visc = u_tau * h / Re_tau - omega = 1.0 / (3.0 * visc + 0.5) - - # Wall distance in wall units to be used inside output_data - zz = np.arange(nz) - zz = np.minimum(zz, zz.max() - zz) - yplus = zz * u_tau / visc - - os.system("rm -rf ./*.vtk && rm -rf ./*.png") - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': nz, - 'precision': precision, - 'io_rate': 500000, - 'print_info_rate': 100000 - } - sim = turbulentChannel(**kwargs) - sim.run(10000000) diff --git a/examples/CFD/couette2d.py b/examples/CFD/couette2d.py deleted file mode 100644 index 1c15a6a..0000000 --- a/examples/CFD/couette2d.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM). -""" - -import os -import jax.numpy as jnp -import numpy as np -from src.utils import * -from jax import config - - -from src.models import BGKSim -from src.boundary_conditions import * -from src.lattice import LatticeD2Q9 - -# config.update('jax_disable_jit', True) -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' - -class Couette(BGKSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - walls = np.concatenate((self.boundingBoxIndices["top"], self.boundingBoxIndices["bottom"])) - self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - - outlet = self.boundingBoxIndices["right"] - inlet = self.boundingBoxIndices["left"] - - rho_wall = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_wall = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) - - self.BCs.append(DoNothing(tuple(outlet.T), self.gridInfo, self.precisionPolicy)) - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs["rho"][..., 1:-1, :]) - u = np.array(kwargs["u"][..., 1:-1, :]) - timestep = kwargs["timestep"] - u_prev = kwargs["u_prev"][..., 1:-1, :] - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - err = np.sum(np.abs(u_old - u_new)) - print("error= {:07.6f}".format(err)) - save_image(timestep, u) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1]} - save_fields_vtk(timestep, fields) - -if __name__ == "__main__": - precision = "f32/f32" - lattice = LatticeD2Q9(precision) - nx = 501 - ny = 101 - - Re = 100.0 - prescribed_vel = 0.1 - clength = nx - 1 - - visc = prescribed_vel * clength / Re - - omega = 1.0 / (3.0 * visc + 0.5) - assert omega < 1.98, "omega must be less than 2.0" - os.system("rm -rf ./*.vtk && rm -rf ./*.png") - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': 100, - 'print_info_rate': 100 - } - sim = Couette(**kwargs) - sim.run(20000) diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py deleted file mode 100644 index 2c9887d..0000000 --- a/examples/CFD/cylinder2d.py +++ /dev/null @@ -1,148 +0,0 @@ -""" -This script conducts a 2D simulation of flow around a cylinder using the lattice Boltzmann method (LBM). This is a classic problem in fluid dynamics and is often used to examine the behavior of fluid flow over a bluff body. - -In this example you'll be introduced to the following concepts: - -1. Lattice: A D2Q9 lattice is used, which is a two-dimensional lattice model with nine discrete velocity directions. This type of lattice allows for a precise representation of fluid flow in two dimensions. - -2. Boundary Conditions: The script implements several types of boundary conditions: - - BounceBackHalfway: This condition is applied to the cylinder surface, simulating a no-slip condition where the fluid at the cylinder surface has zero velocity. - ExtrapolationOutflow: This condition is applied at the outlet (right boundary), where the fluid is allowed to exit the simulation domain freely. - Regularized: This condition is applied at the inlet (left boundary) and models the inflow of fluid into the domain with a specified velocity profile. Another Regularized condition is used for the stationary top and bottom walls. -3. Velocity Profile: The script uses a Poiseuille flow profile for the inlet velocity. This is a parabolic profile commonly seen in pipe flow. - -4. Drag and lift calculation: The script computes the lift and drag on the cylinder, which are important quantities in fluid dynamics and aerodynamics. - -5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView. - -# To run type: -nohup python3 examples/CFD/cylinder2d.py > logfile.log & -""" -import os -import json -import jax -from time import time -from jax import config -import numpy as np -import jax.numpy as jnp - -from src.utils import * -from src.boundary_conditions import * -from src.models import BGKSim, KBCSim -from src.lattice import LatticeD2Q9 - -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -jax.config.update('jax_enable_x64', True) - -class Cylinder(BGKSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # Define the cylinder surface - coord = np.array([(i, j) for i in range(self.nx) for j in range(self.ny)]) - xx, yy = coord[:, 0], coord[:, 1] - cx, cy = 2.*diam, 2.*diam - cylinder = (xx - cx)**2 + (yy-cy)**2 <= (diam/2.)**2 - cylinder = coord[cylinder] - implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (diam/2.)**2, (self.nx, self.ny)) - self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy)) - - # Outflow BC - outlet = self.boundingBoxIndices['right'] - rho_outlet = np.ones((outlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy)) - # self.BCs.append(ZouHe(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet)) - - # Inlet BC - inlet = self.boundingBoxIndices['left'] - rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) - yy_inlet = yy.reshape(self.nx, self.ny)[tuple(inlet.T)] - vel_inlet[:, 0] = poiseuille_profile(yy_inlet, - yy_inlet.min(), - yy_inlet.max()-yy_inlet.min(), 3.0 / 2.0 * prescribed_vel) - self.BCs.append(Regularized(tuple(inlet.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_inlet)) - - # No-slip BC for top and bottom - wall = np.concatenate([self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']]) - vel_wall = np.zeros(wall.shape, dtype=self.precisionPolicy.compute_dtype) - self.BCs.append(Regularized(tuple(wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall)) - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using bounce-back) - rho = np.array(kwargs["rho"][..., 1:-1, :]) - u = np.array(kwargs["u"][..., 1:-1, :]) - timestep = kwargs["timestep"] - u_prev = kwargs["u_prev"][..., 1:-1, :] - - if timestep == 0: - self.CL_max = 0.0 - self.CD_max = 0.0 - if timestep > 0.5 * niter_max: - # compute lift and drag over the cyliner - cylinder = self.BCs[0] - boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(np.array(boundary_force), axis=0) - drag = boundary_force[0] - lift = boundary_force[1] - cd = 2. * drag / (prescribed_vel ** 2 * diam) - cl = 2. * lift / (prescribed_vel ** 2 * diam) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - err = np.sum(np.abs(u_old - u_new)) - self.CL_max = max(self.CL_max, cl) - self.CD_max = max(self.CD_max, cd) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - # save_image(timestep, u) - -# Helper function to specify a parabolic poiseuille profile -poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2)) - -if __name__ == '__main__': - precision = 'f64/f64' - # diam_list = [10, 20, 30, 40, 60, 80] - diam_list = [80] - CL_list, CD_list = [], [] - result_dict = {} - result_dict['resolution_list'] = diam_list - for diam in diam_list: - scale_factor = 80 / diam - prescribed_vel = 0.003 * scale_factor - lattice = LatticeD2Q9(precision) - - nx = int(22*diam) - ny = int(4.1*diam) - - Re = 100.0 - visc = prescribed_vel * diam / Re - omega = 1.0 / (3. * visc + 0.5) - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': int(500 / scale_factor), - 'print_info_rate': int(10000 / scale_factor), - 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag - } - # characteristic time - tc = prescribed_vel/diam - niter_max = int(100//tc) - sim = Cylinder(**kwargs) - sim.run(niter_max) - CL_list.append(sim.CL_max) - CD_list.append(sim.CD_max) - - result_dict['CL'] = CL_list - result_dict['CD'] = CD_list - with open('data.json', 'w') as fp: - json.dump(result_dict, fp) diff --git a/examples/CFD/oscilating_cylinder2d.py b/examples/CFD/oscilating_cylinder2d.py deleted file mode 100644 index 97d6746..0000000 --- a/examples/CFD/oscilating_cylinder2d.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -This script conducts a 2D simulation of flow around a cylinder using the lattice Boltzmann method (LBM). This is a classic problem in fluid dynamics and is often used to examine the behavior of fluid flow over a bluff body. - -In this example you'll be introduced to the following concepts: - -1. Lattice: A D2Q9 lattice is used, which is a two-dimensional lattice model with nine discrete velocity directions. This type of lattice allows for a precise representation of fluid flow in two dimensions. - -2. Boundary Conditions: The script implements several types of boundary conditions: - - BounceBackMoving: This condition is applied to the cylinder surface. Unlike the usual BounceBack condition, this one takes into account the motion of the cylinder. - ExtrapolationOutflow: This condition is applied at the outlet (right boundary), where the fluid is allowed to exit the simulation domain freely. - Regularized: This condition is applied at the inlet (left boundary) and models the inflow of fluid into the domain with a specified velocity profile. Another Regularized condition is used for the stationary top and bottom walls. -3. Velocity Profile: The script uses a Poiseuille flow profile for the inlet velocity. This is a parabolic profile commonly seen in pipe flow. - -4. Drag and lift calculation: The script computes the lift and drag on the cylinder, which are important quantities in fluid dynamics and aerodynamics. - -5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView. - -""" - - -import os -import jax -from time import time -from jax import config -import numpy as np -import jax.numpy as jnp - -from src.utils import * -from src.boundary_conditions import * -from src.models import BGKSim, KBCSim -from src.lattice import LatticeD2Q9 - -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -jax.config.update('jax_enable_x64', True) - -class Cylinder(KBCSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - wall = np.concatenate([self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']]) - self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - - coord = np.array([np.unravel_index(i, (self.nx, self.ny)) for i in range(self.nx*self.ny)]) - xx, yy = coord[:, 0], coord[:, 1] - cx, cy = 2.*diam, 2.*diam - cyl = ((xx) - cx)**2 + (yy-cy)**2 <= (diam/2.)**2 - cyl = jnp.array(coord[cyl]) - - # Define update rules for boundary conditions - def update_function(time: int): - # Move the cylinder up and down sinusoidally with time - # Define the scale for the sinusoidal motion - scale = 10000 - - # Amplitude of the motion, a quarter of the y-dimension of the grid - A = ny // 4 - - # Calculate the new y-coordinates of the cylinder. The cylinder moves up and down, - # its motion dictated by the sinusoidal function. We use `astype(int)` to ensure - # the indices are integers, as they will be used for array indexing. - new_y_coords = cyl[:, 1] + jnp.array((jnp.sin(time/scale)*A).astype(int)) - - # Define the indices of the grid points occupied by the cylinder - indices = (cyl[:, 0], new_y_coords) - - # Calculate the velocity of the cylinder. The x-component is always 0 (the cylinder - # doesn't move horizontally), and the y-component is the derivative of the sinusoidal - # function governing the cylinder's motion, scaled by the amplitude and the scale factor. - velocity = jnp.array([0., jnp.cos(time/scale)* A / scale], dtype=self.precisionPolicy.compute_dtype) - - return indices, velocity - - self.BCs.append(BounceBackMoving(self.gridInfo, self.precisionPolicy, update_function=update_function)) - - - outlet = self.boundingBoxIndices['right'] - self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy)) - - inlet = self.boundingBoxIndices['left'] - vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) - yy_inlet = yy.reshape(self.nx, self.ny)[tuple(inlet.T)] - vel_inlet[:, 0] = poiseuille_profile(yy_inlet, - yy_inlet.min(), - yy_inlet.max()-yy_inlet.min(), 3.0 / 2.0 * prescribed_vel) - self.BCs.append(Regularized(tuple(inlet.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_inlet)) - - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs["rho"][..., 1:-1, :]) - u = np.array(kwargs["u"][..., 1:-1, :]) - timestep = kwargs["timestep"] - u_prev = kwargs["u_prev"][..., 1:-1, :] - - # compute lift and drag over the cyliner - cylinder = self.BCs[0] - boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(boundary_force, axis=0) - drag = boundary_force[0] - lift = boundary_force[1] - cd = 2. * drag / (prescribed_vel ** 2 * diam) - cl = 2. * lift / (prescribed_vel ** 2 * diam) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - save_image(timestep, u) - # u magnitude - fields = {'rho': rho[..., 0], 'u': np.linalg.norm(u, axis=2)} - save_fields_vtk(timestep, fields) - save_BCs_vtk(timestep, self.BCs, self.gridInfo) - -# Helper function to specify a parabolic poiseuille profile -poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2)) - -if __name__ == '__main__': - precision = 'f64/f64' - lattice = LatticeD2Q9(precision) - prescribed_vel = 0.005 - diam = 20 - nx = int(22*diam) - ny = int(4.1*diam) - - Re = 10.0 - visc = prescribed_vel * diam / Re - omega = 1.0 / (3. * visc + 0.5) - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': 500, - 'print_info_rate': 500, - 'return_fpost': True, # Need to retain fpost-collision for computation of lift and drag - } - sim = Cylinder(**kwargs) - - sim.run(1000000) diff --git a/examples/CFD/taylor_green_vortex.py b/examples/CFD/taylor_green_vortex.py deleted file mode 100644 index 374c499..0000000 --- a/examples/CFD/taylor_green_vortex.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -The given script sets up a simulation for the Taylor-Green vortex flow. -The Taylor-Green vortex is a type of two-dimensional, incompressible fluid flow with a known analytical solution, making it an ideal test case for fluid dynamics simulations. -The flow is characterized by a pair of counter-rotating vortices. In this script, the initial fields for the Taylor-Green vortex are set using a known function. -""" - - -import os -import json -import jax -import numpy as np -import matplotlib.pyplot as plt - -from src.utils import * -from src.boundary_conditions import * -from src.models import BGKSim, KBCSim, AdvectionDiffusionBGK -from src.lattice import LatticeD2Q9 - - -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -# disable JIt compilation - -jax.config.update('jax_enable_x64', True) - -def taylor_green_initial_fields(xx, yy, u0, rho0, nu, time): - ux = u0 * np.sin(xx) * np.cos(yy) * np.exp(-2 * nu * time) - uy = -u0 * np.cos(xx) * np.sin(yy) * np.exp(-2 * nu * time) - rho = 1.0 - rho0 * u0 ** 2 / 12. * (np.cos(2. * xx) + np.cos(2. * yy)) * np.exp(-4 * nu * time) - return ux, uy, np.expand_dims(rho, axis=-1) - -class TaylorGreenVortex(KBCSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # no boundary conditions implying periodic BC in all directions - return - - def initialize_macroscopic_fields(self): - ux, uy, rho = taylor_green_initial_fields(xx, yy, vel_ref, 1, 0., 0.) - rho = self.distributed_array_init(rho.shape, self.precisionPolicy.output_dtype, init_val=1.0, sharding=self.sharding) - u = np.stack([ux, uy], axis=-1) - u = self.distributed_array_init(u.shape, self.precisionPolicy.output_dtype, init_val=u, sharding=self.sharding) - return rho, u - - def initialize_populations(self, rho, u): - omegaADE = 1.0 - kwargs = {'lattice': lattice, 'nx': self.nx, 'ny': self.ny, 'nz': self.nz, 'precision': precision, 'omega': omegaADE, 'vel': u, 'print_info_rate': 0, 'io_rate': 0} - ADE = AdvectionDiffusionBGK(**kwargs) - ADE.initialize_macroscopic_fields = self.initialize_macroscopic_fields - print("Initializing the distribution functions using the specified macroscopic fields....") - f = ADE.run(int(20000*nx/32)) - return f - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs["rho"]) - u = np.array(kwargs["u"]) - timestep = kwargs["timestep"] - - # theoretical results - time = timestep * (kx**2 + ky**2)/2. - ux_th, uy_th, rho_th = taylor_green_initial_fields(xx, yy, vel_ref, 1, visc, time) - vel_err_L2 = np.sqrt(np.sum((u[..., 0]-ux_th)**2 + (u[..., 1]-uy_th)**2) / np.sum(ux_th**2 + uy_th**2)) - rho_err_L2 = np.sqrt(np.sum((rho - rho_th)**2) / np.sum(rho_th**2)) - print("Vel error= {:07.6f}, Pressure error= {:07.6f}".format(vel_err_L2, rho_err_L2)) - if timestep == endTime: - ErrL2ResList.append(vel_err_L2) - ErrL2ResListRho.append(rho_err_L2) - # save_image(timestep, u) - - -if __name__ == "__main__": - precision_list = ["f32/f32", "f64/f32", "f64/f64"] - resList = [32, 64, 128, 256, 512, 1024] - result_dict = dict.fromkeys(precision_list) - result_dict['resolution_list'] = resList - - for precision in precision_list: - lattice = LatticeD2Q9(precision) - ErrL2ResList = [] - ErrL2ResListRho = [] - result_dict[precision] = dict.fromkeys(['vel_error', 'rho_error']) - for nx in resList: - ny = nx - twopi = 2.0 * np.pi - coord = np.array([(i, j) for i in range(nx) for j in range(ny)]) - xx, yy = coord[:, 0], coord[:, 1] - kx, ky = twopi / nx, twopi / ny - xx = xx.reshape((nx, ny)) * kx - yy = yy.reshape((nx, ny)) * ky - - Re = 1600.0 - vel_ref = 0.04*32/nx - - visc = vel_ref * nx / Re - omega = 1.0 / (3.0 * visc + 0.5) - os.system("rm -rf ./*.vtk && rm -rf ./*.png") - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': 0, - 'precision': precision, - 'io_rate': 5000, - 'print_info_rate': 1000 - } - sim = TaylorGreenVortex(**kwargs) - tc = 2.0/(2. * visc * (kx**2 + ky**2)) - endTime = int(0.05*tc) - sim.run(endTime) - result_dict[precision]['vel_error'] = ErrL2ResList - result_dict[precision]['rho_error'] = ErrL2ResListRho - - with open('data.json', 'w') as fp: - json.dump(result_dict, fp) - - # plt.loglog(resList, ErrL2ResList, '-o') - # plt.loglog(resList, 1e-3*(np.array(resList)/128)**(-2), '--') - # plt.savefig('ErrorVel.png'); plt.savefig('ErrorVel.pdf', format='pdf') - - # plt.figure() - # plt.loglog(resList, ErrL2ResListRho, '-o') - # plt.loglog(resList, 1e-3*(np.array(resList)/128)**(-2), '--') - # plt.savefig('ErrorRho.png'); plt.savefig('ErrorRho.pdf', format='pdf') diff --git a/examples/CFD/windtunnel3d.py b/examples/CFD/windtunnel3d.py deleted file mode 100644 index 231069b..0000000 --- a/examples/CFD/windtunnel3d.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -This script performs a Lattice Boltzmann Method (LBM) simulation of fluid flow over a car model. Here are the main concepts and steps in the simulation: - -Here are the main concepts introduced simulation: - -1. Lattice: Given the usually high Reynolds number required for these simulations, a D3Q27 lattice is used, which is a three-dimensional lattice model with 27 discrete velocity directions. - -2. Loading geometry and voxelization: The geometry of the car is loaded from a STL file. -This is a file format commonly used for 3D models. The model is then voxelized to a binary matrix which represents the presence or absence of the object in the lattice. We use the DrivAer model, which is a common car model used for aerodynamic simulations. - -3. Output: After each specified number of iterations, the script outputs the state of the simulation. This includes the error (difference between consecutive velocity fields), lift and drag coefficients, and visualization files in the VTK format. -""" - - -import os -import jax -import trimesh -from time import time -import numpy as np -import jax.numpy as jnp -from jax import config - -from src.utils import * -from src.models import BGKSim, KBCSim -from src.lattice import LatticeD3Q19, LatticeD3Q27 -from src.boundary_conditions import * - -# Use 8 CPU devices -# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' - -# disable JIt compilation - -jax.config.update('jax_array', True) - -class Car(KBCSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def voxelize_stl(self, stl_filename, length_lbm_unit): - mesh = trimesh.load_mesh(stl_filename, process=False) - length_phys_unit = mesh.extents.max() - pitch = length_phys_unit/length_lbm_unit - mesh_voxelized = mesh.voxelized(pitch=pitch) - mesh_matrix = mesh_voxelized.matrix - return mesh_matrix, pitch - - def set_boundary_conditions(self): - print('Voxelizing mesh...') - time_start = time() - stl_filename = 'stl-files/DrivAer-Notchback.stl' - car_length_lbm_unit = self.nx / 4 - car_voxelized, pitch = voxelize_stl(stl_filename, car_length_lbm_unit) - car_matrix = car_voxelized.matrix - print('Voxelization time for pitch={}: {} seconds'.format(pitch, time() - time_start)) - print("Car matrix shape: ", car_matrix.shape) - - self.car_area = np.prod(car_matrix.shape[1:]) - tx, ty, tz = np.array([nx, ny, nz]) - car_matrix.shape - shift = [tx//4, ty//2, 0] - car_indices = np.argwhere(car_matrix) + shift - self.BCs.append(BounceBackHalfway(tuple(car_indices.T), self.gridInfo, self.precisionPolicy)) - - wall = np.concatenate((self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'], - self.boundingBoxIndices['front'], self.boundingBoxIndices['back'])) - self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - - doNothing = self.boundingBoxIndices['right'] - self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) - self.BCs[-1].implementationStep = 'PostCollision' - # rho_outlet = np.ones(doNothing.shape[0], dtype=self.precisionPolicy.compute_dtype) - # self.BCs.append(ZouHe(tuple(doNothing.T), - # self.gridInfo, - # self.precisionPolicy, - # 'pressure', rho_outlet)) - - inlet = self.boundingBoxIndices['left'] - rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) - - vel_inlet[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) - # self.BCs.append(ZouHe(tuple(inlet.T), - # self.gridInfo, - # self.precisionPolicy, - # 'velocity', vel_inlet)) - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho'][..., 1:-1, 1:-1, :]) - u = np.array(kwargs['u'][..., 1:-1, 1:-1, :]) - timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'][..., 1:-1, 1:-1, :] - - # compute lift and drag over the car - car = self.BCs[0] - boundary_force = car.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(boundary_force, axis=0) - drag = np.sqrt(boundary_force[0]**2 + boundary_force[1]**2) #xy-plane - lift = boundary_force[2] #z-direction - cd = 2. * drag / (prescribed_vel ** 2 * self.car_area) - cl = 2. * lift / (prescribed_vel ** 2 * self.car_area) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) - -class VehicleRecipe(Recipe): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - -if __name__ == '__main__': - precision = 'f32/f32' - lattice = LatticeD3Q27(precision) - - nx = 601 - ny = 351 - nz = 251 - - Re = 50000.0 - prescribed_vel = 0.05 - clength = nx - 1 - - visc = prescribed_vel * clength / Re - omega = 1.0 / (3. * visc + 0.5) - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': nx, - 'ny': ny, - 'nz': nz, - 'precision': precision, - 'io_rate': 100, - 'print_info_rate': 100, - 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag - } - sim = Car(**kwargs) - sim.run(200000) diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py deleted file mode 100644 index 156a219..0000000 --- a/examples/CFD_refactor/windtunnel3d.py +++ /dev/null @@ -1,512 +0,0 @@ -# Wind tunnel simulation using the XLB library - -from typing import Any -import os -import jax -import trimesh -from time import time -import numpy as np -import warp as wp -import pyvista as pv -import tqdm -import matplotlib.pyplot as plt - -wp.init() - -import xlb -from xlb.operator import Operator - -class UniformInitializer(Operator): - - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - ): - # Get the global index - i, j, k = wp.tid() - - # Set the velocity - u[0, i, j, k] = vel - u[1, i, j, k] = 0.0 - u[2, i, j, k] = 0.0 - - # Set the density - rho[0, i, j, k] = 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - ], - dim=rho.shape[1:], - ) - return rho, u - -class MomentumTransfer(Operator): - - def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c - _opp_indices = self.velocity_set.wp_opp_indices - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool - - # Find velocity index for 0, 0, 0 - for l in range(self.velocity_set.q): - if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0: - zero_index = l - _zero_index = wp.int32(zero_index) - print(f"Zero index: {_zero_index}") - - # Construct the warp kernel - @wp.kernel - def kernel( - f: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - momentum: wp.array(dtype=Any), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # Get the boundary id - _boundary_id = boundary_id[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) - - # Determin if boundary is an edge by checking if center is missing - is_edge = wp.bool(False) - if _boundary_id == wp.uint8(xlb.operator.boundary_condition.HalfwayBounceBackBC.id): - if _missing_mask[_zero_index] != wp.uint8(1): - is_edge = wp.bool(True) - - # If the boundary is an edge then add the momentum transfer - m = wp.vec3() - if is_edge: - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - phi = 2.0 * f[_opp_indices[l], index[0], index[1], index[2]] - - # Compute the momentum transfer - for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) - - wp.atomic_add(momentum, 0, m) - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, f, boundary_id, missing_mask): - - # Allocate the momentum field - momentum = wp.zeros((1), dtype=wp.vec3) - - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f, boundary_id, missing_mask, momentum], - dim=f.shape[1:], - ) - return momentum.numpy() - - -class WindTunnel: - """ - Wind tunnel simulation using the XLB library - """ - - def __init__( - self, - stl_filename: str, - inlet_velocity: float = 27.78, # m/s - lower_bounds: tuple[float, float, float] = (0.0, 0.0, 0.0), # m - upper_bounds: tuple[float, float, float] = (1.0, 0.5, 0.5), # m - dx: float = 0.01, # m - viscosity: float = 1.42e-5, # air at 20 degrees Celsius - density: float = 1.2754, # kg/m^3 - solve_time: float = 1.0, # s - #collision="BGK", - collision="KBC", - equilibrium="Quadratic", - velocity_set="D3Q27", - precision_policy=xlb.PrecisionPolicy.FP32FP32, - compute_backend=xlb.ComputeBackend.WARP, - grid_configs={}, - save_state_frequency=1024, - monitor_frequency=32, - ): - - # Set parameters - self.stl_filename = stl_filename - self.inlet_velocity = inlet_velocity - self.lower_bounds = lower_bounds - self.upper_bounds = upper_bounds - self.dx = dx - self.solve_time = solve_time - self.viscosity = viscosity - self.density = density - self.save_state_frequency = save_state_frequency - self.monitor_frequency = monitor_frequency - - # Get fluid properties needed for the simulation - self.base_velocity = 0.05 # LBM units - self.velocity_conversion = self.base_velocity / inlet_velocity - self.dt = self.dx * self.velocity_conversion - self.lbm_viscosity = self.viscosity * self.dt / (self.dx ** 2) - self.tau = 0.5 + self.lbm_viscosity - self.omega = 1.0 / self.tau - print(f"tau: {self.tau}") - print(f"omega: {self.omega}") - self.lbm_density = 1.0 - self.mass_conversion = self.dx ** 3 * (self.density / self.lbm_density) - self.nr_steps = int(solve_time / self.dt) - - # Get the grid shape - self.nx = int((upper_bounds[0] - lower_bounds[0]) / dx) - self.ny = int((upper_bounds[1] - lower_bounds[1]) / dx) - self.nz = int((upper_bounds[2] - lower_bounds[2]) / dx) - self.shape = (self.nx, self.ny, self.nz) - - # Set the compute backend - self.compute_backend = xlb.ComputeBackend.WARP - - # Set the precision policy - self.precision_policy = xlb.PrecisionPolicy.FP32FP32 - - # Set the velocity set - if velocity_set == "D3Q27": - self.velocity_set = xlb.velocity_set.D3Q27() - elif velocity_set == "D3Q19": - self.velocity_set = xlb.velocity_set.D3Q19() - else: - raise ValueError("Invalid velocity set") - - # Make grid - self.grid = xlb.grid.WarpGrid(shape=self.shape) - - # Make feilds - self.rho = self.grid.create_field(cardinality=1, precision=xlb.Precision.FP32) - self.u = self.grid.create_field(cardinality=self.velocity_set.d, precision=xlb.Precision.FP32) - self.f0 = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.FP32) - self.f1 = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.FP32) - self.boundary_id = self.grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - self.missing_mask = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.BOOL) - - # Make operators - self.initializer = UniformInitializer( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.momentum_transfer = MomentumTransfer( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - if collision == "BGK": - self.collision = xlb.operator.collision.BGK( - omega=self.omega, - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - elif collision == "KBC": - self.collision = xlb.operator.collision.KBC( - omega=self.omega, - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.stream = xlb.operator.stream.Stream( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=self.lbm_density, - u=(self.base_velocity, 0.0, 0.0), - equilibrium_operator=self.equilibrium, - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=self.collision, - equilibrium=self.equilibrium, - macroscopic=self.macroscopic, - stream=self.stream, - boundary_conditions=[ - self.half_way_bc, - self.full_way_bc, - self.equilibrium_bc, - self.do_nothing_bc - ], - ) - self.planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.stl_boundary_masker = xlb.operator.boundary_masker.STLBoundaryMasker( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - - # Make list to store drag coefficients - self.drag_coefficients = [] - - def initialize_flow(self): - """ - Initialize the flow field - """ - - # Set initial conditions - self.rho, self.u = self.initializer(self.rho, self.u, self.base_velocity) - self.f0 = self.equilibrium(self.rho, self.u, self.f0) - - def initialize_boundary_conditions(self): - """ - Initialize the boundary conditions - """ - - # Set inlet bc (bottom x face) - lower_bound = (0, 1, 1) # no edges - upper_bound = (0, self.ny-1, self.nz-1) - direction = (1, 0, 0) - self.boundary_id, self.missing_mask = self.planar_boundary_masker( - lower_bound, - upper_bound, - direction, - self.equilibrium_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - # Set outlet bc (top x face) - lower_bound = (self.nx-1, 1, 1) - upper_bound = (self.nx-1, self.ny-1, self.nz-1) - direction = (-1, 0, 0) - self.boundary_id, self.missing_mask = self.planar_boundary_masker( - lower_bound, - upper_bound, - direction, - self.do_nothing_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - # Set full way bc (bottom y face) - lower_bound = (0, 0, 0) - upper_bound = (self.nx, 0, self.nz) - direction = (0, 1, 0) - self.boundary_id, self.missing_mask = self.planar_boundary_masker( - lower_bound, - upper_bound, - direction, - self.full_way_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - # Set full way bc (top y face) - lower_bound = (0, self.ny-1, 0) - upper_bound = (self.nx, self.ny-1, self.nz) - direction = (0, -1, 0) - self.boundary_id, self.missing_mask = self.planar_boundary_masker( - lower_bound, - upper_bound, - direction, - self.full_way_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - # Set full way bc (bottom z face) - lower_bound = (0, 0, 0) - upper_bound = (self.nx, self.ny, 0) - direction = (0, 0, 1) - self.boundary_id, self.missing_mask = self.planar_boundary_masker( - lower_bound, - upper_bound, - direction, - self.full_way_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - # Set full way bc (top z face) - lower_bound = (0, 0, self.nz-1) - upper_bound = (self.nx, self.ny, self.nz-1) - direction = (0, 0, -1) - self.boundary_id, self.missing_mask = self.planar_boundary_masker( - lower_bound, - upper_bound, - direction, - self.full_way_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - # Set stl half way bc - self.boundary_id, self.missing_mask = self.stl_boundary_masker( - self.stl_filename, - self.lower_bounds, - (self.dx, self.dx, self.dx), - self.half_way_bc.id, - self.boundary_id, - self.missing_mask, - (0, 0, 0) - ) - - def save_state( - self, - postfix: str, - save_velocity_distribution: bool = False, - ): - """ - Save the solid id array. - """ - - # Create grid - grid = pv.RectilinearGrid( - np.linspace(self.lower_bounds[0], self.upper_bounds[0], self.nx, endpoint=False), - np.linspace(self.lower_bounds[1], self.upper_bounds[1], self.ny, endpoint=False), - np.linspace(self.lower_bounds[2], self.upper_bounds[2], self.nz, endpoint=False), - ) # TODO off by one? - grid["boundary_id"] = self.boundary_id.numpy().flatten("F") - grid["u"] = self.u.numpy().transpose(1, 2, 3, 0).reshape(-1, 3, order="F") - grid["rho"] = self.rho.numpy().flatten("F") - if save_velocity_distribution: - grid["f0"] = self.f0.numpy().transpose(1, 2, 3, 0).reshape(-1, self.velocity_set.q, order="F") - grid.save(f"state_{postfix}.vtk") - - def step(self): - self.f1 = self.stepper(self.f0, self.f1, self.boundary_id, self.missing_mask, 0) - self.f0, self.f1 = self.f1, self.f0 - - def compute_rho_u(self): - self.rho, self.u = self.macroscopic(self.f0, self.rho, self.u) - - def monitor(self): - # Compute the momentum transfer - momentum = self.momentum_transfer(self.f0, self.boundary_id, self.missing_mask)[0] - drag = momentum[0] - lift = momentum[2] - c_d = 2.0 * drag / (self.base_velocity ** 2 * self.cross_section) - c_l = 2.0 * lift / (self.base_velocity ** 2 * self.cross_section) - self.drag_coefficients.append(c_d) - - def plot_drag_coefficient(self): - plt.plot(self.drag_coefficients[-30:]) - plt.xlabel("Time step") - plt.ylabel("Drag coefficient") - plt.savefig("drag_coefficient.png") - plt.close() - - def run(self): - - # Initialize the flow field - self.initialize_flow() - - # Initialize the boundary conditions - self.initialize_boundary_conditions() - - # Compute cross section - np_boundary_id = self.boundary_id.numpy() - cross_section = np.sum(np_boundary_id == self.half_way_bc.id, axis=(0, 1)) - self.cross_section = np.sum(cross_section > 0) - - # Run the simulation - for i in tqdm.tqdm(range(self.nr_steps)): - - # Step - self.step() - - # Monitor - if i % self.monitor_frequency == 0: - self.monitor() - - # Save monitor plot - if i % (self.monitor_frequency * 10) == 0: - self.plot_drag_coefficient() - - # Save state - if i % self.save_state_frequency == 0: - self.compute_rho_u() - self.save_state(str(i).zfill(8)) - -if __name__ == '__main__': - - # Parameters - inlet_velocity = 0.01 # m/s - stl_filename = "fastback_baseline.stl" - lower_bounds = (-4.0, -2.5, -1.5) - upper_bounds = (12.0, 2.5, 2.5) - dx = 0.03 - solve_time = 10000.0 - - # Make wind tunnel - wind_tunnel = WindTunnel( - stl_filename=stl_filename, - inlet_velocity=inlet_velocity, - lower_bounds=lower_bounds, - upper_bounds=upper_bounds, - solve_time=solve_time, - dx=dx, - ) - - # Run the simulation - wind_tunnel.run() - wind_tunnel.save_state("final", save_velocity_distribution=True) - diff --git a/examples/backend_comparisons/README.md b/examples/backend_comparisons/README.md deleted file mode 100644 index a198eb4..0000000 --- a/examples/backend_comparisons/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Performance Comparisons - -This directory contains a minimal LBM implementation in Warp, Numba, and Jax. The -code can be run with the following command: - -```bash -python3 lattice_boltzmann.py -``` - -This will give MLUPs numbers for each implementation. The Warp implementation -is the fastest, followed by Numba, and then Jax. - -This example should be used as a test for properly implementing more backends in -XLB. diff --git a/examples/backend_comparisons/lattice_boltzmann.py b/examples/backend_comparisons/lattice_boltzmann.py deleted file mode 100644 index 57f0e12..0000000 --- a/examples/backend_comparisons/lattice_boltzmann.py +++ /dev/null @@ -1,1138 +0,0 @@ -# Description: This file contains a simple example of using the OOCmap -# decorator to apply a function to a distributed array. -# Solves Lattice Boltzmann Taylor Green vortex decay - -import time -import warp as wp -import matplotlib.pyplot as plt -from tqdm import tqdm -import numpy as np -import cupy as cp -import time -from tqdm import tqdm -from numba import cuda -import numba -import math -import jax.numpy as jnp -import jax -from jax import jit -from functools import partial - -# Initialize Warp -wp.init() - -@wp.func -def warp_set_f( - f: wp.array4d(dtype=float), - value: float, - q: int, - i: int, - j: int, - k: int, - width: int, - height: int, - length: int, -): - # Modulo - if i < 0: - i += width - if j < 0: - j += height - if k < 0: - k += length - if i >= width: - i -= width - if j >= height: - j -= height - if k >= length: - k -= length - f[q, i, j, k] = value - -@wp.kernel -def warp_collide_stream( - f0: wp.array4d(dtype=float), - f1: wp.array4d(dtype=float), - width: int, - height: int, - length: int, - tau: float, -): - - # get index - x, y, z = wp.tid() - - # sample needed points - f_1_1_1 = f0[0, x, y, z] - f_2_1_1 = f0[1, x, y, z] - f_0_1_1 = f0[2, x, y, z] - f_1_2_1 = f0[3, x, y, z] - f_1_0_1 = f0[4, x, y, z] - f_1_1_2 = f0[5, x, y, z] - f_1_1_0 = f0[6, x, y, z] - f_1_2_2 = f0[7, x, y, z] - f_1_0_0 = f0[8, x, y, z] - f_1_2_0 = f0[9, x, y, z] - f_1_0_2 = f0[10, x, y, z] - f_2_1_2 = f0[11, x, y, z] - f_0_1_0 = f0[12, x, y, z] - f_2_1_0 = f0[13, x, y, z] - f_0_1_2 = f0[14, x, y, z] - f_2_2_1 = f0[15, x, y, z] - f_0_0_1 = f0[16, x, y, z] - f_2_0_1 = f0[17, x, y, z] - f_0_2_1 = f0[18, x, y, z] - - # compute u and p - p = (f_1_1_1 - + f_2_1_1 + f_0_1_1 - + f_1_2_1 + f_1_0_1 - + f_1_1_2 + f_1_1_0 - + f_1_2_2 + f_1_0_0 - + f_1_2_0 + f_1_0_2 - + f_2_1_2 + f_0_1_0 - + f_2_1_0 + f_0_1_2 - + f_2_2_1 + f_0_0_1 - + f_2_0_1 + f_0_2_1) - u = (f_2_1_1 - f_0_1_1 - + f_2_1_2 - f_0_1_0 - + f_2_1_0 - f_0_1_2 - + f_2_2_1 - f_0_0_1 - + f_2_0_1 - f_0_2_1) - v = (f_1_2_1 - f_1_0_1 - + f_1_2_2 - f_1_0_0 - + f_1_2_0 - f_1_0_2 - + f_2_2_1 - f_0_0_1 - - f_2_0_1 + f_0_2_1) - w = (f_1_1_2 - f_1_1_0 - + f_1_2_2 - f_1_0_0 - - f_1_2_0 + f_1_0_2 - + f_2_1_2 - f_0_1_0 - - f_2_1_0 + f_0_1_2) - res_p = 1.0 / p - u = u * res_p - v = v * res_p - w = w * res_p - uxu = u * u + v * v + w * w - - # compute e dot u - exu_1_1_1 = 0 - exu_2_1_1 = u - exu_0_1_1 = -u - exu_1_2_1 = v - exu_1_0_1 = -v - exu_1_1_2 = w - exu_1_1_0 = -w - exu_1_2_2 = v + w - exu_1_0_0 = -v - w - exu_1_2_0 = v - w - exu_1_0_2 = -v + w - exu_2_1_2 = u + w - exu_0_1_0 = -u - w - exu_2_1_0 = u - w - exu_0_1_2 = -u + w - exu_2_2_1 = u + v - exu_0_0_1 = -u - v - exu_2_0_1 = u - v - exu_0_2_1 = -u + v - - # compute equilibrium dist - factor_1 = 1.5 - factor_2 = 4.5 - weight_0 = 0.33333333 - weight_1 = 0.05555555 - weight_2 = 0.02777777 - f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + 1.0)) - f_eq_2_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + 1.0)) - f_eq_0_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + 1.0)) - f_eq_1_2_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) - f_eq_1_0_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) - f_eq_1_1_2 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + 1.0)) - f_eq_1_1_0 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + 1.0)) - f_eq_1_2_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + 1.0)) - f_eq_1_0_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + 1.0)) - f_eq_1_2_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + 1.0)) - f_eq_1_0_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + 1.0)) - f_eq_2_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + 1.0)) - f_eq_0_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + 1.0)) - f_eq_2_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + 1.0)) - f_eq_0_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + 1.0)) - f_eq_2_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + 1.0)) - f_eq_0_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + 1.0)) - f_eq_2_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + 1.0)) - f_eq_0_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + 1.0)) - - # set next lattice state - inv_tau = (1.0 / tau) - warp_set_f(f1, f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1), 0, x, y, z, width, height, length) - warp_set_f(f1, f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1), 1, x + 1, y, z, width, height, length) - warp_set_f(f1, f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1), 2, x - 1, y, z, width, height, length) - warp_set_f(f1, f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1), 3, x, y + 1, z, width, height, length) - warp_set_f(f1, f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1), 4, x, y - 1, z, width, height, length) - warp_set_f(f1, f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2), 5, x, y, z + 1, width, height, length) - warp_set_f(f1, f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0), 6, x, y, z - 1, width, height, length) - warp_set_f(f1, f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2), 7, x, y + 1, z + 1, width, height, length) - warp_set_f(f1, f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0), 8, x, y - 1, z - 1, width, height, length) - warp_set_f(f1, f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0), 9, x, y + 1, z - 1, width, height, length) - warp_set_f(f1, f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2), 10, x, y - 1, z + 1, width, height, length) - warp_set_f(f1, f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2), 11, x + 1, y, z + 1, width, height, length) - warp_set_f(f1, f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0), 12, x - 1, y, z - 1, width, height, length) - warp_set_f(f1, f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0), 13, x + 1, y, z - 1, width, height, length) - warp_set_f(f1, f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2), 14, x - 1, y, z + 1, width, height, length) - warp_set_f(f1, f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1), 15, x + 1, y + 1, z, width, height, length) - warp_set_f(f1, f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1), 16, x - 1, y - 1, z, width, height, length) - warp_set_f(f1, f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1), 17, x + 1, y - 1, z, width, height, length) - warp_set_f(f1, f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1), 18, x - 1, y + 1, z, width, height, length) - -@wp.kernel -def warp_initialize_taylor_green( - f: wp.array4d(dtype=wp.float32), - dx: float, - vel: float, - start_x: int, - start_y: int, - start_z: int, -): - - # get index - i, j, k = wp.tid() - - # get real pos - x = wp.float(i + start_x) * dx - y = wp.float(j + start_y) * dx - z = wp.float(k + start_z) * dx - - # compute u - u = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - v = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) - w = 0.0 - - # compute p - p = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * (wp.cos(2.0 * x) + wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0)) - + 1.0 - ) - - # compute u X u - uxu = u * u + v * v + w * w - - # compute e dot u - exu_1_1_1 = 0.0 - exu_2_1_1 = u - exu_0_1_1 = -u - exu_1_2_1 = v - exu_1_0_1 = -v - exu_1_1_2 = w - exu_1_1_0 = -w - exu_1_2_2 = v + w - exu_1_0_0 = -v - w - exu_1_2_0 = v - w - exu_1_0_2 = -v + w - exu_2_1_2 = u + w - exu_0_1_0 = -u - w - exu_2_1_0 = u - w - exu_0_1_2 = -u + w - exu_2_2_1 = u + v - exu_0_0_1 = -u - v - exu_2_0_1 = u - v - exu_0_2_1 = -u + v - - # compute equilibrium dist - factor_1 = 1.5 - factor_2 = 4.5 - weight_0 = 0.33333333 - weight_1 = 0.05555555 - weight_2 = 0.02777777 - f_eq_1_1_1 = weight_0 * (p * (factor_1 * (-uxu) + 1.0)) - f_eq_2_1_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_2_1_1 - uxu) - + factor_2 * (exu_2_1_1 * exu_2_1_1) - + 1.0 - ) - ) - f_eq_0_1_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_0_1_1 - uxu) - + factor_2 * (exu_0_1_1 * exu_0_1_1) - + 1.0 - ) - ) - f_eq_1_2_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_2_1 - uxu) - + factor_2 * (exu_1_2_1 * exu_1_2_1) - + 1.0 - ) - ) - f_eq_1_0_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_0_1 - uxu) - + factor_2 * (exu_1_2_1 * exu_1_2_1) - + 1.0 - ) - ) - f_eq_1_1_2 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_1_2 - uxu) - + factor_2 * (exu_1_1_2 * exu_1_1_2) - + 1.0 - ) - ) - f_eq_1_1_0 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_1_0 - uxu) - + factor_2 * (exu_1_1_0 * exu_1_1_0) - + 1.0 - ) - ) - f_eq_1_2_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_2_2 - uxu) - + factor_2 * (exu_1_2_2 * exu_1_2_2) - + 1.0 - ) - ) - f_eq_1_0_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_0_0 - uxu) - + factor_2 * (exu_1_0_0 * exu_1_0_0) - + 1.0 - ) - ) - f_eq_1_2_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_2_0 - uxu) - + factor_2 * (exu_1_2_0 * exu_1_2_0) - + 1.0 - ) - ) - f_eq_1_0_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_0_2 - uxu) - + factor_2 * (exu_1_0_2 * exu_1_0_2) - + 1.0 - ) - ) - f_eq_2_1_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_1_2 - uxu) - + factor_2 * (exu_2_1_2 * exu_2_1_2) - + 1.0 - ) - ) - f_eq_0_1_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_1_0 - uxu) - + factor_2 * (exu_0_1_0 * exu_0_1_0) - + 1.0 - ) - ) - f_eq_2_1_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_1_0 - uxu) - + factor_2 * (exu_2_1_0 * exu_2_1_0) - + 1.0 - ) - ) - f_eq_0_1_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_1_2 - uxu) - + factor_2 * (exu_0_1_2 * exu_0_1_2) - + 1.0 - ) - ) - f_eq_2_2_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_2_1 - uxu) - + factor_2 * (exu_2_2_1 * exu_2_2_1) - + 1.0 - ) - ) - f_eq_0_0_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_0_1 - uxu) - + factor_2 * (exu_0_0_1 * exu_0_0_1) - + 1.0 - ) - ) - f_eq_2_0_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_0_1 - uxu) - + factor_2 * (exu_2_0_1 * exu_2_0_1) - + 1.0 - ) - ) - f_eq_0_2_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_2_1 - uxu) - + factor_2 * (exu_0_2_1 * exu_0_2_1) - + 1.0 - ) - ) - - # set next lattice state - f[0, i, j, k] = f_eq_1_1_1 - f[1, i, j, k] = f_eq_2_1_1 - f[2, i, j, k] = f_eq_0_1_1 - f[3, i, j, k] = f_eq_1_2_1 - f[4, i, j, k] = f_eq_1_0_1 - f[5, i, j, k] = f_eq_1_1_2 - f[6, i, j, k] = f_eq_1_1_0 - f[7, i, j, k] = f_eq_1_2_2 - f[8, i, j, k] = f_eq_1_0_0 - f[9, i, j, k] = f_eq_1_2_0 - f[10, i, j, k] = f_eq_1_0_2 - f[11, i, j, k] = f_eq_2_1_2 - f[12, i, j, k] = f_eq_0_1_0 - f[13, i, j, k] = f_eq_2_1_0 - f[14, i, j, k] = f_eq_0_1_2 - f[15, i, j, k] = f_eq_2_2_1 - f[16, i, j, k] = f_eq_0_0_1 - f[17, i, j, k] = f_eq_2_0_1 - f[18, i, j, k] = f_eq_0_2_1 - - -def warp_initialize_f(f, dx: float): - # Get inputs - cs = 1.0 / np.sqrt(3.0) - vel = 0.1 * cs - - # Launch kernel - wp.launch( - kernel=warp_initialize_taylor_green, - dim=list(f.shape[1:]), - inputs=[f, dx, vel, 0, 0, 0], - device=f.device, - ) - - return f - - -def warp_apply_collide_stream(f0, f1, tau: float): - # Apply streaming and collision step - wp.launch( - kernel=warp_collide_stream, - dim=list(f0.shape[1:]), - inputs=[f0, f1, f0.shape[1], f0.shape[2], f0.shape[3], tau], - device=f0.device, - ) - - return f1, f0 - - -@cuda.jit("void(float32[:,:,:,::1], float32, int32, int32, int32, int32, int32, int32, int32)", device=True) -def numba_set_f( - f: numba.cuda.cudadrv.devicearray.DeviceNDArray, - value: float, - q: int, - i: int, - j: int, - k: int, - width: int, - height: int, - length: int, -): - # Modulo - if i < 0: - i += width - if j < 0: - j += height - if k < 0: - k += length - if i >= width: - i -= width - if j >= height: - j -= height - if k >= length: - k -= length - f[i, j, k, q] = value - -#@cuda.jit -@cuda.jit("void(float32[:,:,:,::1], float32[:,:,:,::1], int32, int32, int32, float32)") -def numba_collide_stream( - f0: numba.cuda.cudadrv.devicearray.DeviceNDArray, - f1: numba.cuda.cudadrv.devicearray.DeviceNDArray, - width: int, - height: int, - length: int, - tau: float, -): - - x, y, z = cuda.grid(3) - - # sample needed points - f_1_1_1 = f0[x, y, z, 0] - f_2_1_1 = f0[x, y, z, 1] - f_0_1_1 = f0[x, y, z, 2] - f_1_2_1 = f0[x, y, z, 3] - f_1_0_1 = f0[x, y, z, 4] - f_1_1_2 = f0[x, y, z, 5] - f_1_1_0 = f0[x, y, z, 6] - f_1_2_2 = f0[x, y, z, 7] - f_1_0_0 = f0[x, y, z, 8] - f_1_2_0 = f0[x, y, z, 9] - f_1_0_2 = f0[x, y, z, 10] - f_2_1_2 = f0[x, y, z, 11] - f_0_1_0 = f0[x, y, z, 12] - f_2_1_0 = f0[x, y, z, 13] - f_0_1_2 = f0[x, y, z, 14] - f_2_2_1 = f0[x, y, z, 15] - f_0_0_1 = f0[x, y, z, 16] - f_2_0_1 = f0[x, y, z, 17] - f_0_2_1 = f0[x, y, z, 18] - - # compute u and p - p = (f_1_1_1 - + f_2_1_1 + f_0_1_1 - + f_1_2_1 + f_1_0_1 - + f_1_1_2 + f_1_1_0 - + f_1_2_2 + f_1_0_0 - + f_1_2_0 + f_1_0_2 - + f_2_1_2 + f_0_1_0 - + f_2_1_0 + f_0_1_2 - + f_2_2_1 + f_0_0_1 - + f_2_0_1 + f_0_2_1) - u = (f_2_1_1 - f_0_1_1 - + f_2_1_2 - f_0_1_0 - + f_2_1_0 - f_0_1_2 - + f_2_2_1 - f_0_0_1 - + f_2_0_1 - f_0_2_1) - v = (f_1_2_1 - f_1_0_1 - + f_1_2_2 - f_1_0_0 - + f_1_2_0 - f_1_0_2 - + f_2_2_1 - f_0_0_1 - - f_2_0_1 + f_0_2_1) - w = (f_1_1_2 - f_1_1_0 - + f_1_2_2 - f_1_0_0 - - f_1_2_0 + f_1_0_2 - + f_2_1_2 - f_0_1_0 - - f_2_1_0 + f_0_1_2) - res_p = numba.float32(1.0) / p - u = u * res_p - v = v * res_p - w = w * res_p - uxu = u * u + v * v + w * w - - # compute e dot u - exu_1_1_1 = numba.float32(0.0) - exu_2_1_1 = u - exu_0_1_1 = -u - exu_1_2_1 = v - exu_1_0_1 = -v - exu_1_1_2 = w - exu_1_1_0 = -w - exu_1_2_2 = v + w - exu_1_0_0 = -v - w - exu_1_2_0 = v - w - exu_1_0_2 = -v + w - exu_2_1_2 = u + w - exu_0_1_0 = -u - w - exu_2_1_0 = u - w - exu_0_1_2 = -u + w - exu_2_2_1 = u + v - exu_0_0_1 = -u - v - exu_2_0_1 = u - v - exu_0_2_1 = -u + v - - # compute equilibrium dist - factor_1 = numba.float32(1.5) - factor_2 = numba.float32(4.5) - weight_0 = numba.float32(0.33333333) - weight_1 = numba.float32(0.05555555) - weight_2 = numba.float32(0.02777777) - - f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + numba.float32(1.0))) - f_eq_2_1_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + numba.float32(1.0))) - f_eq_0_1_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + numba.float32(1.0))) - f_eq_1_2_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + numba.float32(1.0))) - f_eq_1_0_1 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + numba.float32(1.0))) - f_eq_1_1_2 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + numba.float32(1.0))) - f_eq_1_1_0 = weight_1 * (p * (factor_1 * (numba.float32(2.0) * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + numba.float32(1.0))) - f_eq_1_2_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + numba.float32(1.0))) - f_eq_1_0_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + numba.float32(1.0))) - f_eq_1_2_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + numba.float32(1.0))) - f_eq_1_0_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + numba.float32(1.0))) - f_eq_2_1_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + numba.float32(1.0))) - f_eq_0_1_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + numba.float32(1.0))) - f_eq_2_1_0 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + numba.float32(1.0))) - f_eq_0_1_2 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + numba.float32(1.0))) - f_eq_2_2_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + numba.float32(1.0))) - f_eq_0_0_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + numba.float32(1.0))) - f_eq_2_0_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + numba.float32(1.0))) - f_eq_0_2_1 = weight_2 * (p * (factor_1 * (numba.float32(2.0) * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + numba.float32(1.0))) - - # set next lattice state - inv_tau = numba.float32((numba.float32(1.0) / tau)) - numba_set_f(f1, f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1), 0, x, y, z, width, height, length) - numba_set_f(f1, f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1), 1, x + 1, y, z, width, height, length) - numba_set_f(f1, f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1), 2, x - 1, y, z, width, height, length) - numba_set_f(f1, f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1), 3, x, y + 1, z, width, height, length) - numba_set_f(f1, f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1), 4, x, y - 1, z, width, height, length) - numba_set_f(f1, f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2), 5, x, y, z + 1, width, height, length) - numba_set_f(f1, f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0), 6, x, y, z - 1, width, height, length) - numba_set_f(f1, f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2), 7, x, y + 1, z + 1, width, height, length) - numba_set_f(f1, f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0), 8, x, y - 1, z - 1, width, height, length) - numba_set_f(f1, f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0), 9, x, y + 1, z - 1, width, height, length) - numba_set_f(f1, f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2), 10, x, y - 1, z + 1, width, height, length) - numba_set_f(f1, f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2), 11, x + 1, y, z + 1, width, height, length) - numba_set_f(f1, f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0), 12, x - 1, y, z - 1, width, height, length) - numba_set_f(f1, f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0), 13, x + 1, y, z - 1, width, height, length) - numba_set_f(f1, f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2), 14, x - 1, y, z + 1, width, height, length) - numba_set_f(f1, f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1), 15, x + 1, y + 1, z, width, height, length) - numba_set_f(f1, f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1), 16, x - 1, y - 1, z, width, height, length) - numba_set_f(f1, f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1), 17, x + 1, y - 1, z, width, height, length) - numba_set_f(f1, f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1), 18, x - 1, y + 1, z, width, height, length) - - -@cuda.jit -def numba_initialize_taylor_green( - f, - dx, - vel, - start_x, - start_y, - start_z, -): - - i, j, k = cuda.grid(3) - - # get real pos - x = numba.float32(i + start_x) * dx - y = numba.float32(j + start_y) * dx - z = numba.float32(k + start_z) * dx - - # compute u - u = vel * math.sin(x) * math.cos(y) * math.cos(z) - v = -vel * math.cos(x) * math.sin(y) * math.cos(z) - w = 0.0 - - # compute p - p = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * (math.cos(2.0 * x) + math.cos(2.0 * y) * (math.cos(2.0 * z) + 2.0)) - + 1.0 - ) - - # compute u X u - uxu = u * u + v * v + w * w - - # compute e dot u - exu_1_1_1 = 0.0 - exu_2_1_1 = u - exu_0_1_1 = -u - exu_1_2_1 = v - exu_1_0_1 = -v - exu_1_1_2 = w - exu_1_1_0 = -w - exu_1_2_2 = v + w - exu_1_0_0 = -v - w - exu_1_2_0 = v - w - exu_1_0_2 = -v + w - exu_2_1_2 = u + w - exu_0_1_0 = -u - w - exu_2_1_0 = u - w - exu_0_1_2 = -u + w - exu_2_2_1 = u + v - exu_0_0_1 = -u - v - exu_2_0_1 = u - v - exu_0_2_1 = -u + v - - # compute equilibrium dist - factor_1 = 1.5 - factor_2 = 4.5 - weight_0 = 0.33333333 - weight_1 = 0.05555555 - weight_2 = 0.02777777 - f_eq_1_1_1 = weight_0 * (p * (factor_1 * (-uxu) + 1.0)) - f_eq_2_1_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_2_1_1 - uxu) - + factor_2 * (exu_2_1_1 * exu_2_1_1) - + 1.0 - ) - ) - f_eq_0_1_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_0_1_1 - uxu) - + factor_2 * (exu_0_1_1 * exu_0_1_1) - + 1.0 - ) - ) - f_eq_1_2_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_2_1 - uxu) - + factor_2 * (exu_1_2_1 * exu_1_2_1) - + 1.0 - ) - ) - f_eq_1_0_1 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_0_1 - uxu) - + factor_2 * (exu_1_2_1 * exu_1_2_1) - + 1.0 - ) - ) - f_eq_1_1_2 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_1_2 - uxu) - + factor_2 * (exu_1_1_2 * exu_1_1_2) - + 1.0 - ) - ) - f_eq_1_1_0 = weight_1 * ( - p - * ( - factor_1 * (2.0 * exu_1_1_0 - uxu) - + factor_2 * (exu_1_1_0 * exu_1_1_0) - + 1.0 - ) - ) - f_eq_1_2_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_2_2 - uxu) - + factor_2 * (exu_1_2_2 * exu_1_2_2) - + 1.0 - ) - ) - f_eq_1_0_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_0_0 - uxu) - + factor_2 * (exu_1_0_0 * exu_1_0_0) - + 1.0 - ) - ) - f_eq_1_2_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_2_0 - uxu) - + factor_2 * (exu_1_2_0 * exu_1_2_0) - + 1.0 - ) - ) - f_eq_1_0_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_1_0_2 - uxu) - + factor_2 * (exu_1_0_2 * exu_1_0_2) - + 1.0 - ) - ) - f_eq_2_1_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_1_2 - uxu) - + factor_2 * (exu_2_1_2 * exu_2_1_2) - + 1.0 - ) - ) - f_eq_0_1_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_1_0 - uxu) - + factor_2 * (exu_0_1_0 * exu_0_1_0) - + 1.0 - ) - ) - f_eq_2_1_0 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_1_0 - uxu) - + factor_2 * (exu_2_1_0 * exu_2_1_0) - + 1.0 - ) - ) - f_eq_0_1_2 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_1_2 - uxu) - + factor_2 * (exu_0_1_2 * exu_0_1_2) - + 1.0 - ) - ) - f_eq_2_2_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_2_1 - uxu) - + factor_2 * (exu_2_2_1 * exu_2_2_1) - + 1.0 - ) - ) - f_eq_0_0_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_0_1 - uxu) - + factor_2 * (exu_0_0_1 * exu_0_0_1) - + 1.0 - ) - ) - f_eq_2_0_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_2_0_1 - uxu) - + factor_2 * (exu_2_0_1 * exu_2_0_1) - + 1.0 - ) - ) - f_eq_0_2_1 = weight_2 * ( - p - * ( - factor_1 * (2.0 * exu_0_2_1 - uxu) - + factor_2 * (exu_0_2_1 * exu_0_2_1) - + 1.0 - ) - ) - - # set next lattice state - f[i, j, k, 0] = f_eq_1_1_1 - f[i, j, k, 1] = f_eq_2_1_1 - f[i, j, k, 2] = f_eq_0_1_1 - f[i, j, k, 3] = f_eq_1_2_1 - f[i, j, k, 4] = f_eq_1_0_1 - f[i, j, k, 5] = f_eq_1_1_2 - f[i, j, k, 6] = f_eq_1_1_0 - f[i, j, k, 7] = f_eq_1_2_2 - f[i, j, k, 8] = f_eq_1_0_0 - f[i, j, k, 9] = f_eq_1_2_0 - f[ i, j, k, 10] = f_eq_1_0_2 - f[ i, j, k, 11] = f_eq_2_1_2 - f[ i, j, k, 12] = f_eq_0_1_0 - f[ i, j, k, 13] = f_eq_2_1_0 - f[ i, j, k, 14] = f_eq_0_1_2 - f[ i, j, k, 15] = f_eq_2_2_1 - f[ i, j, k, 16] = f_eq_0_0_1 - f[ i, j, k, 17] = f_eq_2_0_1 - f[ i, j, k, 18] = f_eq_0_2_1 - - -def numba_initialize_f(f, dx: float): - # Get inputs - cs = 1.0 / np.sqrt(3.0) - vel = 0.1 * cs - - # Launch kernel - blockdim = (16, 16, 1) - griddim = ( - int(np.ceil(f.shape[0] / blockdim[0])), - int(np.ceil(f.shape[1] / blockdim[1])), - int(np.ceil(f.shape[2] / blockdim[2])), - ) - numba_initialize_taylor_green[griddim, blockdim]( - f, dx, vel, 0, 0, 0 - ) - - return f - -def numba_apply_collide_stream(f0, f1, tau: float): - # Apply streaming and collision step - blockdim = (8, 8, 8) - griddim = ( - int(np.ceil(f0.shape[0] / blockdim[0])), - int(np.ceil(f0.shape[1] / blockdim[1])), - int(np.ceil(f0.shape[2] / blockdim[2])), - ) - numba_collide_stream[griddim, blockdim]( - f0, f1, f0.shape[0], f0.shape[1], f0.shape[2], tau - ) - - return f1, f0 - -@partial(jit, static_argnums=(1), donate_argnums=(0)) -def jax_apply_collide_stream(f, tau: float): - - # Get f directions - f_1_1_1 = f[:, :, :, 0] - f_2_1_1 = f[:, :, :, 1] - f_0_1_1 = f[:, :, :, 2] - f_1_2_1 = f[:, :, :, 3] - f_1_0_1 = f[:, :, :, 4] - f_1_1_2 = f[:, :, :, 5] - f_1_1_0 = f[:, :, :, 6] - f_1_2_2 = f[:, :, :, 7] - f_1_0_0 = f[:, :, :, 8] - f_1_2_0 = f[:, :, :, 9] - f_1_0_2 = f[:, :, :, 10] - f_2_1_2 = f[:, :, :, 11] - f_0_1_0 = f[:, :, :, 12] - f_2_1_0 = f[:, :, :, 13] - f_0_1_2 = f[:, :, :, 14] - f_2_2_1 = f[:, :, :, 15] - f_0_0_1 = f[:, :, :, 16] - f_2_0_1 = f[:, :, :, 17] - f_0_2_1 = f[:, :, :, 18] - - # compute u and p - p = (f_1_1_1 - + f_2_1_1 + f_0_1_1 - + f_1_2_1 + f_1_0_1 - + f_1_1_2 + f_1_1_0 - + f_1_2_2 + f_1_0_0 - + f_1_2_0 + f_1_0_2 - + f_2_1_2 + f_0_1_0 - + f_2_1_0 + f_0_1_2 - + f_2_2_1 + f_0_0_1 - + f_2_0_1 + f_0_2_1) - u = (f_2_1_1 - f_0_1_1 - + f_2_1_2 - f_0_1_0 - + f_2_1_0 - f_0_1_2 - + f_2_2_1 - f_0_0_1 - + f_2_0_1 - f_0_2_1) - v = (f_1_2_1 - f_1_0_1 - + f_1_2_2 - f_1_0_0 - + f_1_2_0 - f_1_0_2 - + f_2_2_1 - f_0_0_1 - - f_2_0_1 + f_0_2_1) - w = (f_1_1_2 - f_1_1_0 - + f_1_2_2 - f_1_0_0 - - f_1_2_0 + f_1_0_2 - + f_2_1_2 - f_0_1_0 - - f_2_1_0 + f_0_1_2) - res_p = 1.0 / p - u = u * res_p - v = v * res_p - w = w * res_p - uxu = u * u + v * v + w * w - - # compute e dot u - exu_1_1_1 = 0 - exu_2_1_1 = u - exu_0_1_1 = -u - exu_1_2_1 = v - exu_1_0_1 = -v - exu_1_1_2 = w - exu_1_1_0 = -w - exu_1_2_2 = v + w - exu_1_0_0 = -v - w - exu_1_2_0 = v - w - exu_1_0_2 = -v + w - exu_2_1_2 = u + w - exu_0_1_0 = -u - w - exu_2_1_0 = u - w - exu_0_1_2 = -u + w - exu_2_2_1 = u + v - exu_0_0_1 = -u - v - exu_2_0_1 = u - v - exu_0_2_1 = -u + v - - # compute equilibrium dist - factor_1 = 1.5 - factor_2 = 4.5 - weight_0 = 0.33333333 - weight_1 = 0.05555555 - weight_2 = 0.02777777 - f_eq_1_1_1 = weight_0 * (p * (factor_1 * (- uxu) + 1.0)) - f_eq_2_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_2_1_1 - uxu) + factor_2 * (exu_2_1_1 * exu_2_1_1) + 1.0)) - f_eq_0_1_1 = weight_1 * (p * (factor_1 * (2.0 * exu_0_1_1 - uxu) + factor_2 * (exu_0_1_1 * exu_0_1_1) + 1.0)) - f_eq_1_2_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_2_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) - f_eq_1_0_1 = weight_1 * (p * (factor_1 * (2.0 * exu_1_0_1 - uxu) + factor_2 * (exu_1_2_1 * exu_1_2_1) + 1.0)) - f_eq_1_1_2 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_2 - uxu) + factor_2 * (exu_1_1_2 * exu_1_1_2) + 1.0)) - f_eq_1_1_0 = weight_1 * (p * (factor_1 * (2.0 * exu_1_1_0 - uxu) + factor_2 * (exu_1_1_0 * exu_1_1_0) + 1.0)) - f_eq_1_2_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_2 - uxu) + factor_2 * (exu_1_2_2 * exu_1_2_2) + 1.0)) - f_eq_1_0_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_0 - uxu) + factor_2 * (exu_1_0_0 * exu_1_0_0) + 1.0)) - f_eq_1_2_0 = weight_2 * (p * (factor_1 * (2.0 * exu_1_2_0 - uxu) + factor_2 * (exu_1_2_0 * exu_1_2_0) + 1.0)) - f_eq_1_0_2 = weight_2 * (p * (factor_1 * (2.0 * exu_1_0_2 - uxu) + factor_2 * (exu_1_0_2 * exu_1_0_2) + 1.0)) - f_eq_2_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_2 - uxu) + factor_2 * (exu_2_1_2 * exu_2_1_2) + 1.0)) - f_eq_0_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_0 - uxu) + factor_2 * (exu_0_1_0 * exu_0_1_0) + 1.0)) - f_eq_2_1_0 = weight_2 * (p * (factor_1 * (2.0 * exu_2_1_0 - uxu) + factor_2 * (exu_2_1_0 * exu_2_1_0) + 1.0)) - f_eq_0_1_2 = weight_2 * (p * (factor_1 * (2.0 * exu_0_1_2 - uxu) + factor_2 * (exu_0_1_2 * exu_0_1_2) + 1.0)) - f_eq_2_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_2_1 - uxu) + factor_2 * (exu_2_2_1 * exu_2_2_1) + 1.0)) - f_eq_0_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_0_1 - uxu) + factor_2 * (exu_0_0_1 * exu_0_0_1) + 1.0)) - f_eq_2_0_1 = weight_2 * (p * (factor_1 * (2.0 * exu_2_0_1 - uxu) + factor_2 * (exu_2_0_1 * exu_2_0_1) + 1.0)) - f_eq_0_2_1 = weight_2 * (p * (factor_1 * (2.0 * exu_0_2_1 - uxu) + factor_2 * (exu_0_2_1 * exu_0_2_1) + 1.0)) - - # set next lattice state - inv_tau = (1.0 / tau) - f_1_1_1 = f_1_1_1 - inv_tau * (f_1_1_1 - f_eq_1_1_1) - f_2_1_1 = f_2_1_1 - inv_tau * (f_2_1_1 - f_eq_2_1_1) - f_0_1_1 = f_0_1_1 - inv_tau * (f_0_1_1 - f_eq_0_1_1) - f_1_2_1 = f_1_2_1 - inv_tau * (f_1_2_1 - f_eq_1_2_1) - f_1_0_1 = f_1_0_1 - inv_tau * (f_1_0_1 - f_eq_1_0_1) - f_1_1_2 = f_1_1_2 - inv_tau * (f_1_1_2 - f_eq_1_1_2) - f_1_1_0 = f_1_1_0 - inv_tau * (f_1_1_0 - f_eq_1_1_0) - f_1_2_2 = f_1_2_2 - inv_tau * (f_1_2_2 - f_eq_1_2_2) - f_1_0_0 = f_1_0_0 - inv_tau * (f_1_0_0 - f_eq_1_0_0) - f_1_2_0 = f_1_2_0 - inv_tau * (f_1_2_0 - f_eq_1_2_0) - f_1_0_2 = f_1_0_2 - inv_tau * (f_1_0_2 - f_eq_1_0_2) - f_2_1_2 = f_2_1_2 - inv_tau * (f_2_1_2 - f_eq_2_1_2) - f_0_1_0 = f_0_1_0 - inv_tau * (f_0_1_0 - f_eq_0_1_0) - f_2_1_0 = f_2_1_0 - inv_tau * (f_2_1_0 - f_eq_2_1_0) - f_0_1_2 = f_0_1_2 - inv_tau * (f_0_1_2 - f_eq_0_1_2) - f_2_2_1 = f_2_2_1 - inv_tau * (f_2_2_1 - f_eq_2_2_1) - f_0_0_1 = f_0_0_1 - inv_tau * (f_0_0_1 - f_eq_0_0_1) - f_2_0_1 = f_2_0_1 - inv_tau * (f_2_0_1 - f_eq_2_0_1) - f_0_2_1 = f_0_2_1 - inv_tau * (f_0_2_1 - f_eq_0_2_1) - - # Roll fs and concatenate - f_2_1_1 = jnp.roll(f_2_1_1, -1, axis=0) - f_0_1_1 = jnp.roll(f_0_1_1, 1, axis=0) - f_1_2_1 = jnp.roll(f_1_2_1, -1, axis=1) - f_1_0_1 = jnp.roll(f_1_0_1, 1, axis=1) - f_1_1_2 = jnp.roll(f_1_1_2, -1, axis=2) - f_1_1_0 = jnp.roll(f_1_1_0, 1, axis=2) - f_1_2_2 = jnp.roll(jnp.roll(f_1_2_2, -1, axis=1), -1, axis=2) - f_1_0_0 = jnp.roll(jnp.roll(f_1_0_0, 1, axis=1), 1, axis=2) - f_1_2_0 = jnp.roll(jnp.roll(f_1_2_0, -1, axis=1), 1, axis=2) - f_1_0_2 = jnp.roll(jnp.roll(f_1_0_2, 1, axis=1), -1, axis=2) - f_2_1_2 = jnp.roll(jnp.roll(f_2_1_2, -1, axis=0), -1, axis=2) - f_0_1_0 = jnp.roll(jnp.roll(f_0_1_0, 1, axis=0), 1, axis=2) - f_2_1_0 = jnp.roll(jnp.roll(f_2_1_0, -1, axis=0), 1, axis=2) - f_0_1_2 = jnp.roll(jnp.roll(f_0_1_2, 1, axis=0), -1, axis=2) - f_2_2_1 = jnp.roll(jnp.roll(f_2_2_1, -1, axis=0), -1, axis=1) - f_0_0_1 = jnp.roll(jnp.roll(f_0_0_1, 1, axis=0), 1, axis=1) - f_2_0_1 = jnp.roll(jnp.roll(f_2_0_1, -1, axis=0), 1, axis=1) - f_0_2_1 = jnp.roll(jnp.roll(f_0_2_1, 1, axis=0), -1, axis=1) - - return jnp.stack( - [ - f_1_1_1, - f_2_1_1, - f_0_1_1, - f_1_2_1, - f_1_0_1, - f_1_1_2, - f_1_1_0, - f_1_2_2, - f_1_0_0, - f_1_2_0, - f_1_0_2, - f_2_1_2, - f_0_1_0, - f_2_1_0, - f_0_1_2, - f_2_2_1, - f_0_0_1, - f_2_0_1, - f_0_2_1, - ], - axis=-1, - ) - - - -if __name__ == "__main__": - - # Sim Parameters - n = 256 - tau = 0.505 - dx = 2.0 * np.pi / n - nr_steps = 128 - - # Bar plot - backend = [] - mlups = [] - - ######### Warp ######### - # Make f0, f1 - f0 = wp.empty((19, n, n, n), dtype=wp.float32, device="cuda:0") - f1 = wp.empty((19, n, n, n), dtype=wp.float32, device="cuda:0") - - # Initialize f0 - f0 = warp_initialize_f(f0, dx) - - # Apply streaming and collision - t0 = time.time() - for _ in tqdm(range(nr_steps)): - f0, f1 = warp_apply_collide_stream(f0, f1, tau) - wp.synchronize() - t1 = time.time() - - # Compute MLUPS - mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 - backend.append("Warp") - print(mlups) - exit() - mlups.append(mlups) - - # Plot results - np_f = f0.numpy() - plt.imshow(np_f[3, :, :, 0]) - plt.colorbar() - plt.savefig("warp_f_.png") - plt.close() - - ######### Numba ######### - # Make f0, f1 - f0 = cp.ascontiguousarray(cp.empty((n, n, n, 19), dtype=np.float32)) - f1 = cp.ascontiguousarray(cp.empty((n, n, n, 19), dtype=np.float32)) - - # Initialize f0 - f0 = numba_initialize_f(f0, dx) - - # Apply streaming and collision - t0 = time.time() - for _ in tqdm(range(nr_steps)): - f0, f1 = numba_apply_collide_stream(f0, f1, tau) - cp.cuda.Device(0).synchronize() - t1 = time.time() - - # Compute MLUPS - mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 - backend.append("Numba") - mlups.append(mlups) - - # Plot results - np_f = f0 - plt.imshow(np_f[:, :, 0, 3].get()) - plt.colorbar() - plt.savefig("numba_f_.png") - plt.close() - - ######### Jax ######### - # Make f0, f1 - f = jnp.zeros((n, n, n, 19), dtype=jnp.float32) - - # Initialize f0 - # f = jax_initialize_f(f, dx) - - # Apply streaming and collision - t0 = time.time() - for _ in tqdm(range(nr_steps)): - f = jax_apply_collide_stream(f, tau) - t1 = time.time() - - # Compute MLUPS - mlups = (nr_steps * n * n * n) / (t1 - t0) / 1e6 - backend.append("Jax") - mlups.append(mlups) - - # Plot results - np_f = f - plt.imshow(np_f[:, :, 0, 3]) - plt.colorbar() - plt.savefig("jax_f_.png") - plt.close() - - - - diff --git a/examples/backend_comparisons/small_example.py b/examples/backend_comparisons/small_example.py deleted file mode 100644 index 6d6213b..0000000 --- a/examples/backend_comparisons/small_example.py +++ /dev/null @@ -1,327 +0,0 @@ -# Simple example of functions to generate a warp kernel for LBM - -import warp as wp -import numpy as np - -# Initialize Warp -wp.init() - -def make_warp_kernel( - velocity_weight, - velocity_set, - dtype=wp.float32, - dim=3, # slightly hard coded for 3d right now - q=19, -): - - # Make needed vector classes - lattice_vec = wp.vec(q, dtype=dtype) - velocity_vec = wp.vec(dim, dtype=dtype) - - # Make array type - if dim == 2: - array_type = wp.array3d(dtype=dtype) - elif dim == 3: - array_type = wp.array4d(dtype=dtype) - - # Make everything constant - velocity_weight = wp.constant(velocity_weight) - velocity_set = wp.constant(velocity_set) - q = wp.constant(q) - dim = wp.constant(dim) - - # Make function for computing exu - @wp.func - def compute_exu(u: velocity_vec): - exu = lattice_vec() - for _ in range(q): - for d in range(dim): - if velocity_set[_, d] == 1: - exu[_] += u[d] - elif velocity_set[_, d] == -1: - exu[_] -= u[d] - return exu - - # Make function for computing feq - @wp.func - def compute_feq( - p: dtype, - uxu: dtype, - exu: lattice_vec, - ): - factor_1 = 1.5 - factor_2 = 4.5 - feq = lattice_vec() - for _ in range(q): - feq[_] = ( - velocity_weight[_] * p * ( - 1.0 - + factor_1 * (2.0 * exu[_] - uxu) - + factor_2 * exu[_] * exu[_] - ) - ) - return feq - - # Make function for computing u and p - @wp.func - def compute_u_and_p(f: lattice_vec): - p = wp.float32(0.0) - u = velocity_vec() - for d in range(dim): - u[d] = wp.float32(0.0) - for _ in range(q): - p += f[_] - for d in range(dim): - if velocity_set[_, d] == 1: - u[d] += f[_] - elif velocity_set[_, d] == -1: - u[d] -= f[_] - u /= p - return u, p - - # bc function - @wp.func - def bc_0(pre_f: lattice_vec, post_f: lattice_vec): - return pre_f - @wp.func - def bc_1(pre_f: lattice_vec, post_f: lattice_vec): - return post_f - tup_bc = tuple([bc_0, bc_1]) - single_bc = bc_0 - for bc in tup_bc: - def make_bc(bc, prev_bc): - @wp.func - def _bc(pre_f: lattice_vec, post_f: lattice_vec): - pre_f = prev_bc(pre_f, post_f) - post_f = single_bc(pre_f, post_f) - return bc(pre_f, post_f) - return _bc - single_bc = make_bc(bc, single_bc) - - # Make function for getting stream index - @wp.func - def get_streamed_index( - i: int, - x: int, - y: int, - z: int, - width: int, - height: int, - length: int, - ): - streamed_x = x + velocity_set[i, 0] - streamed_y = y + velocity_set[i, 1] - streamed_z = z + velocity_set[i, 2] - if streamed_x == -1: # TODO hacky - streamed_x = width - 1 - if streamed_y == -1: - streamed_y = height - 1 - if streamed_z == -1: - streamed_z = length - 1 - if streamed_x == width: - streamed_x = 0 - if streamed_y == height: - streamed_y = 0 - if streamed_z == length: - streamed_z = 0 - return streamed_x, streamed_y, streamed_z - - # Make kernel for stream and collide - @wp.kernel - def collide_stream( - f0: array_type, - f1: array_type, - width: int, - height: int, - length: int, - tau: float, - ): - - # Get indices (TODO: no good way to do variable dimension indexing) - f = lattice_vec() - x, y, z = wp.tid() - for i in range(q): - f[i] = f0[i, x, y, z] - - # Compute p and u - u, p = compute_u_and_p(f) - - # get uxu - uxu = wp.dot(u, u) - - # Compute velocity_set dot u - exu = compute_exu(u) - - # Compute equilibrium - feq = compute_feq(p, uxu, exu) - - # Set bc - if x == 0: - #tup_bc[0](feq, f) - bc_0(feq, f) - if x == width - 1: - bc_1(feq, f) - #tup_bc[1](feq, f) - - # Set value - new_f = f - (f - feq) / tau - for i in range(q): - (streamed_x, streamed_y, streamed_z) = get_streamed_index( - i, x, y, z, width, height, length - ) - f1[i, streamed_x, streamed_y, streamed_z] = new_f[i] - - # make kernel for initialization - @wp.kernel - def initialize_taylor_green( - f0: array_type, - dx: float, - vel: float, - width: int, - height: int, - length: int, - tau: float, - ): - - # Get indices (TODO: no good way to do variable dimension indexing) - i, j, k = wp.tid() - - # Get real coordinates - x = wp.float(i) * dx - y = wp.float(j) * dx - z = wp.float(k) * dx - - # Compute velocity - u = velocity_vec() - u[0] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) - u[2] = 0.0 - - # Compute p - p = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * ( - wp.cos(2.0 * x) - + wp.cos(2.0 * y) - + wp.cos(2.0 * z) - ) - + 1.0 - ) - - # Compute uxu - uxu = wp.dot(u, u) - - # Compute velocity_set dot u - exu = compute_exu(u) - - # Compute equilibrium - feq = compute_feq(p, uxu, exu) - - # Set value - for _ in range(q): - f0[_, i, j, k] = feq[_] - - return collide_stream, initialize_taylor_green - -def plt_f(f): - import matplotlib.pyplot as plt - plt.imshow(f.numpy()[3, :, :, f.shape[3] // 4]) - plt.show() - -if __name__ == "__main__": - - # Parameters - n = 256 - tau = 0.505 - dim = 3 - q = 19 - lattice_dtype = wp.float32 - lattice_vec = wp.vec(q, dtype=lattice_dtype) - - # Make arrays - f0 = wp.empty((q, n, n, n), dtype=lattice_dtype, device="cuda:0") - f1 = wp.empty((q, n, n, n), dtype=lattice_dtype, device="cuda:0") - - # Make velocity set - velocity_weight = wp.vec(q, dtype=lattice_dtype)( - [1.0/3.0] + [1.0/18.0] * 6 + [1.0/36.0] * 12 - ) - velocity_set = wp.mat((q, dim), dtype=wp.int32)( - [ - [0, 0, 0], - [1, 0, 0], - [-1, 0, 0], - [0, 1, 0], - [0, -1, 0], - [0, 0, 1], - [0, 0, -1], - [0, 1, 1], - [0, -1, -1], - [0, 1, -1], - [0, -1, 1], - [1, 0, 1], - [-1, 0, -1], - [1, 0, -1], - [-1, 0, 1], - [1, 1, 0], - [-1, -1, 0], - [1, -1, 0], - [-1, 1, 0], - ] - ) - - # Make kernel - collide_stream, initialize = make_warp_kernel( - velocity_weight, - velocity_set, - dtype=lattice_dtype, - dim=dim, - q=q, - ) - - # Initialize - cs = 1.0 / np.sqrt(3.0) - vel = 0.1 * cs - dx = 2.0 * np.pi / n - wp.launch( - initialize, - inputs=[ - f0, - dx, - vel, - n, - n, - n, - tau, - ], - dim=(n, n, n), - ) - - # Compute MLUPS - import time - import tqdm - nr_iterations = 128 - start = time.time() - for i in tqdm.tqdm(range(nr_iterations)): - #if i % 10 == 0: - # plt_f(f0) - - wp.launch( - collide_stream, - inputs=[ - f0, - f1, - n, - n, - n, - tau, - ], - dim=(n, n, n), - ) - f0, f1 = f1, f0 - wp.synchronize() - end = time.time() - print("MLUPS: ", (nr_iterations * n * n * n) / (end - start) / 1e6) diff --git a/examples/cfd/example_basic.py b/examples/cfd/example_basic.py new file mode 100644 index 0000000..c11abbe --- /dev/null +++ b/examples/cfd/example_basic.py @@ -0,0 +1,69 @@ +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.default_config import DefaultConfig +import warp as wp +from xlb.grid import grid +from xlb.precision_policy import Precision +import xlb.velocity_set + +xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=xlb.velocity_set.D3Q19, +) + +grid_size = 50 +grid_shape = (grid_size, grid_size, grid_size) +my_grid = grid(grid_shape) +f = my_grid.create_field(cardinality=9) + +# compute_macro = QuadraticEquilibrium() + +# f_eq = compute_macro(rho, u) + + +# DefaultConfig.velocity_set.w + + + + +# def initializer(): +# rho = grid.create_field(cardinality=1) + 1.0 +# u = grid.create_field(cardinality=2) + +# circle_center = (grid_shape[0] // 2, grid_shape[1] // 2) +# circle_radius = 10 + +# for x in range(grid_shape[0]): +# for y in range(grid_shape[1]): +# if (x - circle_center[0]) ** 2 + ( +# y - circle_center[1] +# ) ** 2 <= circle_radius**2: +# rho = rho.at[0, x, y].add(0.001) + +# func_eq = QuadraticEquilibrium() +# f_eq = func_eq(rho, u) + +# return f_eq + + + +# solver = IncompressibleNavierStokes(grid, omega=1.0) + + +# def perform_io(f, step): +# rho, u = compute_macro(f) +# fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1]} +# save_fields_vtk(fields, step) +# save_image(rho[0], step) +# print(f"Step {step + 1} complete") + + +# num_steps = 1000 +# io_rate = 100 +# for step in range(num_steps): +# f = solver.step(f, timestep=step) + +# if step % io_rate == 0: +# perform_io(f, step) diff --git a/examples/interfaces/flow_past_sphere.py b/examples/cfd/flow_past_sphere.py similarity index 95% rename from examples/interfaces/flow_past_sphere.py rename to examples/cfd/flow_past_sphere.py index 8bfe945..ef24ac8 100644 --- a/examples/interfaces/flow_past_sphere.py +++ b/examples/cfd/flow_past_sphere.py @@ -7,11 +7,19 @@ from typing import Any import numpy as np -import warp as wp +from xlb.compute_backend import ComputeBackend -wp.init() +import warp as wp import xlb + +xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=xlb.velocity_set.D2Q9, +) + + from xlb.operator import Operator class UniformInitializer(Operator): @@ -62,8 +70,8 @@ def warp_implementation(self, rho, u, vel): nr = 256 vel = 0.05 shape = (nr, nr, nr) - grid = xlb.grid.WarpGrid(shape=shape) - rho = grid.create_field(cardinality=1, dtype=wp.float32) + grid = xlb.grid.grid(shape=shape) + rho = grid.create_field(cardinality=1) u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) diff --git a/examples/interfaces/ldc.py b/examples/cfd/lid_driven_cavity.py similarity index 100% rename from examples/interfaces/ldc.py rename to examples/cfd/lid_driven_cavity.py diff --git a/examples/interfaces/taylor_green.py b/examples/cfd/taylor_green.py similarity index 100% rename from examples/interfaces/taylor_green.py rename to examples/cfd/taylor_green.py diff --git a/examples/performance/MLUPS2d.py b/examples/performance/MLUPS2d.py deleted file mode 100644 index 77d32a4..0000000 --- a/examples/performance/MLUPS2d.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -This script computes the MLUPS (Million Lattice Updates per Second) in 2D by simulating fluid flow inside a 2D cavity. -""" - -import os -import argparse -import jax.numpy as jnp -import numpy as np -from jax import config -from time import time - -from src.utils import * -from src.boundary_conditions import * -from src.lattice import LatticeD2Q9 -from src.models import BGKSim - -class Cavity(BGKSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # concatenate the indices of the left, right, and bottom walls - walls = np.concatenate((self.boundingBoxIndices['left'], self.boundingBoxIndices['right'], self.boundingBoxIndices['bottom'])) - # apply bounce back boundary condition to the walls - self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - - # apply inlet equilibrium boundary condition to the top wall - moving_wall = self.boundingBoxIndices['top'] - - rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = u_wall - self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) - - -if __name__ == '__main__': - precision = 'f32/f32' - lattice = LatticeD2Q9(precision) - - parser = argparse.ArgumentParser("simple_example") - parser.add_argument("N", help="The total number of voxels will be NxN", type=int) - parser.add_argument("timestep", help="Number of timesteps", type=int) - args = parser.parse_args() - - n = args.N - max_iter = args.timestep - Re = 100.0 - u_wall = 0.1 - clength = n - 1 - - visc = u_wall * clength / Re - omega = 1.0 / (3. * visc + 0.5) - print('omega = ', omega) - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': n, - 'ny': n, - 'nz': 0, - 'precision': precision, - 'compute_MLUPS': True - } - - os.system('rm -rf ./*.vtk && rm -rf ./*.png') - sim = Cavity(**kwargs) - sim.run(max_iter) diff --git a/examples/performance/MLUPS3d.py b/examples/performance/MLUPS3d.py deleted file mode 100644 index 8a9f9e3..0000000 --- a/examples/performance/MLUPS3d.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -This script computes the MLUPS (Million Lattice Updates per Second) in 3D by simulating fluid flow inside a 2D cavity. -""" - -import os -import argparse - -import jax -import jax.numpy as jnp -import numpy as np -from jax import config -from time import time -#config.update('jax_disable_jit', True) -# Use 8 CPU devices -#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -#config.update("jax_enable_x64", True) -from src.utils import * -from src.boundary_conditions import * -from src.models import BGKSim -from src.lattice import LatticeD3Q19 -class Cavity(BGKSim): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # concatenate the indices of the left, right, and bottom walls - walls = np.concatenate((self.boundingBoxIndices['left'], self.boundingBoxIndices['right'], self.boundingBoxIndices['bottom'], self.boundingBoxIndices['front'], self.boundingBoxIndices['back'])) - # apply bounce back boundary condition to the walls - self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - - # apply inlet equilibrium boundary condition to the top wall - moving_wall = self.boundingBoxIndices['top'] - - rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = u_wall - self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) - -if __name__ == '__main__': - precision = 'f32/f32' - lattice = LatticeD3Q19(precision) - # Create a parser that will read the command line arguments - parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation") - parser.add_argument("N", help="The total number of voxels all directions. The final dimension will be N*NxN", default=100, type=int) - parser.add_argument("N_ITERS", help="Number of timesteps", default=10000, type=int) - - args = parser.parse_args() - n = args.N - n_iters = args.N_ITERS - - # Store the Reynolds number in the variable Re - Re = 100.0 - # Store the velocity of the lid in the variable u_wall - u_wall = 0.1 - # Store the length of the cavity in the variable clength - clength = n - 1 - - # Compute the viscosity from the Reynolds number, the lid velocity, and the length of the cavity - visc = u_wall * clength / Re - # Compute the relaxation parameter from the viscosity - omega = 1.0 / (3. * visc + 0.5) - - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': n, - 'ny': n, - 'nz': n, - 'precision': precision, - 'compute_MLUPS': True - } - - sim = Cavity(**kwargs) - # Run the simulation - sim.run(n_iters) - \ No newline at end of file diff --git a/examples/performance/MLUPS3d_distributed.py b/examples/performance/MLUPS3d_distributed.py deleted file mode 100644 index 70d5328..0000000 --- a/examples/performance/MLUPS3d_distributed.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -This script computes the MLUPS (Million Lattice Updates per Second) in 3D by simulating fluid flow inside a 2D cavity. -This script is equivalent to MLUPS3d.py, but uses JAX distributed to run the simulation on distributed systems (multi-host, multi-GPUs). -Please refer to https://jax.readthedocs.io/en/latest/multi_process.html for more information on JAX distributed. -""" - - -# Standard Libraries -import argparse -import os -import jax - -import jax.numpy as jnp -import numpy as np - -from jax import config - -from src.boundary_conditions import * -from src.models import BGKSim -from src.lattice import LatticeD3Q19 -from src.utils import * - -#config.update('jax_disable_jit', True) -# Use 8 CPU devices -#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' -#config.update("jax_enable_x64", True) - -class Cavity(BGKSim): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_boundary_conditions(self): - # concatenate the indices of the left, right, and bottom walls - walls = np.concatenate((self.boundingBoxIndices['left'], self.boundingBoxIndices['right'], self.boundingBoxIndices['bottom'], self.boundingBoxIndices['front'], self.boundingBoxIndices['back'])) - # apply bounce back boundary condition to the walls - self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy)) - - # apply inlet equilibrium boundary condition to the top wall - moving_wall = self.boundingBoxIndices['top'] - - rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype) - vel_wall[:, 0] = u_wall - self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall)) - -if __name__ == '__main__': - # Create a parser that will read the command line arguments - parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation") - parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN", - default=100, type=int) - parser.add_argument("N_ITERS", help="Number of iterations", default=10000, type=int) - parser.add_argument("N_PROCESSES", help="Number of processes. If >1, call jax.distributed.initialize with that number of process. If -1 will call jax.distributed.initialize without any arsgument. So it should pick up the values from SLURM env variable.", - default=1, type=int) - parser.add_argument("IP", help="IP of the master node for multi-node. Useless if using SLURM.", - default='127.0.0.1', type=str, nargs='?') - parser.add_argument("PROCESS_ID_INCREMENT", help="For multi-node only. Useless if using SLURM.", - default=0, type=int, nargs='?') - - args = parser.parse_args() - n = args.N - n_iters = args.N_ITERS - n_processes = args.N_PROCESSES - # Initialize JAX distributed. The IP, number of processes and process id must be set correctly. - print("N processes, ", n_processes) - print("N iter, ", n_iters) - if n_processes > 1: - process_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + args.PROCESS_ID_INCREMENT - print("ip, num_processes, process_id, ", args.IP, n_processes, process_id) - jax.distributed.initialize(args.IP, num_processes=n_processes, - process_id=process_id) - elif n_processes == -1: - print("Will call jax.distributed.initialize()") - jax.distributed.initialize() - print("jax.distributed.initialize() ended") - else: - print("No call to jax.distributed.initialize") - print("JAX local devices: ", jax.local_devices()) - - precision = 'f32/f32' - # Create a 3D lattice with the D3Q19 scheme - lattice = LatticeD3Q19(precision) - - # Store the Reynolds number in the variable Re - Re = 100.0 - # Store the velocity of the lid in the variable u_wall - u_wall = 0.1 - # Store the length of the cavity in the variable clength - clength = n - 1 - - # Compute the viscosity from the Reynolds number, the lid velocity, and the length of the cavity - visc = u_wall * clength / Re - # Compute the relaxation parameter from the viscosity - omega = 1.0 / (3. * visc + 0.5) - - # Create a new instance of the Cavity class - kwargs = { - 'lattice': lattice, - 'omega': omega, - 'nx': n, - 'ny': n, - 'nz': n, - 'precision': precision, - 'compute_MLUPS': True - } - - sim = Cavity(**kwargs) # Run the simulation - sim.run(n_iters) diff --git a/examples/refactor/mlups3d.py b/examples/performance/mlups3d.py similarity index 100% rename from examples/refactor/mlups3d.py rename to examples/performance/mlups3d.py diff --git a/examples/refactor/README.md b/examples/refactor/README.md deleted file mode 100644 index 37b17e7..0000000 --- a/examples/refactor/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# Refactor Examples - -This directory contains several example of using the refactored XLB library. - -These examples are not meant to be veiwed as the new interface to XLB but only how -to expose the compute kernels to a user. Development is still ongoing. - -## Examples - -### JAX Example - -The JAX example is a simple example of using the refactored XLB library -with JAX. The example is located in the `example_jax.py`. It shows -a very basic flow past a cyliner. - -### NUMBA Example - -TODO: Not working yet - -The NUMBA example is a simple example of using the refactored XLB library -with NUMBA. The example is located in the `example_numba.py`. It shows -a very basic flow past a cyliner. This example is not working yet though and -is still under development for numba backend. - -### Out of Core JAX Example - -This shoes how we can use out of core memory with JAX. The example is located -in the `example_jax_out_of_core.py`. It shows a very basic flow past a cyliner. -The basic idea is to create an out of core memory array using the implementation -in XLB. Then we run the simulation using the jax functions implementation obtained -from XLB. Some rendering is done using PhantomGaze. diff --git a/examples/refactor/example_basic.py b/examples/refactor/example_basic.py deleted file mode 100644 index 5f74d21..0000000 --- a/examples/refactor/example_basic.py +++ /dev/null @@ -1,62 +0,0 @@ -import xlb -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import Fp32Fp32 - -from xlb.solver import IncompressibleNavierStokes -from xlb.grid import Grid -from xlb.operator.macroscopic import Macroscopic -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.utils import save_fields_vtk, save_image - -xlb.init( - precision_policy=Fp32Fp32, - compute_backend=ComputeBackend.JAX, - velocity_set=xlb.velocity_set.D2Q9, -) - -grid_shape = (1000, 1000) -grid = Grid.create(grid_shape) - - -def initializer(): - rho = grid.create_field(cardinality=1) + 1.0 - u = grid.create_field(cardinality=2) - - circle_center = (grid_shape[0] // 2, grid_shape[1] // 2) - circle_radius = 10 - - for x in range(grid_shape[0]): - for y in range(grid_shape[1]): - if (x - circle_center[0]) ** 2 + ( - y - circle_center[1] - ) ** 2 <= circle_radius**2: - rho = rho.at[0, x, y].add(0.001) - - func_eq = QuadraticEquilibrium() - f_eq = func_eq(rho, u) - - return f_eq - - -f = initializer() - -compute_macro = Macroscopic() - -solver = IncompressibleNavierStokes(grid, omega=1.0) - - -def perform_io(f, step): - rho, u = compute_macro(f) - fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1]} - save_fields_vtk(fields, step) - save_image(rho[0], step) - print(f"Step {step + 1} complete") - - -num_steps = 1000 -io_rate = 100 -for step in range(num_steps): - f = solver.step(f, timestep=step) - - if step % io_rate == 0: - perform_io(f, step) diff --git a/examples/refactor/example_jax.py b/examples/refactor/example_jax.py deleted file mode 100644 index 9092022..0000000 --- a/examples/refactor/example_jax.py +++ /dev/null @@ -1,107 +0,0 @@ -# from IPython import display -import numpy as np -import jax -import jax.numpy as jnp -import scipy -import time -from tqdm import tqdm -import matplotlib.pyplot as plt - -import xlb - -if __name__ == "__main__": - # Simulation parameters - nr = 128 - vel = 0.05 - visc = 0.00001 - omega = 1.0 / (3.0 * visc + 0.5) - length = 2 * np.pi - - # Geometry (sphere) - lin = np.linspace(0, length, nr) - X, Y, Z = np.meshgrid(lin, lin, lin, indexing="ij") - XYZ = np.stack([X, Y, Z], axis=-1) - radius = np.pi / 8.0 - - # XLB precision policy - precision_policy = xlb.precision_policy.Fp32Fp32() - - # XLB lattice - velocity_set = xlb.velocity_set.D3Q27() - - # XLB equilibrium - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(velocity_set=velocity_set) - - # XLB macroscopic - macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set) - - # XLB collision - collision = xlb.operator.collision.KBC(omega=omega, velocity_set=velocity_set) - - # XLB stream - stream = xlb.operator.stream.Stream(velocity_set=velocity_set) - - # XLB noslip boundary condition (sphere) - in_cylinder = ((X - np.pi/2.0)**2 + (Y - np.pi)**2 + (Z - np.pi)**2) < radius**2 - indices = np.argwhere(in_cylinder) - bounce_back = xlb.operator.boundary_condition.FullBounceBack.from_indices( - indices=indices, - velocity_set=velocity_set - ) - - # XLB outflow boundary condition - outflow = xlb.operator.boundary_condition.DoNothing.from_indices( - indices=np.argwhere(XYZ[..., 0] == length), - velocity_set=velocity_set - ) - - # XLB inflow boundary condition - inflow = xlb.operator.boundary_condition.EquilibriumBoundary.from_indices( - indices=np.argwhere(XYZ[..., 0] == 0.0), - velocity_set=velocity_set, - rho=1.0, - u=np.array([vel, 0.0, 0.0]), - equilibrium=equilibrium - ) - - # XLB stepper - stepper = xlb.operator.stepper.NSE( - collision=collision, - stream=stream, - equilibrium=equilibrium, - macroscopic=macroscopic, - boundary_conditions=[bounce_back, outflow, inflow], - precision_policy=precision_policy, - ) - - # Make initial condition - u = jnp.stack([vel * jnp.ones_like(X), jnp.zeros_like(X), jnp.zeros_like(X)], axis=-1) - rho = jnp.expand_dims(jnp.ones_like(X), axis=-1) - f = equilibrium(rho, u) - - # Get boundary id and mask - ijk = jnp.meshgrid(jnp.arange(nr), jnp.arange(nr), jnp.arange(nr), indexing="ij") - boundary_id, mask = stepper.set_boundary(jnp.stack(ijk, axis=-1)) - - # Run simulation - tic = time.time() - nr_iter = 4096 - for i in tqdm(range(nr_iter)): - f = stepper(f, boundary_id, mask, i) - - if i % 32 == 0: - # Get u, rho from f - rho, u = macroscopic(f) - norm_u = jnp.linalg.norm(u, axis=-1) - norm_u = (1.0 - jnp.minimum(boundary_id, 1.0)) * norm_u - - # Plot - plt.imshow(norm_u[..., nr//2], cmap="jet") - plt.colorbar() - plt.savefig(f"img_{str(i).zfill(5)}.png") - plt.close() - - # Sync to host - f = f.block_until_ready() - toc = time.time() - print(f"MLUPS: {(nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_jax_out_of_core.py b/examples/refactor/example_jax_out_of_core.py deleted file mode 100644 index aeb604f..0000000 --- a/examples/refactor/example_jax_out_of_core.py +++ /dev/null @@ -1,336 +0,0 @@ -# from IPython import display -import os -os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.7' - -import numpy as np -import jax -import jax.numpy as jnp -import scipy -import time -from tqdm import tqdm -import matplotlib.pyplot as plt -from mpi4py import MPI -import cupy as cp - -import xlb -from xlb.experimental.ooc import OOCmap, OOCArray - -import phantomgaze as pg - -comm = MPI.COMM_WORLD - -@jax.jit -def q_criterion(u): - # Compute derivatives - u_x = u[..., 0] - u_y = u[..., 1] - u_z = u[..., 2] - - # Compute derivatives - u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2 - u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2 - u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2 - u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2 - u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2 - u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2 - u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2 - u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2 - u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2 - - # Compute vorticity - mu_x = u_z_dy - u_y_dz - mu_y = u_x_dz - u_z_dx - mu_z = u_y_dx - u_x_dy - norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2) - - # Compute strain rate - s_0_0 = u_x_dx - s_0_1 = 0.5 * (u_x_dy + u_y_dx) - s_0_2 = 0.5 * (u_x_dz + u_z_dx) - s_1_0 = s_0_1 - s_1_1 = u_y_dy - s_1_2 = 0.5 * (u_y_dz + u_z_dy) - s_2_0 = s_0_2 - s_2_1 = s_1_2 - s_2_2 = u_z_dz - s_dot_s = ( - s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 + - s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 + - s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2 - ) - - # Compute omega - omega_0_0 = 0.0 - omega_0_1 = 0.5 * (u_x_dy - u_y_dx) - omega_0_2 = 0.5 * (u_x_dz - u_z_dx) - omega_1_0 = -omega_0_1 - omega_1_1 = 0.0 - omega_1_2 = 0.5 * (u_y_dz - u_z_dy) - omega_2_0 = -omega_0_2 - omega_2_1 = -omega_1_2 - omega_2_2 = 0.0 - omega_dot_omega = ( - omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 + - omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 + - omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2 - ) - - # Compute q-criterion - q = 0.5 * (omega_dot_omega - s_dot_s) - - return norm_mu, q - - -if __name__ == "__main__": - # Simulation parameters - nr = 256 - nx = 3 * nr - ny = nr - nz = nr - vel = 0.05 - visc = 0.00001 - omega = 1.0 / (3.0 * visc + 0.5) - length = 2 * np.pi - dx = length / (ny - 1) - radius = np.pi / 3.0 - - # OOC parameters - sub_steps = 8 - sub_nr = 128 - padding = (sub_steps, sub_steps, sub_steps, 0) - - # XLB precision policy - precision_policy = xlb.precision_policy.Fp32Fp32() - - # XLB lattice - velocity_set = xlb.velocity_set.D3Q27() - - # XLB equilibrium - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(velocity_set=velocity_set) - - # XLB macroscopic - macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set) - - # XLB collision - collision = xlb.operator.collision.KBC(omega=omega, velocity_set=velocity_set) - - # XLB stream - stream = xlb.operator.stream.Stream(velocity_set=velocity_set) - - # XLB noslip boundary condition (sphere) - # Create a mask function - def set_boundary_sphere(ijk, boundary_id, mask, id_number): - # Get XYZ - XYZ = ijk * dx - sphere_mask = jnp.linalg.norm(XYZ - length / 2.0, axis=-1) < radius - boundary_id = boundary_id.at[sphere_mask].set(id_number) - mask = mask.at[sphere_mask].set(True) - return boundary_id, mask - bounce_back = xlb.operator.boundary_condition.FullBounceBack( - set_boundary=set_boundary_sphere, - velocity_set=velocity_set - ) - - # XLB outflow boundary condition - def set_boundary_outflow(ijk, boundary_id, mask, id_number): - # Get XYZ - XYZ = ijk * dx - outflow_mask = XYZ[..., 0] >= (length * 3.0) - dx - boundary_id = boundary_id.at[outflow_mask].set(id_number) - mask = mask.at[outflow_mask].set(True) - return boundary_id, mask - outflow = xlb.operator.boundary_condition.DoNothing( - set_boundary=set_boundary_outflow, - velocity_set=velocity_set - ) - - # XLB inflow boundary condition - def set_boundary_inflow(ijk, boundary_id, mask, id_number): - # Get XYZ - XYZ = ijk * dx - inflow_mask = XYZ[..., 0] == 0.0 - boundary_id = boundary_id.at[inflow_mask].set(id_number) - mask = mask.at[inflow_mask].set(True) - return boundary_id, mask - inflow = xlb.operator.boundary_condition.EquilibriumBoundary( - set_boundary=set_boundary_inflow, - velocity_set=velocity_set, - rho=1.0, - u=np.array([vel, 0.0, 0.0]), - equilibrium=equilibrium - ) - - # XLB stepper - stepper = xlb.operator.stepper.NSE( - collision=collision, - stream=stream, - equilibrium=equilibrium, - macroscopic=macroscopic, - boundary_conditions=[bounce_back, outflow, inflow], - precision_policy=precision_policy, - ) - - # Make OOC arrays - f = OOCArray( - shape=(nx, ny, nz, velocity_set.q), - dtype=np.float32, - tile_shape=(sub_nr, sub_nr, sub_nr, velocity_set.q), - padding=padding, - comm=comm, - devices=[cp.cuda.Device(0) for i in range(comm.size)], - codec=None, - nr_compute_tiles=1, - ) - - camera_radius = length * 2.0 - focal_point = (3.0 * length / 2.0, length / 2.0, length / 2.0) - angle = 1 * 0.0001 - camera_position = (focal_point[0] + camera_radius * np.sin(angle), focal_point[1], focal_point[2] + camera_radius * np.cos(angle)) - camera = pg.Camera( - position=camera_position, - focal_point=focal_point, - view_up=(0.0, 1.0, 0.0), - height=1440, - width=2560, - max_depth=6.0 * length, - ) - screen_buffer = pg.ScreenBuffer.from_camera(camera) - - - # Initialize f - @OOCmap(comm, (0,), backend="jax") - def initialize_f(f): - # Get inputs - shape = f.shape[:-1] - u = jnp.stack([vel * jnp.ones(shape), jnp.zeros(shape), jnp.zeros(shape)], axis=-1) - rho = jnp.expand_dims(jnp.ones(shape), axis=-1) - f = equilibrium(rho, u) - return f - f = initialize_f(f) - - # Stepping function - @OOCmap(comm, (0,), backend="jax", add_index=True) - def ooc_stepper(f): - - # Get tensors - f, global_index = f - - # Get ijk - lin_i = jnp.arange(global_index[0], global_index[0] + f.shape[0]) - lin_j = jnp.arange(global_index[1], global_index[1] + f.shape[1]) - lin_k = jnp.arange(global_index[2], global_index[2] + f.shape[2]) - ijk = jnp.meshgrid(lin_i, lin_j, lin_k, indexing="ij") - ijk = jnp.stack(ijk, axis=-1) - - # Set boundary_id and mask - boundary_id, mask = stepper.set_boundary(ijk) - - # Run stepper - for _ in range(sub_steps): - f = stepper(f, boundary_id, mask, _) - - # Wait till f is computed using jax - f = f.block_until_ready() - - return f - - # Make a render function - @OOCmap(comm, (0,), backend="jax", add_index=True) - def render(f, screen_buffer, camera): - - # Get tensors - f, global_index = f - - # Get ijk - lin_i = jnp.arange(global_index[0], global_index[0] + f.shape[0]) - lin_j = jnp.arange(global_index[1], global_index[1] + f.shape[1]) - lin_k = jnp.arange(global_index[2], global_index[2] + f.shape[2]) - ijk = jnp.meshgrid(lin_i, lin_j, lin_k, indexing="ij") - ijk = jnp.stack(ijk, axis=-1) - - # Set boundary_id and mask - boundary_id, mask = stepper.set_boundary(ijk) - sphere = (boundary_id == 1).astype(jnp.float32)[1:-1, 1:-1, 1:-1] - - # Get rho, u - rho, u = macroscopic(f) - - # Get q-cr - norm_mu, q = q_criterion(u) - - # Make volumes - origin = ((global_index[0] + 1) * dx, (global_index[1] + 1) * dx, (global_index[2] + 1) * dx) - q_volume = pg.objects.Volume( - q, spacing=(dx, dx, dx), origin=origin - ) - norm_mu_volume = pg.objects.Volume( - norm_mu, spacing=(dx, dx, dx), origin=origin - ) - sphere_volume = pg.objects.Volume( - sphere, spacing=(dx, dx, dx), origin=origin - ) - - # Render - screen_buffer = pg.render.contour( - q_volume, - threshold=0.000005, - color=norm_mu_volume, - colormap=pg.Colormap("jet", vmin=0.0, vmax=0.025), - camera=camera, - screen_buffer=screen_buffer, - ) - screen_buffer = pg.render.contour( - sphere_volume, - threshold=0.5, - camera=camera, - screen_buffer=screen_buffer, - ) - - return f - - # Run simulation - tic = time.time() - nr_iter = 128 * nr // sub_steps - nr_frames = 1024 - for i in tqdm(range(nr_iter)): - f = ooc_stepper(f) - - if i % (nr_iter // nr_frames) == 0: - # Rotate camera - camera_radius = length * 1.0 - focal_point = (length / 2.0, length / 2.0, length / 2.0) - angle = (np.pi / nr_iter) * i - camera_position = (focal_point[0] + camera_radius * np.sin(angle), focal_point[1], focal_point[2] + camera_radius * np.cos(angle)) - camera = pg.Camera( - position=camera_position, - focal_point=focal_point, - view_up=(0.0, 1.0, 0.0), - height=1080, - width=1920, - max_depth=6.0 * length, - ) - - # Render global setup - screen_buffer = pg.render.wireframe( - lower_bound=(0.0, 0.0, 0.0), - upper_bound=(3.0*length, length, length), - thickness=length/100.0, - camera=camera, - ) - screen_buffer = pg.render.axes( - size=length/30.0, - center=(0.0, 0.0, length*1.1), - camera=camera, - screen_buffer=screen_buffer - ) - - # Render - render(f, screen_buffer, camera) - - # Save image - plt.imsave('./q_criterion_' + str(i).zfill(7) + '.png', np.minimum(screen_buffer.image.get(), 1.0)) - - # Sync to host - cp.cuda.runtime.deviceSynchronize() - toc = time.time() - print(f"MLUPS: {(sub_steps * nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_numba.py b/examples/refactor/example_numba.py deleted file mode 100644 index 9183067..0000000 --- a/examples/refactor/example_numba.py +++ /dev/null @@ -1,78 +0,0 @@ -# from IPython import display -import numpy as np -import cupy as cp -import scipy -import time -from tqdm import tqdm -import matplotlib.pyplot as plt -from numba import cuda, config - -import xlb - -config.CUDA_ARRAY_INTERFACE_SYNC = False - -if __name__ == "__main__": - # XLB precision policy - precision_policy = xlb.precision_policy.Fp32Fp32() - - # XLB lattice - lattice = xlb.lattice.D3Q19() - - # XLB collision model - collision = xlb.collision.BGK() - - # Make XLB compute kernels - compute = xlb.compute_constructor.NumbaConstructor( - lattice=lattice, - collision=collision, - boundary_conditions=[], - forcing=None, - precision_policy=precision_policy, - ) - - # Make taylor green vortex initial condition - tau = 0.505 - vel = 0.1 * 1.0 / np.sqrt(3.0) - nr = 256 - lin = cp.linspace(0, 2 * np.pi, nr, endpoint=False, dtype=cp.float32) - X, Y, Z = cp.meshgrid(lin, lin, lin, indexing="ij") - X = X[None, ...] - Y = Y[None, ...] - Z = Z[None, ...] - u = vel * cp.sin(X) * cp.cos(Y) * cp.cos(Z) - v = -vel * cp.cos(X) * cp.sin(Y) * cp.cos(Z) - w = cp.zeros_like(X) - rho = ( - 3.0 - * vel**2 - * (1.0 / 16.0) - * (cp.cos(2 * X) + cp.cos(2 * Y) + cp.cos(2 * Z)) - + 1.0) - u = cp.concatenate([u, v, w], axis=-1) - - # Allocate f - f0 = cp.zeros((19, nr, nr, nr), dtype=cp.float32) - f1 = cp.zeros((19, nr, nr, nr), dtype=cp.float32) - - # Get f from u, rho - compute.equilibrium(rho, u, f0) - - # Run compute kernel on f - tic = time.time() - nr_iter = 128 - for i in tqdm(range(nr_iter)): - compute.step(f0, f1, i) - f0, f1 = f1, f0 - - if i % 4 == 0: - ## Get u, rho from f - #rho, u = compute.macroscopic(f) - #norm_u = jnp.linalg.norm(u, axis=-1) - - # Plot - plt.imsave(f"img_{str(i).zfill(5)}.png", f0[8, nr // 2, :, :].get(), cmap="jet") - - # Sync to host - cp.cuda.stream.get_current_stream().synchronize() - toc = time.time() - print(f"MLUPS: {(nr_iter * nr**3) / (toc - tic) / 1e6}") diff --git a/examples/refactor/example_pallas_3d.py b/examples/refactor/example_pallas_3d.py deleted file mode 100644 index 17084d4..0000000 --- a/examples/refactor/example_pallas_3d.py +++ /dev/null @@ -1,73 +0,0 @@ -import xlb -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import Fp32Fp32 - -from xlb.solver import IncompressibleNavierStokes -from xlb.grid import Grid -from xlb.operator.macroscopic import Macroscopic -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.utils import save_fields_vtk, save_image -import numpy as np -import jax.numpy as jnp - -# Initialize XLB with Pallas backend for 3D simulation -xlb.init( - precision_policy=Fp32Fp32, - compute_backend=ComputeBackend.PALLAS, # Changed to Pallas backend - velocity_set=xlb.velocity_set.D3Q19, # Changed to D3Q19 for 3D -) - -grid_shape = (128, 128, 128) # Adjusted for 3D grid -grid = Grid.create(grid_shape) - - -def initializer(): - rho = grid.create_field(cardinality=1) + 1.0 - u = grid.create_field(cardinality=3) - - sphere_center = np.array([s // 2 for s in grid_shape]) - sphere_radius = 10 - - x, y, z = np.meshgrid( - np.arange(grid_shape[0]), - np.arange(grid_shape[1]), - np.arange(grid_shape[2]), - indexing="ij", - ) - - squared_dist = ( - (x - sphere_center[0]) ** 2 - + (y - sphere_center[1]) ** 2 - + (z - sphere_center[2]) ** 2 - ) - - inside_sphere = squared_dist <= sphere_radius**2 - - rho = jnp.where(inside_sphere, rho.at[0, x, y, z].add(0.001), rho) - - func_eq = QuadraticEquilibrium(compute_backend=ComputeBackend.JAX) - f_eq = func_eq(rho, u) - - return f_eq - - -f = initializer() - -compute_macro = Macroscopic(compute_backend=ComputeBackend.JAX) - -solver = IncompressibleNavierStokes(grid, omega=1.0) - -def perform_io(f, step): - rho, u = compute_macro(f) - fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_z": u[2]} - save_fields_vtk(fields, step) - # save_image function might not be suitable for 3D, consider alternative visualization - print(f"Step {step + 1} complete") - -num_steps = 1000 -io_rate = 100 -for step in range(num_steps): - f = solver.step(f, timestep=step) - - if step % io_rate == 0: - perform_io(f, step) diff --git a/examples/refactor/mlups_pallas_3d.py b/examples/refactor/mlups_pallas_3d.py deleted file mode 100644 index 71715d9..0000000 --- a/examples/refactor/mlups_pallas_3d.py +++ /dev/null @@ -1,84 +0,0 @@ -import xlb -import time -import argparse -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import Fp32Fp32 -from xlb.solver import IncompressibleNavierStokes -from xlb.grid import Grid -from xlb.operator.macroscopic import Macroscopic -from xlb.operator.equilibrium import QuadraticEquilibrium -import numpy as np -import jax.numpy as jnp - -# Command line argument parsing -parser = argparse.ArgumentParser( - description="3D Lattice Boltzmann Method Simulation using XLB" -) -parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") -parser.add_argument( - "num_steps", type=int, help="Number of timesteps for the simulation" -) -args = parser.parse_args() - -# Initialize XLB -xlb.init( - precision_policy=Fp32Fp32, - compute_backend=ComputeBackend.PALLAS, - velocity_set=xlb.velocity_set.D3Q19, -) - -# Grid initialization -grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) -grid = Grid.create(grid_shape) - - -def initializer(): - rho = grid.create_field(cardinality=1) + 1.0 - u = grid.create_field(cardinality=3) - - sphere_center = np.array([s // 2 for s in grid_shape]) - sphere_radius = 10 - - x, y, z = np.meshgrid( - np.arange(grid_shape[0]), - np.arange(grid_shape[1]), - np.arange(grid_shape[2]), - indexing="ij", - ) - - squared_dist = ( - (x - sphere_center[0]) ** 2 - + (y - sphere_center[1]) ** 2 - + (z - sphere_center[2]) ** 2 - ) - - inside_sphere = squared_dist <= sphere_radius**2 - - rho = jnp.where(inside_sphere, rho.at[0, x, y, z].add(0.001), rho) - - func_eq = QuadraticEquilibrium(compute_backend=ComputeBackend.JAX) - f_eq = func_eq(rho, u) - - return f_eq - - -f = initializer() - -solver = IncompressibleNavierStokes(grid, omega=1.0) - -# AoT compile -f = solver.step(f, timestep=0) - -# Start the simulation -start_time = time.time() - -for step in range(args.num_steps): - f = solver.step(f, timestep=step) - -end_time = time.time() - -# MLUPS calculation -total_lattice_updates = args.cube_edge**3 * args.num_steps -total_time_seconds = end_time - start_time -mlups = (total_lattice_updates / total_time_seconds) / 1e6 -print(f"MLUPS: {mlups}") diff --git a/examples/warp_backend/testing.py b/examples/warp_backend/testing.py deleted file mode 100644 index a20feb3..0000000 --- a/examples/warp_backend/testing.py +++ /dev/null @@ -1,108 +0,0 @@ -# from IPython import display -import numpy as np -import jax -import jax.numpy as jnp -import scipy -import time -from tqdm import tqdm -import matplotlib.pyplot as plt - -import warp as wp -wp.init() - -import xlb - - -def test_backends(compute_backend): - - # Set parameters - precision_policy = xlb.PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D3Q27() - - # Make operators - collision = xlb.operator.collision.BGK( - omega=1.0, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - bounceback = xlb.operator.boundary_condition.FullBounceBack.from_indices( - indices=np.array([[0, 0, 0], [0, 0, 1]]), - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream, - boundary_conditions=[bounceback]) - - # Test operators - if compute_backend == xlb.ComputeBackend.WARP: - # Make warp arrays - nr = 128 - f_0 = wp.zeros((27, nr, nr, nr), dtype=wp.float32) - f_1 = wp.zeros((27, nr, nr, nr), dtype=wp.float32) - f_out = wp.zeros((27, nr, nr, nr), dtype=wp.float32) - u = wp.zeros((3, nr, nr, nr), dtype=wp.float32) - rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32) - boundary_id = wp.zeros((1, nr, nr, nr), dtype=wp.uint8) - boundary = wp.zeros((1, nr, nr, nr), dtype=wp.bool) - mask = wp.zeros((27, nr, nr, nr), dtype=wp.bool) - - # Test operators - collision(f_0, f_1, rho, u, f_out) - equilibrium(rho, u, f_0) - macroscopic(f_0, rho, u) - stream(f_0, f_1) - bounceback(f_0, f_1, f_out, boundary, mask) - #bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1) - - - - elif compute_backend == xlb.ComputeBackend.JAX: - # Make jax arrays - nr = 128 - f_0 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) - f_1 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) - f_out = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) - u = jnp.zeros((3, nr, nr, nr), dtype=jnp.float32) - rho = jnp.zeros((1, nr, nr, nr), dtype=jnp.float32) - boundary_id = jnp.zeros((1, nr, nr, nr), dtype=jnp.uint8) - boundary = jnp.zeros((1, nr, nr, nr), dtype=jnp.bool_) - mask = jnp.zeros((27, nr, nr, nr), dtype=jnp.bool_) - - # Test operators - collision(f_0, f_1, rho, u) - equilibrium(rho, u) - macroscopic(f_0) - stream(f_0) - bounceback(f_0, f_1, boundary, mask) - bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1) - stepper(f_0, boundary_id, mask, 0) - - - -if __name__ == "__main__": - - # Test backends - compute_backend = [ - xlb.ComputeBackend.WARP, - xlb.ComputeBackend.JAX - ] - - for compute_backend in compute_backend: - test_backends(compute_backend) - print(f"Backend {compute_backend} passed all tests.") diff --git a/requirements.txt b/requirements.txt index c8c6fc2..5f356bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ termcolor==2.3.0 PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git tqdm==4.66.2 warp-lang==1.0.2 -numpy-stl==3.1.1 \ No newline at end of file +numpy-stl==3.1.1 +pydantic==2.7.0 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/backends_conformance/boundary_conditions.py b/tests/backends_conformance/boundary_conditions.py index e75e249..ed2e5c3 100644 --- a/tests/backends_conformance/boundary_conditions.py +++ b/tests/backends_conformance/boundary_conditions.py @@ -216,3 +216,4 @@ def test_boundary_conditions(self): if __name__ == "__main__": unittest.main() + diff --git a/tests/grids/__init__.py b/tests/grids/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/grids/test_jax_grid.py b/tests/grids/test_jax_grid.py new file mode 100644 index 0000000..3c74b6f --- /dev/null +++ b/tests/grids/test_jax_grid.py @@ -0,0 +1,61 @@ +import pytest +import jax +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.grid import grid +from jax.sharding import Mesh +from jax.experimental import mesh_utils + + +def init_xlb_env(): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=xlb.velocity_set.D2Q9, # does not affect the test + ) + + +@pytest.mark.parametrize("grid_size", [50, 100, 150]) +def test_jax_2d_grid_initialization(grid_size): + init_xlb_env() + grid_shape = (grid_size, grid_size) + my_grid = grid(grid_shape) + f = my_grid.create_field(cardinality=9) + n_devices = jax.device_count() + + device_mesh = mesh_utils.create_device_mesh((1, n_devices, 1)) + expected_mesh = Mesh(device_mesh, axis_names=("cardinality", "x", "y")) + + assert f.shape == (9,) + grid_shape, "Field shape is incorrect" + assert f.sharding.mesh == expected_mesh, "Field sharding mesh is incorrect" + assert f.sharding.spec == ("cardinality", "x", "y"), "PartitionSpec is incorrect" + + +@pytest.mark.parametrize("grid_size", [50, 100, 150]) +def test_jax_3d_grid_initialization(grid_size): + init_xlb_env() + grid_shape = (grid_size, grid_size, grid_size) + my_grid = grid(grid_shape) + f = my_grid.create_field(cardinality=9) + n_devices = jax.device_count() + + device_mesh = mesh_utils.create_device_mesh((1, n_devices, 1, 1)) + expected_mesh = Mesh(device_mesh, axis_names=("cardinality", "x", "y", "z")) + + assert f.shape == (9,) + grid_shape, "Field shape is incorrect" + assert f.sharding.mesh == expected_mesh, "Field sharding mesh is incorrect" + assert f.sharding.spec == ( + "cardinality", + "x", + "y", + "z", + ), "PartitionSpec is incorrect" + + +@pytest.fixture(autouse=True) +def setup_xlb_env(): + init_xlb_env() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/grids/test_warp_grid.py b/tests/grids/test_warp_grid.py new file mode 100644 index 0000000..ee5b905 --- /dev/null +++ b/tests/grids/test_warp_grid.py @@ -0,0 +1,48 @@ +import pytest +import warp as wp +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.grid import grid +from xlb.precision_policy import Precision + + +def init_xlb_warp_env(): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=xlb.velocity_set.D2Q9, + ) + + +def test_warp_grid_create_field(): + for grid_shape in [(100, 100), (100, 100, 100)]: + init_xlb_warp_env() + my_grid = grid(grid_shape) + f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) + + assert f.shape == (9,) + grid_shape, "Field shape is incorrect" + assert isinstance(f, wp.array), "Field should be a Warp ndarray" + + +def test_warp_grid_create_field_init_val(): + init_xlb_warp_env() + grid_shape = (100, 100) + init_val = 3.14 + my_grid = grid(grid_shape) + + f = my_grid.create_field(cardinality=9, dtype=Precision.FP32, init_val=init_val) + assert isinstance(f, wp.array), "Field should be a Warp ndarray" + + f = f.numpy() + assert f.shape == (9,) + grid_shape, "Field shape is incorrect" + assert np.allclose(f, init_val), "Field not properly initialized with init_val" + + +@pytest.fixture(autouse=True) +def setup_xlb_env(): + init_xlb_warp_env() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/kernels/jax/test_equilibrium_jax.py b/tests/kernels/jax/test_equilibrium_jax.py new file mode 100644 index 0000000..7498069 --- /dev/null +++ b/tests/kernels/jax/test_equilibrium_jax.py @@ -0,0 +1,42 @@ +import pytest +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.grid import grid +from xlb.default_config import DefaultConfig + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + +@pytest.mark.parametrize("dim,velocity_set,grid_shape", [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)) +]) + +def test_quadratic_equilibrium(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid(grid_shape) + + rho = my_grid.create_field(cardinality=1) + 1.0 # Uniform density + u = my_grid.create_field(cardinality=dim) + 0.0 # Zero velocity + + # Compute equilibrium + compute_macro = QuadraticEquilibrium() + f_eq = compute_macro(rho, u) + + # Test sum of f_eq across cardinality at each point + sum_f_eq = np.sum(f_eq, axis=0) + assert np.allclose(sum_f_eq, 1.0), "Sum of f_eq should be 1.0 across all directions at each grid point" + + # Test that each direction matches the expected weights + weights = DefaultConfig.velocity_set.w + for i, weight in enumerate(weights): + assert np.allclose(f_eq[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/warp/test_equilibrium_warp.py b/tests/kernels/warp/test_equilibrium_warp.py new file mode 100644 index 0000000..ef75624 --- /dev/null +++ b/tests/kernels/warp/test_equilibrium_warp.py @@ -0,0 +1,48 @@ +import pytest +import warp as wp +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.grid import grid +from xlb.default_config import DefaultConfig + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + +@pytest.mark.parametrize("dim,velocity_set,grid_shape", [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)) +]) +def test_quadratic_equilibrium(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid(grid_shape) + + rho = my_grid.create_field(cardinality=1, init_val=1.0) + u = my_grid.create_field(cardinality=dim, init_val=0.0) + + f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) + + compute_macro = QuadraticEquilibrium() + f_eq = compute_macro(rho, u, f_eq) + + f_eq_np = f_eq.numpy() + + sum_f_eq = np.sum(f_eq_np, axis=0) + assert np.allclose(sum_f_eq, 1.0), "Sum of f_eq should be 1.0 across all directions at each grid point" + + weights = DefaultConfig.velocity_set.w + for i, weight in enumerate(weights): + assert np.allclose(f_eq_np[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" + +# @pytest.fixture(autouse=True) +# def setup_xlb_env(request): +# dim, velocity_set, grid_shape = request.param +# init_xlb_env(velocity_set) + +if __name__ == "__main__": + pytest.main() diff --git a/xlb/__init__.py b/xlb/__init__.py index 84d38c5..4dc0639 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -4,7 +4,7 @@ from xlb.physics_type import PhysicsType # Config -from .global_config import init, current_backend +from .default_config import init, default_backend # Velocity Set import xlb.velocity_set diff --git a/xlb/default_config.py b/xlb/default_config.py new file mode 100644 index 0000000..7f0e617 --- /dev/null +++ b/xlb/default_config.py @@ -0,0 +1,40 @@ +from xlb.compute_backend import ComputeBackend +from dataclasses import dataclass + + +@dataclass +class DefaultConfig: + default_precision_policy = None + velocity_set = None + default_backend = None + + +def init(velocity_set, default_backend, default_precision_policy): + DefaultConfig.velocity_set = velocity_set() + DefaultConfig.default_backend = default_backend + DefaultConfig.default_precision_policy = default_precision_policy + + if default_backend == ComputeBackend.WARP: + import warp as wp + + wp.init() + elif default_backend == ComputeBackend.JAX: + check_multi_gpu_support() + else: + raise ValueError(f"Unsupported compute backend: {default_backend}") + + +def default_backend() -> ComputeBackend: + return DefaultConfig.default_backend + + +def check_multi_gpu_support(): + import jax + + gpus = jax.devices("gpu") + if len(gpus) > 1: + print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) + elif len(gpus) == 1: + print("Single-GPU support is available: 1 GPU detected.") + else: + print("No GPU support is available; CPU fallback will be used.") diff --git a/xlb/global_config.py b/xlb/global_config.py deleted file mode 100644 index 17a1839..0000000 --- a/xlb/global_config.py +++ /dev/null @@ -1,14 +0,0 @@ -class GlobalConfig: - precision_policy = None - velocity_set = None - compute_backend = None - - -def init(velocity_set, compute_backend, precision_policy): - GlobalConfig.velocity_set = velocity_set() - GlobalConfig.compute_backend = compute_backend - GlobalConfig.precision_policy = precision_policy - - -def current_backend(): - return GlobalConfig.compute_backend diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index a777f8f..47bcef1 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1,3 +1 @@ -from xlb.grid.grid import Grid -from xlb.grid.warp_grid import WarpGrid -from xlb.grid.jax_grid import JaxGrid +from xlb.grid.grid import grid, Grid \ No newline at end of file diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 2276929..6dbdaf6 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,27 +1,41 @@ from abc import ABC, abstractmethod +from typing import Any, Literal, Optional, Tuple -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy, Precision -from xlb.velocity_set import VelocitySet -from xlb.operator import Operator +from xlb.precision_policy import Precision -class Grid(ABC): +def grid( + shape: Tuple[int, ...], + compute_backend: ComputeBackend = None, + parallel: bool = False, +): + compute_backend = compute_backend or DefaultConfig.default_backend + if compute_backend == ComputeBackend.WARP: + from xlb.grid.warp_grid import WarpGrid + + return WarpGrid(shape) + elif compute_backend == ComputeBackend.JAX: + from xlb.grid.jax_grid import JaxGrid + + return JaxGrid(shape) + + raise ValueError(f"Compute backend {compute_backend} is not supported") - def __init__( - self, - shape : tuple, - ): - # Set parameters + +class Grid(ABC): + def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): self.shape = shape self.dim = len(shape) - - def parallelize_operator(self, operator: Operator): - raise NotImplementedError("Parallelization not implemented, child class must implement") + self.parallel = False + self.compute_backend = compute_backend + self._initialize_backend() @abstractmethod - def create_field( - self, name: str, cardinality: int, precision: Precision, callback=None - ): - raise NotImplementedError("create_field not implemented, child class must implement") + def _initialize_backend(self): + pass + + # @abstractmethod + # def parallelize_operator(self, operator: Operator): + # pass diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index f17545b..2d3b04d 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -1,23 +1,25 @@ +from typing import Any, Literal, Optional, Tuple from jax.sharding import PartitionSpec as P from jax.sharding import NamedSharding, Mesh from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map -import numpy as np +from xlb.compute_backend import ComputeBackend + import jax.numpy as jnp +from jax import lax import jax +from xlb.default_config import DefaultConfig from xlb.grid import Grid -from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.precision_policy import Precision class JaxGrid(Grid): def __init__(self, shape): - super().__init__(shape) - self._initialize_jax_backend() + super().__init__(shape, ComputeBackend.JAX) - def _initialize_jax_backend(self): + def _initialize_backend(self): self.nDevices = jax.device_count() self.backend = jax.default_backend() device_mesh = ( @@ -25,69 +27,40 @@ def _initialize_jax_backend(self): if self.dim == 2 else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) ) - self.global_mesh = ( + global_mesh = ( Mesh(device_mesh, axis_names=("cardinality", "x", "y")) if self.dim == 2 else Mesh(device_mesh, axis_names=("cardinality", "x", "y", "z")) ) self.sharding = ( - NamedSharding(self.global_mesh, P("cardinality", "x", "y")) + NamedSharding(global_mesh, P("cardinality", "x", "y")) if self.dim == 2 - else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) + else NamedSharding(global_mesh, P("cardinality", "x", "y", "z")) ) - self.grid_shape_per_gpu = ( - self.shape[0] // self.nDevices, - ) + self.shape[1:] - - - def parallelize_operator(self, operator: Operator): - # TODO: fix this - # Make parallel function - def _parallel_operator(f): - rightPerm = [ - (i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices) - ] - leftPerm = [ - ((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices) - ] - f = self.func(f) - left_comm, right_comm = ( - f[self.velocity_set.right_indices, :1, ...], - f[self.velocity_set.left_indices, -1:, ...], - ) - left_comm, right_comm = ( - lax.ppermute(left_comm, perm=rightPerm, axis_name="x"), - lax.ppermute(right_comm, perm=leftPerm, axis_name="x"), - ) - f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) - f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) - - return f + def create_field( + self, + cardinality: int, + dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, + init_val=None, + ): + sharding_dim = self.shape[-1] // self.nDevices + device_shape = (cardinality, sharding_dim, *self.shape[1:]) + full_shape = (cardinality, *self.shape) + arrays = [] - in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) - out_specs = in_specs + dype = dtype.jax_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.jax_dtype - f = shard_map( - self._parallel_func, - mesh=self.grid.global_mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - )(f) - return f - - - def create_field(self, cardinality: int, precision: Precision, callback=None): - # Get shape of the field - shape = (cardinality,) + (self.shape) - - # Create field - if callback is None: - f = np.full(shape, 0.0, dtype=precision.jax_dtype) - f = jax.device_put(f, self.sharding) - else: - f = jax.make_array_from_callback(shape, self.sharding, callback) - - # Add field to the field dictionary - return f + for d, index in self.sharding.addressable_devices_indices_map( + full_shape + ).items(): + jax.default_device = d + if init_val: + x = jnp.full(device_shape, init_val, dtype=dtype) + else: + x = jnp.zeros(shape=device_shape, dtype=dtype) + arrays += [jax.device_put(x, d)] + jax.default_device = jax.devices()[0] + return jax.make_array_from_single_device_arrays( + full_shape, self.sharding, arrays + ) diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index 97b337b..92bed54 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -1,27 +1,43 @@ +from dataclasses import field import warp as wp from xlb.grid import Grid from xlb.operator import Operator from xlb.precision_policy import Precision +from xlb.compute_backend import ComputeBackend +from typing import Literal +from xlb.default_config import DefaultConfig +import numpy as np + class WarpGrid(Grid): def __init__(self, shape): - super().__init__(shape) + super().__init__(shape, ComputeBackend.WARP) + + def _initialize_backend(self): + pass def parallelize_operator(self, operator: Operator): # TODO: Implement parallelization of the operator - raise NotImplementedError("Parallelization of the operator is not implemented yet for the WarpGrid") + raise NotImplementedError( + "Parallelization of the operator is not implemented yet for the WarpGrid" + ) - def create_field(self, cardinality: int, precision: Precision, callback=None): - # Get shape of the field + def create_field( + self, + cardinality: int, + dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, + init_val=None, + ): + dtype = ( + dtype.wp_dtype + if dtype + else DefaultConfig.default_precision_policy.store_precision.wp_dtype + ) shape = (cardinality,) + (self.shape) - # Create the field - f = wp.zeros(shape, dtype=precision.wp_dtype) - - # Raise error on callback - if callback is not None: - raise ValueError("Callback is not supported in the WarpGrid") - - # Add field to the field dictionary + if init_val is None: + f = wp.zeros(shape, dtype=dtype) + else: + f = wp.full(shape, init_val, dtype=dtype) return f diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 83660c1..cbec6ac 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -7,7 +7,7 @@ import warp as wp from typing import Tuple -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 572f345..4c886c4 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -7,7 +7,7 @@ import warp as wp from typing import Tuple -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index 148e9b8..8a6f956 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -8,7 +8,7 @@ import warp as wp from typing import Tuple -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py index 726ca37..11f4155 100644 --- a/xlb/operator/equilibrium/equilibrium.py +++ b/xlb/operator/equilibrium/equilibrium.py @@ -1,6 +1,5 @@ # Base class for all equilibriums from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy from xlb.operator.operator import Operator diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index d5bb20a..c84abd1 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -8,7 +8,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator import Operator -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig class QuadraticEquilibrium(Equilibrium): @@ -91,7 +91,7 @@ def functional( # Construct the warp kernel @wp.kernel - def kernel( + def kernel3d( rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), f: wp.array4d(dtype=Any), @@ -111,6 +111,29 @@ def kernel( for l in range(self.velocity_set.q): f[l, index[0], index[1], index[2]] = feq[l] + @wp.kernel + def kernel2d( + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + f: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the equilibrium + _u = _u_vec() + for d in range(self.velocity_set.d): + _u[d] = u[d, index[0], index[1]] + _rho = rho[0, index[0], index[1]] + feq = functional(_rho, _u) + + # Set the output + for l in range(self.velocity_set.q): + f[l, index[0], index[1]] = feq[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 161705e..7cbf38a 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -6,7 +6,7 @@ import warp as wp from typing import Tuple, Any -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 1724ffc..1f3d279 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -4,7 +4,7 @@ from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy, Precision -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig class Operator: @@ -18,9 +18,9 @@ class Operator: def __init__(self, velocity_set, precision_policy, compute_backend): # Set the default values from the global config - self.velocity_set = velocity_set or GlobalConfig.velocity_set - self.precision_policy = precision_policy or GlobalConfig.precision_policy - self.compute_backend = compute_backend or GlobalConfig.compute_backend + self.velocity_set = velocity_set or DefaultConfig.velocity_set + self.precision_policy = precision_policy or DefaultConfig.default_precision_policy + self.compute_backend = compute_backend or DefaultConfig.default_backend # Check if the compute backend is supported if self.compute_backend not in ComputeBackend: diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index bf40c55..fad3c8f 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from xlb.compute_backend import ComputeBackend -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.precision_policy.jax_precision_policy import ( JaxFp32Fp32, @@ -14,63 +14,63 @@ class Fp64Fp64: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackend.JAX - or GlobalConfig.compute_backend == ComputeBackend.PALLAS + DefaultConfig.compute_backend == ComputeBackend.JAX + or DefaultConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp64Fp64() else: raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" + f"Unsupported compute backend: {DefaultConfig.compute_backend}" ) class Fp64Fp32: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackend.JAX - or GlobalConfig.compute_backend == ComputeBackend.PALLAS + DefaultConfig.compute_backend == ComputeBackend.JAX + or DefaultConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp64Fp32() else: raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" + f"Unsupported compute backend: {DefaultConfig.compute_backend}" ) class Fp32Fp32: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackend.JAX - or GlobalConfig.compute_backend == ComputeBackend.PALLAS + DefaultConfig.compute_backend == ComputeBackend.JAX + or DefaultConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp32Fp32() else: raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" + f"Unsupported compute backend: {DefaultConfig.compute_backend}" ) class Fp64Fp16: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackend.JAX - or GlobalConfig.compute_backend == ComputeBackend.PALLAS + DefaultConfig.compute_backend == ComputeBackend.JAX + or DefaultConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp64Fp16() else: raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" + f"Unsupported compute backend: {DefaultConfig.compute_backend}" ) class Fp32Fp16: def __new__(cls): if ( - GlobalConfig.compute_backend == ComputeBackend.JAX - or GlobalConfig.compute_backend == ComputeBackend.PALLAS + DefaultConfig.compute_backend == ComputeBackend.JAX + or DefaultConfig.compute_backend == ComputeBackend.PALLAS ): return JaxFp32Fp16() else: raise ValueError( - f"Unsupported compute backend: {GlobalConfig.compute_backend}" + f"Unsupported compute backend: {DefaultConfig.compute_backend}" ) diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 335fd72..7979c11 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -1,7 +1,7 @@ # Base class for all stepper operators from xlb.compute_backend import ComputeBackend -from xlb.global_config import GlobalConfig +from xlb.default_config import DefaultConfig from xlb.operator.operator import Operator @@ -23,10 +23,10 @@ def __init__( # Set parameters self.shape = shape - self.velocity_set = velocity_set or GlobalConfig.velocity_set - self.precision_policy = precision_policy or GlobalConfig.precision_policy - self.compute_backend = compute_backend or GlobalConfig.compute_backend - self.grid_backend = grid_backend or GlobalConfig.grid_backend + self.velocity_set = velocity_set or DefaultConfig.velocity_set + self.precision_policy = precision_policy or DefaultConfig.precision_policy + self.compute_backend = compute_backend or DefaultConfig.compute_backend + self.grid_backend = grid_backend or DefaultConfig.grid_backend self.boundary_conditions = boundary_conditions # Make grid From e33827a7c0ac00a553df061cbb0fafff89fdb46b Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Thu, 25 Apr 2024 12:52:24 -0400 Subject: [PATCH 028/144] WIP added tests for all main kernels + changes to operator + 2D kernels for Warp --- .../{test_jax_grid.py => test_grid_jax.py} | 13 +++- .../{test_warp_grid.py => test_grid_warp.py} | 5 +- .../collision/test_bgk_collision_jax.py | 51 ++++++++++++ .../collision/test_bgk_collision_warp.py | 55 +++++++++++++ .../test_equilibrium_jax.py | 27 +++++-- .../test_equilibrium_warp.py | 0 .../macroscopic/test_macroscopic_jax.py | 48 ++++++++++++ .../macroscopic/test_macroscopic_warp.py | 57 ++++++++++++++ tests/kernels/stream/test_stream_jax.py | 67 ++++++++++++++++ tests/kernels/stream/test_stream_warp.py | 78 +++++++++++++++++++ xlb/operator/collision/bgk.py | 56 +++++++------ xlb/operator/macroscopic/macroscopic.py | 25 +++++- xlb/operator/operator.py | 58 ++++++++------ xlb/operator/stream/stream.py | 53 +++++++++++-- 14 files changed, 528 insertions(+), 65 deletions(-) rename tests/grids/{test_jax_grid.py => test_grid_jax.py} (83%) rename tests/grids/{test_warp_grid.py => test_grid_warp.py} (87%) create mode 100644 tests/kernels/collision/test_bgk_collision_jax.py create mode 100644 tests/kernels/collision/test_bgk_collision_warp.py rename tests/kernels/{jax => equilibrium}/test_equilibrium_jax.py (58%) rename tests/kernels/{warp => equilibrium}/test_equilibrium_warp.py (100%) create mode 100644 tests/kernels/macroscopic/test_macroscopic_jax.py create mode 100644 tests/kernels/macroscopic/test_macroscopic_warp.py create mode 100644 tests/kernels/stream/test_stream_jax.py create mode 100644 tests/kernels/stream/test_stream_warp.py diff --git a/tests/grids/test_jax_grid.py b/tests/grids/test_grid_jax.py similarity index 83% rename from tests/grids/test_jax_grid.py rename to tests/grids/test_grid_jax.py index 3c74b6f..9f8a90f 100644 --- a/tests/grids/test_jax_grid.py +++ b/tests/grids/test_grid_jax.py @@ -5,7 +5,7 @@ from xlb.grid import grid from jax.sharding import Mesh from jax.experimental import mesh_utils - +import jax.numpy as jnp def init_xlb_env(): xlb.init( @@ -51,6 +51,17 @@ def test_jax_3d_grid_initialization(grid_size): "z", ), "PartitionSpec is incorrect" +def test_jax_grid_create_field_init_val(): + init_xlb_env() + grid_shape = (100, 100) + init_val = 3.14 + my_grid = grid(grid_shape) + + f = my_grid.create_field(cardinality=9, init_val=init_val) + assert f.shape == (9,) + grid_shape, "Field shape is incorrect" + assert jnp.allclose(f, init_val), "Field not properly initialized with init_val" + + @pytest.fixture(autouse=True) def setup_xlb_env(): diff --git a/tests/grids/test_warp_grid.py b/tests/grids/test_grid_warp.py similarity index 87% rename from tests/grids/test_warp_grid.py rename to tests/grids/test_grid_warp.py index ee5b905..5bdac7c 100644 --- a/tests/grids/test_warp_grid.py +++ b/tests/grids/test_grid_warp.py @@ -15,8 +15,9 @@ def init_xlb_warp_env(): ) -def test_warp_grid_create_field(): - for grid_shape in [(100, 100), (100, 100, 100)]: +@pytest.mark.parametrize("grid_size", [50, 100, 150]) +def test_warp_grid_create_field(grid_size): + for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]: init_xlb_warp_env() my_grid = grid(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py new file mode 100644 index 0000000..adabd41 --- /dev/null +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -0,0 +1,51 @@ +import pytest +import jax.numpy as jnp +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.collision import BGK +from xlb.grid import grid +from xlb.default_config import DefaultConfig + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape,omega", + [ + (2, xlb.velocity_set.D2Q9, (100, 100), 0.6), + (2, xlb.velocity_set.D2Q9, (100, 100), 1.0), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 0.6), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0), + ], +) +def test_bgk_ollision(dim, velocity_set, grid_shape, omega): + init_xlb_env(velocity_set) + my_grid = grid(grid_shape) + + rho = my_grid.create_field(cardinality=1, init_val=1.0) + u = my_grid.create_field(cardinality=dim, init_val=0.0) + + # Compute equilibrium + compute_macro = QuadraticEquilibrium() + f_eq = compute_macro(rho, u) + + # Compute collision + + compute_collision = BGK(omega=omega) + + f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) + + f_out = compute_collision(f_orig, f_eq) + + assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq)) + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py new file mode 100644 index 0000000..20e2f48 --- /dev/null +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -0,0 +1,55 @@ +import pytest +import warp as wp +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.collision import BGK +from xlb.grid import grid +from xlb.default_config import DefaultConfig +from xlb.precision_policy import Precision + +def init_xlb_warp_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape,omega", + [ + (2, xlb.velocity_set.D2Q9, (100, 100), 0.6), + (2, xlb.velocity_set.D2Q9, (100, 100), 1.0), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 0.6), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0), + ], +) +def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): + init_xlb_warp_env(velocity_set) + my_grid = grid(grid_shape) + + rho = my_grid.create_field(cardinality=1, init_val=1.0) + u = my_grid.create_field(cardinality=dim, init_val=0.0) + + compute_macro = QuadraticEquilibrium() + + f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) + f_eq = compute_macro(rho, u, f_eq) + + + compute_collision = BGK(omega=omega) + f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) + + f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) + f_out = compute_collision(f_orig, f_eq, f_out) + + f_eq = f_eq.numpy() + f_out = f_out.numpy() + f_orig = f_orig.numpy() + + assert np.allclose(f_out, f_orig - omega * (f_orig - f_eq), atol=1e-5) + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/jax/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py similarity index 58% rename from tests/kernels/jax/test_equilibrium_jax.py rename to tests/kernels/equilibrium/test_equilibrium_jax.py index 7498069..c0f50d7 100644 --- a/tests/kernels/jax/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -6,6 +6,7 @@ from xlb.grid import grid from xlb.default_config import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -13,17 +14,22 @@ def init_xlb_env(velocity_set): velocity_set=velocity_set, ) -@pytest.mark.parametrize("dim,velocity_set,grid_shape", [ - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q19, (50, 50, 50)) -]) +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + ], +) def test_quadratic_equilibrium(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid = grid(grid_shape) - rho = my_grid.create_field(cardinality=1) + 1.0 # Uniform density - u = my_grid.create_field(cardinality=dim) + 0.0 # Zero velocity + rho = my_grid.create_field(cardinality=1, init_val=1.0) + u = my_grid.create_field(cardinality=dim, init_val=0.0) # Compute equilibrium compute_macro = QuadraticEquilibrium() @@ -31,12 +37,17 @@ def test_quadratic_equilibrium(dim, velocity_set, grid_shape): # Test sum of f_eq across cardinality at each point sum_f_eq = np.sum(f_eq, axis=0) - assert np.allclose(sum_f_eq, 1.0), "Sum of f_eq should be 1.0 across all directions at each grid point" + assert np.allclose( + sum_f_eq, 1.0 + ), f"Sum of f_eq should be 1.0 across all directions at each grid point" # Test that each direction matches the expected weights weights = DefaultConfig.velocity_set.w for i, weight in enumerate(weights): - assert np.allclose(f_eq[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" + assert np.allclose( + f_eq[i, ...], weight + ), f"Direction {i} in f_eq does not match the expected weight" + if __name__ == "__main__": pytest.main() diff --git a/tests/kernels/warp/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py similarity index 100% rename from tests/kernels/warp/test_equilibrium_warp.py rename to tests/kernels/equilibrium/test_equilibrium_warp.py diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py new file mode 100644 index 0000000..786063f --- /dev/null +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -0,0 +1,48 @@ +import pytest +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.grid import grid +from xlb.default_config import DefaultConfig + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape,rho,velocity", + [ + (2, xlb.velocity_set.D2Q9, (100, 100), 1.0, 0.0), + (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 1.0), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0, 0.0), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), + ],) +def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): + init_xlb_env(velocity_set) + my_grid = grid(grid_shape) + + rho_field = my_grid.create_field(cardinality=1, init_val=rho) + velocity_field = my_grid.create_field(cardinality=dim, init_val=velocity) + + # Compute equilibrium + f_eq = QuadraticEquilibrium()(rho_field, velocity_field) + + + compute_macro = Macroscopic() + + rho_calc, u_calc = compute_macro(f_eq) + + # Test sum of f_eq which should be 1.0 for rho and 0.0 for u + assert np.allclose(rho_calc, rho), "Sum of f_eq should be {rho} for rho" + assert np.allclose(u_calc, velocity), "Sum of f_eq should be {velocity} for u" + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py new file mode 100644 index 0000000..163db95 --- /dev/null +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -0,0 +1,57 @@ +import pytest +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.grid import grid +from xlb.default_config import DefaultConfig +import warp as wp + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape,rho,velocity", + [ + (2, xlb.velocity_set.D2Q9, (100, 100), 1.0, 0.0), + (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 1.0), + (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 2.0), + (2, xlb.velocity_set.D2Q9, (50, 50), 1.1, 2.0), + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0, 0.0), + # (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. + # (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. + ], +) +def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): + init_xlb_env(velocity_set) + my_grid = grid(grid_shape) + + rho_field = my_grid.create_field(cardinality=1, init_val=rho) + velocity_field = my_grid.create_field(cardinality=dim, init_val=velocity) + + f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) + f_eq = QuadraticEquilibrium()(rho_field, velocity_field, f_eq) + + compute_macro = Macroscopic() + rho_calc = my_grid.create_field(cardinality=1) + u_calc = my_grid.create_field(cardinality=dim) + + rho_calc, u_calc = compute_macro(f_eq, rho_calc, u_calc) + + assert np.allclose( + rho_calc.numpy(), rho + ), f"Computed density should be close to initialized density {rho}" + assert np.allclose( + u_calc.numpy(), velocity + ), f"Computed velocity should be close to initialized velocity {velocity}" + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/stream/test_stream_jax.py b/tests/kernels/stream/test_stream_jax.py new file mode 100644 index 0000000..83f1714 --- /dev/null +++ b/tests/kernels/stream/test_stream_jax.py @@ -0,0 +1,67 @@ +import pytest +import jax.numpy as jnp +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream import Stream +from xlb.default_config import DefaultConfig + +from xlb.grid import grid + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + ], +) +def test_stream_operator(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid(grid_shape) + velocity_set = DefaultConfig.velocity_set + + stream_op = Stream() + + f_initial = my_grid.create_field(cardinality=velocity_set.q) + f_initial = f_initial.at[..., f_initial.shape[-1] // 2].set(1) + + f_streamed = stream_op(f_initial) + + expected = [] + + if dim == 2: + for i in range(velocity_set.q): + expected.append( + jnp.roll( + f_initial[i, ...], + (velocity_set.c[0][i], velocity_set.c[1][i]), + axis=(0, 1), + ) + ) + elif dim == 3: + for i in range(velocity_set.q): + expected.append( + jnp.roll( + f_initial[i, ...], + (velocity_set.c[0][i], velocity_set.c[1][i], velocity_set.c[2][i]), + axis=(0, 1, 2), + ) + ) + + expected = jnp.stack(expected, axis=0) + + assert jnp.allclose(f_streamed, expected), "Streaming did not occur as expected" + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py new file mode 100644 index 0000000..643ea70 --- /dev/null +++ b/tests/kernels/stream/test_stream_warp.py @@ -0,0 +1,78 @@ +import pytest +import jax.numpy as jnp +import numpy as np +import warp as wp +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.operator.stream import Stream +from xlb.default_config import DefaultConfig + +from xlb.grid import grid + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + ], +) +def test_stream_operator(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid_jax = grid(grid_shape, compute_backend=ComputeBackend.JAX) + velocity_set = DefaultConfig.velocity_set + + f_initial = my_grid_jax.create_field(cardinality=velocity_set.q) + f_initial = f_initial.at[..., f_initial.shape[-1] // 2].set(1) + + expected = [] + + if dim == 2: + for i in range(velocity_set.q): + expected.append( + jnp.roll( + f_initial[i, ...], + (velocity_set.c[0][i], velocity_set.c[1][i]), + axis=(0, 1), + ) + ) + elif dim == 3: + for i in range(velocity_set.q): + expected.append( + jnp.roll( + f_initial[i, ...], + (velocity_set.c[0][i], velocity_set.c[1][i], velocity_set.c[2][i]), + axis=(0, 1, 2), + ) + ) + + expected = jnp.stack(expected, axis=0) + + if dim == 2: + f_initial_warp = wp.from_numpy(f_initial, dtype=wp.float32) + + elif dim == 3: + f_initial_warp = wp.from_numpy(f_initial, dtype=wp.float32) + + stream_op = Stream() + my_grid_warp = grid(grid_shape, compute_backend=ComputeBackend.WARP) + f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q) + f_streamed = stream_op(f_initial_warp, f_streamed) + + assert jnp.allclose( + f_streamed.numpy(), np.array(expected) + ), "Streaming did not occur as expected" + + +if __name__ == "__main__": + pytest.main() diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 69718aa..520cf3f 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -17,17 +17,13 @@ class BGK(Collision): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation( - self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray - ): + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.compute_dtype(self.omega) * fneq return fout @Operator.register_backend(ComputeBackend.PALLAS) - def pallas_implementation( - self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray - ): + def pallas_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq return fout @@ -40,23 +36,16 @@ def _construct_warp(self): # Construct the functional @wp.func - def functional( - f: Any, - feq: Any, - rho: Any, - u: Any, - ): + def functional(f: Any, feq: Any): fneq = f - feq fout = f - _omega * fneq return fout # Construct the warp kernel @wp.kernel - def kernel( + def kernel3d( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), ): # Get the global index @@ -69,30 +58,51 @@ def kernel( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = self._warp_u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1], index[2]] - _rho = rho[0, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, _rho, _u) + _fout = functional(_f, _feq) # Write the result for l in range(self.velocity_set.q): fout[l, index[0], index[1], index[2]] = _fout[l] + # Construct the warp kernel + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + feq: wp.array3d(dtype=Any), + fout: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Load needed values + _f = _f_vec() + _feq = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _feq[l] = feq[l, index[0], index[1]] + + # Compute the collision + _fout = functional(_f, _feq) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1]] = _fout[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, rho, u, fout): + def warp_implementation(self, f, feq, fout): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ f, feq, - rho, - u, fout, ], dim=f.shape[1:], diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 7cbf38a..9c6ac10 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -98,7 +98,7 @@ def functional(f: _f_vec): # Construct the kernel @wp.kernel - def kernel( + def kernel3d( f: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), @@ -118,6 +118,29 @@ def kernel( for d in range(self.velocity_set.d): u[d, index[0], index[1], index[2]] = _u[d] + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the equilibrium + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + (_rho, _u) = functional(_f) + + # Set the output + rho[0, index[0], index[1]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1]] = _u[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 1f3d279..a8b85d0 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,4 +1,6 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) + +import inspect import warp as wp from typing import Any @@ -16,10 +18,12 @@ class Operator: _backends = {} - def __init__(self, velocity_set, precision_policy, compute_backend): + def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): # Set the default values from the global config self.velocity_set = velocity_set or DefaultConfig.velocity_set - self.precision_policy = precision_policy or DefaultConfig.default_precision_policy + self.precision_policy = ( + precision_policy or DefaultConfig.default_precision_policy + ) self.compute_backend = compute_backend or DefaultConfig.default_backend # Check if the compute backend is supported @@ -37,35 +41,41 @@ def register_backend(cls, backend_name): """ def decorator(func): - # Use the combination of operator name and backend name as the key subclass_name = func.__qualname__.split(".")[0] - key = (subclass_name, backend_name) + signature = inspect.signature(func) + key = (subclass_name, backend_name, str(signature)) cls._backends[key] = func return func return decorator - def __call__(self, *args, callback=None, **kwargs): - """ - Calls the operator with the compute backend specified in the constructor. - If a callback is provided, it is called either with the result of the operation - or with the original arguments and keyword arguments if the backend modifies them by reference. - """ - key = (self.__class__.__name__, self.compute_backend) - backend_method = self._backends.get(key) - - if backend_method: - result = backend_method(self, *args, **kwargs) - - # Determine what to pass to the callback based on the backend behavior - callback_arg = result if result is not None else (args, kwargs) - - if callback and callable(callback): - callback(callback_arg) + return decorator - return result - else: - raise NotImplementedError(f"Backend {self.compute_backend} not implemented") + def __call__(self, *args, callback=None, **kwargs): + method_candidates = [ + (key, method) + for key, method in self._backends.items() + if key[0] == self.__class__.__name__ and key[1] == self.compute_backend + ] + bound_arguments = None + for key, backend_method in method_candidates: + try: + # This attempts to bind the provided args and kwargs to the backend method's signature + bound_arguments = inspect.signature(backend_method).bind( + self, *args, **kwargs + ) + bound_arguments.apply_defaults() # This fills in any default values + result = backend_method(self, *args, **kwargs) + callback_arg = result if result is not None else (args, kwargs) + if callback and callable(callback): + callback(callback_arg) + return result + except TypeError: + continue # This skips to the next candidate if binding fails + + raise NotImplementedError( + f"No implementation found for backend with key {key} for operator {self.__class__.__name__}" + ) @property def supported_compute_backend(self): diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 8bb2568..77cf22d 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -45,9 +45,7 @@ def _streaming_jax_i(f, c): The updated distribution function after streaming. """ if self.velocity_set.d == 2: - return jnp.roll( - f, (c[0], c[1]), axis=(0, 1) - ) + return jnp.roll(f, (c[0], c[1]), axis=(0, 1)) elif self.velocity_set.d == 3: return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) @@ -60,9 +58,49 @@ def _construct_warp(self): _c = self.velocity_set.wp_c _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + # Construct the warp functional + @wp.func + def functional2d( + f: wp.array3d(dtype=Any), + index: Any, + ): + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + # Get pull index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, l] + + if pull_index[d] < 0: + pull_index[d] = f.shape[d + 1] - 1 + elif pull_index[d] >= f.shape[d + 1]: + pull_index[d] = 0 + + # Read the distribution function + _f[l] = f[l, pull_index[0], pull_index[1]] + + return _f + + @wp.kernel + def kernel2d( + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Set the output + _f = functional2d(f_0, index) + + # Write the output + for l in range(self.velocity_set.q): + f_1[l, index[0], index[1]] = _f[l] + # Construct the funcional to get streamed indices @wp.func - def functional( + def functional3d( f: wp.array4d(dtype=Any), index: Any, ): @@ -86,7 +124,7 @@ def functional( # Construct the warp kernel @wp.kernel - def kernel( + def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), ): @@ -95,12 +133,15 @@ def kernel( index = wp.vec3i(i, j, k) # Set the output - _f = functional(f_0, index) + _f = functional3d(f_0, index) # Write the output for l in range(self.velocity_set.q): f_1[l, index[0], index[1], index[2]] = _f[l] + functional = functional3d if self.velocity_set.d == 3 else functional2d + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) From d37aa1eabf5699d4d58eee0c0ef6fb9652798653 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Mon, 6 May 2024 19:32:47 -0400 Subject: [PATCH 029/144] WIP: Added lots of tests and fixed a bunch of errors --- examples/cfd/example_basic.py | 4 +- examples/cfd/flow_past_sphere.py | 20 ++-- examples/cfd/lid_driven_cavity.py | 34 +++--- examples/cfd/taylor_green.py | 6 +- examples/performance/mlups3d.py | 7 +- .../boundary_conditions.py | 41 +++---- .../bc_equilibrium/test_bc_equilibrium_jax.py | 104 +++++++++++++++++ .../test_bc_equilibrium_warp.py | 108 +++++++++++++++++ .../mask/test_bc_indices_masker_jax.py | 90 +++++++++++++++ .../mask/test_bc_indices_masker_warp.py | 94 +++++++++++++++ tests/grids/test_grid_jax.py | 8 +- tests/grids/test_grid_warp.py | 6 +- .../collision/test_bgk_collision_jax.py | 4 +- .../collision/test_bgk_collision_warp.py | 4 +- .../equilibrium/test_equilibrium_jax.py | 4 +- .../equilibrium/test_equilibrium_warp.py | 4 +- .../macroscopic/test_macroscopic_jax.py | 4 +- .../macroscopic/test_macroscopic_warp.py | 4 +- tests/kernels/stream/test_stream_jax.py | 4 +- tests/kernels/stream/test_stream_warp.py | 10 +- xlb/compute_backend.py | 1 - xlb/grid/__init__.py | 2 +- xlb/grid/grid.py | 11 +- xlb/grid/jax_grid.py | 6 +- xlb/grid/warp_grid.py | 2 +- xlb/operator/boundary_condition/__init__.py | 8 +- .../{do_nothing.py => bc_do_nothing.py} | 19 ++- .../{equilibrium.py => bc_equilibrium.py} | 76 +++++++++--- ...unce_back.py => bc_fullway_bounce_back.py} | 12 +- ...unce_back.py => bc_halfway_bounce_back.py} | 12 +- .../boundary_condition/boundary_condition.py | 11 +- xlb/operator/boundary_masker/__init__.py | 3 - .../boundary_masker/boundary_masker.py | 9 -- .../indices_boundary_masker.py | 109 +++++++++++++----- .../boundary_masker/planar_boundary_masker.py | 20 ++-- .../boundary_masker/stl_boundary_masker.py | 18 +-- xlb/operator/collision/bgk.py | 6 - .../equilibrium/quadratic_equilibrium.py | 30 +---- xlb/operator/macroscopic/macroscopic.py | 27 ----- xlb/operator/operator.py | 7 +- xlb/operator/stepper/__init__.py | 2 +- .../stepper/{nse.py => nse_stepper.py} | 98 ++++------------ xlb/operator/stepper/stepper.py | 12 +- xlb/precision_policy/precision_policy.py | 25 +--- xlb/solver/nse.py | 37 +----- xlb/solver/solver.py | 23 +--- 46 files changed, 712 insertions(+), 434 deletions(-) create mode 100644 tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py create mode 100644 tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py create mode 100644 tests/boundary_conditions/mask/test_bc_indices_masker_jax.py create mode 100644 tests/boundary_conditions/mask/test_bc_indices_masker_warp.py rename xlb/operator/boundary_condition/{do_nothing.py => bc_do_nothing.py} (84%) rename xlb/operator/boundary_condition/{equilibrium.py => bc_equilibrium.py} (58%) rename xlb/operator/boundary_condition/{fullway_bounce_back.py => bc_fullway_bounce_back.py} (90%) rename xlb/operator/boundary_condition/{halfway_bounce_back.py => bc_halfway_bounce_back.py} (91%) delete mode 100644 xlb/operator/boundary_masker/boundary_masker.py rename xlb/operator/stepper/{nse.py => nse_stepper.py} (63%) diff --git a/examples/cfd/example_basic.py b/examples/cfd/example_basic.py index c11abbe..aeac7f7 100644 --- a/examples/cfd/example_basic.py +++ b/examples/cfd/example_basic.py @@ -3,7 +3,7 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.default_config import DefaultConfig import warp as wp -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.precision_policy import Precision import xlb.velocity_set @@ -15,7 +15,7 @@ grid_size = 50 grid_shape = (grid_size, grid_size, grid_size) -my_grid = grid(grid_shape) +my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) # compute_macro = QuadraticEquilibrium() diff --git a/examples/cfd/flow_past_sphere.py b/examples/cfd/flow_past_sphere.py index ef24ac8..515b6f0 100644 --- a/examples/cfd/flow_past_sphere.py +++ b/examples/cfd/flow_past_sphere.py @@ -70,12 +70,12 @@ def warp_implementation(self, rho, u, vel): nr = 256 vel = 0.05 shape = (nr, nr, nr) - grid = xlb.grid.grid(shape=shape) + grid = xlb.grid.grid_factory(shape=shape) rho = grid.create_field(cardinality=1) u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) + boundary_id_field = grid.create_field(cardinality=1, dtype=wp.uint8) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators @@ -157,10 +157,10 @@ def warp_implementation(self, rho, u, vel): indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_id, missing_mask = indices_boundary_masker( + boundary_id_field, missing_mask = indices_boundary_masker( indices, half_way_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -169,12 +169,12 @@ def warp_implementation(self, rho, u, vel): lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, equilibrium_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -183,12 +183,12 @@ def warp_implementation(self, rho, u, vel): lower_bound = (nr-1, 0, 0) upper_bound = (nr-1, nr, nr) direction = (-1, 0, 0) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, do_nothing_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -206,7 +206,7 @@ def warp_implementation(self, rho, u, vel): num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1 = stepper(f0, f1, boundary_id_field, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) @@ -216,7 +216,7 @@ def warp_implementation(self, rho, u, vel): plt.imshow(u[0, :, nr // 2, :].numpy()) plt.colorbar() plt.subplot(1, 2, 2) - plt.imshow(boundary_id[0, :, nr // 2, :].numpy()) + plt.imshow(boundary_id_field[0, :, nr // 2, :].numpy()) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() diff --git a/examples/cfd/lid_driven_cavity.py b/examples/cfd/lid_driven_cavity.py index e5ca559..96cf425 100644 --- a/examples/cfd/lid_driven_cavity.py +++ b/examples/cfd/lid_driven_cavity.py @@ -77,7 +77,7 @@ def run_ldc(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + boundary_id_field = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -143,12 +143,12 @@ def run_ldc(backend, compute_mlup=True): lower_bound = (0, 1, 1) upper_bound = (0, nr-1, nr-1) direction = (1, 0, 0) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, equilibrium_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -157,13 +157,13 @@ def run_ldc(backend, compute_mlup=True): lower_bound = (nr-1, 0, 0) upper_bound = (nr-1, nr, nr) direction = (-1, 0, 0) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, #full_way_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -172,13 +172,13 @@ def run_ldc(backend, compute_mlup=True): lower_bound = (0, 0, 0) upper_bound = (nr, 0, nr) direction = (0, 1, 0) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, #full_way_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -187,13 +187,13 @@ def run_ldc(backend, compute_mlup=True): lower_bound = (0, nr-1, 0) upper_bound = (nr, nr-1, nr) direction = (0, -1, 0) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, #full_way_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -202,13 +202,13 @@ def run_ldc(backend, compute_mlup=True): lower_bound = (0, 0, 0) upper_bound = (nr, nr, 0) direction = (0, 0, 1) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, #full_way_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -217,13 +217,13 @@ def run_ldc(backend, compute_mlup=True): lower_bound = (0, 0, nr-1) upper_bound = (nr, nr, nr-1) direction = (0, 0, -1) - boundary_id, missing_mask = planar_boundary_masker( + boundary_id_field, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, half_way_bc.id, #full_way_bc.id, - boundary_id, + boundary_id_field, missing_mask, (0, 0, 0) ) @@ -246,10 +246,10 @@ def run_ldc(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1 = stepper(f0, f1, boundary_id_field, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, boundary_id, missing_mask, _) + f0 = stepper(f0, boundary_id_field, missing_mask, _) # Plot if necessary if (_ % plot_freq == 0) and (not compute_mlup): @@ -257,10 +257,10 @@ def run_ldc(backend, compute_mlup=True): rho, u = macroscopic(f0, rho, u) local_rho = rho.numpy() local_u = u.numpy() - local_boundary_id = boundary_id.numpy() + local_boundary_id = boundary_id_field.numpy() elif backend == "jax": local_rho, local_u = macroscopic(f0) - local_boundary_id = boundary_id + local_boundary_id = boundary_id_field # Plot the velocity field, rho and boundary id side by side plt.subplot(1, 3, 1) diff --git a/examples/cfd/taylor_green.py b/examples/cfd/taylor_green.py index f842107..b806225 100644 --- a/examples/cfd/taylor_green.py +++ b/examples/cfd/taylor_green.py @@ -135,7 +135,7 @@ def run_taylor_green(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + boundary_id_field = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -187,10 +187,10 @@ def run_taylor_green(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1 = stepper(f0, f1, boundary_id_field, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, boundary_id, missing_mask, _) + f0 = stepper(f0, boundary_id_field, missing_mask, _) # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): diff --git a/examples/performance/mlups3d.py b/examples/performance/mlups3d.py index 2e13769..bce5a0b 100644 --- a/examples/performance/mlups3d.py +++ b/examples/performance/mlups3d.py @@ -7,7 +7,7 @@ from xlb.operator.initializer import EquilibriumInitializer from xlb.solver import IncompressibleNavierStokes -from xlb.grid import Grid +from xlb.grid import grid_factory parser = argparse.ArgumentParser( description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" @@ -36,11 +36,6 @@ solver = IncompressibleNavierStokes(grid, omega=1.0) -# Ahead-of-Time Compilation to remove JIT overhead -# if xlb.current_backend() == ComputeBackend.JAX or xlb.current_backend() == ComputeBackend.PALLAS: -# lowered = jax.jit(solver.step).lower(f, timestep=0) -# solver_step_compiled = lowered.compile() - # Ahead-of-Time Compilation to remove JIT overhead f = solver.step(f, timestep=0) diff --git a/tests/backends_conformance/boundary_conditions.py b/tests/backends_conformance/boundary_conditions.py index ed2e5c3..5a4f0b0 100644 --- a/tests/backends_conformance/boundary_conditions.py +++ b/tests/backends_conformance/boundary_conditions.py @@ -2,6 +2,7 @@ import numpy as np import jax.numpy as jnp import warp as wp +from xlb.grid import grid_factory import xlb wp.init() @@ -28,11 +29,7 @@ def run_boundary_conditions(self, backend): # Make grid nr = 128 shape = (nr, nr, nr) - if backend == "jax": - grid = xlb.grid.JaxGrid(shape=shape) - elif backend == "warp": - grid = xlb.grid.WarpGrid(shape=shape) - + grid = grid_factory(shape) # Make fields f_pre = grid.create_field( cardinality=velocity_set.q, precision=xlb.Precision.FP32 @@ -41,7 +38,7 @@ def run_boundary_conditions(self, backend): cardinality=velocity_set.q, precision=xlb.Precision.FP32 ) f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + boundary_id_field = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field( cardinality=velocity_set.q, precision=xlb.Precision.BOOL ) @@ -98,63 +95,63 @@ def run_boundary_conditions(self, backend): indices = wp.from_numpy(indices, dtype=wp.int32) # Test equilibrium boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, equilibrium_bc.id, boundary_id, missing_mask, (0, 0, 0) + boundary_id_field, missing_mask = indices_boundary_masker( + indices, equilibrium_bc.id, boundary_id_field, missing_mask, (0, 0, 0) ) if backend == "jax": - f_equilibrium = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask) + f_equilibrium = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask) elif backend == "warp": f_equilibrium = grid.create_field( cardinality=velocity_set.q, precision=xlb.Precision.FP32 ) f_equilibrium = equilibrium_bc( - f_pre, f_post, boundary_id, missing_mask, f_equilibrium + f_pre, f_post, boundary_id_field, missing_mask, f_equilibrium ) # Test do nothing boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, do_nothing_bc.id, boundary_id, missing_mask, (0, 0, 0) + boundary_id_field, missing_mask = indices_boundary_masker( + indices, do_nothing_bc.id, boundary_id_field, missing_mask, (0, 0, 0) ) if backend == "jax": - f_do_nothing = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask) + f_do_nothing = do_nothing_bc(f_pre, f_post, boundary_id_field, missing_mask) elif backend == "warp": f_do_nothing = grid.create_field( cardinality=velocity_set.q, precision=xlb.Precision.FP32 ) f_do_nothing = do_nothing_bc( - f_pre, f_post, boundary_id, missing_mask, f_do_nothing + f_pre, f_post, boundary_id_field, missing_mask, f_do_nothing ) # Test halfway bounce back boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, halfway_bounce_back_bc.id, boundary_id, missing_mask, (0, 0, 0) + boundary_id_field, missing_mask = indices_boundary_masker( + indices, halfway_bounce_back_bc.id, boundary_id_field, missing_mask, (0, 0, 0) ) if backend == "jax": f_halfway_bounce_back = halfway_bounce_back_bc( - f_pre, f_post, boundary_id, missing_mask + f_pre, f_post, boundary_id_field, missing_mask ) elif backend == "warp": f_halfway_bounce_back = grid.create_field( cardinality=velocity_set.q, precision=xlb.Precision.FP32 ) f_halfway_bounce_back = halfway_bounce_back_bc( - f_pre, f_post, boundary_id, missing_mask, f_halfway_bounce_back + f_pre, f_post, boundary_id_field, missing_mask, f_halfway_bounce_back ) # Test the full boundary condition - boundary_id, missing_mask = indices_boundary_masker( - indices, fullway_bounce_back_bc.id, boundary_id, missing_mask, (0, 0, 0) + boundary_id_field, missing_mask = indices_boundary_masker( + indices, fullway_bounce_back_bc.id, boundary_id_field, missing_mask, (0, 0, 0) ) if backend == "jax": f_fullway_bounce_back = fullway_bounce_back_bc( - f_pre, f_post, boundary_id, missing_mask + f_pre, f_post, boundary_id_field, missing_mask ) elif backend == "warp": f_fullway_bounce_back = grid.create_field( cardinality=velocity_set.q, precision=xlb.Precision.FP32 ) f_fullway_bounce_back = fullway_bounce_back_bc( - f_pre, f_post, boundary_id, missing_mask, f_fullway_bounce_back + f_pre, f_post, boundary_id_field, missing_mask, f_fullway_bounce_back ) return f_equilibrium, f_do_nothing, f_halfway_bounce_back, f_fullway_bounce_back diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py new file mode 100644 index 0000000..202f4b4 --- /dev/null +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -0,0 +1,104 @@ +import pytest +import numpy as np +import jax.numpy as jnp +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.grid import grid_factory +from xlb.default_config import DefaultConfig + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + ], +) +def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + + # Make indices for boundary conditions (sphere) + sphere_radius = grid_shape[0] // 4 + nr = grid_shape[0] + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + if dim == 2: + X, Y = np.meshgrid(x, y) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) + else: + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + + indices = jnp.array(indices) + + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() + + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0) if dim == 3 else (0.0, 0.0), + equilibrium_operator=equilibrium, + ) + + boundary_id_field, missing_mask = indices_boundary_masker( + indices, equilibrium_bc.id, boundary_id_field, missing_mask, start_index=None + ) + + f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) + + f_post = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + ) # Arbitrary value so that we can check if the values are changed outside the boundary + + f = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask) + + assert f.shape == (velocity_set.q,) + grid_shape + + # Assert that the values are correct in the indices of the sphere + weights = velocity_set.w + for i, weight in enumerate(weights): + if dim == 2: + assert jnp.allclose( + f[i, indices[0], indices[1]], weight + ), f"Direction {i} in f does not match the expected weight" + else: + assert jnp.allclose( + f[i, indices[0], indices[1], indices[2]], weight + ), f"Direction {i} in f does not match the expected weight" + + # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values + mask_outside = np.ones(grid_shape, dtype=bool) + mask_outside[indices] = False # Mark boundary as false + if dim == 2: + for i in range(velocity_set.q): + assert jnp.allclose(f[i, mask_outside], f_post[i, mask_outside]) + else: + for i in range(velocity_set.q): + assert jnp.allclose(f[i, mask_outside], f_post[i, mask_outside]) + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py new file mode 100644 index 0000000..eaaa17f --- /dev/null +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -0,0 +1,108 @@ +import pytest +import numpy as np +import warp as wp +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.grid import grid_factory +from xlb.default_config import DefaultConfig + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (100, 100)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + ], +) +def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + + # Make indices for boundary conditions (sphere) + sphere_radius = grid_shape[0] // 4 + nr = grid_shape[0] + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + if dim == 2: + X, Y = np.meshgrid(x, y) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) + else: + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + + indices = wp.array(indices, dtype=wp.int32) + + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() + + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0) if dim == 3 else (0.0, 0.0), + equilibrium_operator=equilibrium, + ) + + boundary_id_field, missing_mask = indices_boundary_masker( + indices, equilibrium_bc.id, boundary_id_field, missing_mask, start_index=None + ) + + f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) + f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) + f_post = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + ) # Arbitrary value so that we can check if the values are changed outside the boundary + + f = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask, f) + + f = f.numpy() + f_post = f_post.numpy() + indices = indices.numpy() + + assert f.shape == (velocity_set.q,) + grid_shape + + # Assert that the values are correct in the indices of the sphere + weights = velocity_set.w + for i, weight in enumerate(weights): + if dim == 2: + assert np.allclose( + f[i, indices[0], indices[1]], weight + ), f"Direction {i} in f does not match the expected weight" + else: + assert np.allclose( + f[i, indices[0], indices[1], indices[2]], weight + ), f"Direction {i} in f does not match the expected weight" + + # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values + mask_outside = np.ones(grid_shape, dtype=bool) + mask_outside[indices] = False # Mark boundary as false + if dim == 2: + for i in range(velocity_set.q): + assert np.allclose(f[i, mask_outside], f_post[i, mask_outside]) + else: + for i in range(velocity_set.q): + assert np.allclose(f[i, mask_outside], f_post[i, mask_outside]) + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py new file mode 100644 index 0000000..ea19bc7 --- /dev/null +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -0,0 +1,90 @@ +import pytest +import jax.numpy as jnp +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.default_config import DefaultConfig + +from xlb.grid import grid_factory + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (50, 50)), + (3, xlb.velocity_set.D3Q19, (20, 20, 20)), + (3, xlb.velocity_set.D3Q19, (20, 20, 20)), + ], +) +def test_indices_masker_jax(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + + # Make indices for boundary conditions (sphere) + sphere_radius = grid_shape[0] // 4 + nr = grid_shape[0] + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + if dim == 2: + X, Y = np.meshgrid(x, y) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) + else: + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + + indices = jnp.array(indices) + + assert indices.shape[0] == dim + test_id = 5 + boundary_id_field, missing_mask = indices_boundary_masker( + indices, test_id, boundary_id_field, missing_mask, start_index=None + ) + + assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype + + assert boundary_id_field.dtype == xlb.Precision.UINT8.jax_dtype + + assert boundary_id_field.shape == (1,) + grid_shape + + assert missing_mask.shape == (velocity_set.q,) + grid_shape + + if dim == 2: + assert jnp.all(boundary_id_field[0, indices[0], indices[1]] == test_id) + # assert that the rest of the boundary_id_field is zero + boundary_id_field = boundary_id_field.at[0, indices[0], indices[1]].set(0) + assert jnp.all(boundary_id_field == 0) + if dim == 3: + assert jnp.all( + boundary_id_field[0, indices[0], indices[1], indices[2]] == test_id + ) + # assert that the rest of the boundary_id_field is zero + boundary_id_field = boundary_id_field.at[ + 0, indices[0], indices[1], indices[2] + ].set(0) + assert jnp.all(boundary_id_field == 0) + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py new file mode 100644 index 0000000..ce42d7f --- /dev/null +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -0,0 +1,94 @@ +import pytest +import warp as wp +import numpy as np +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.default_config import DefaultConfig + +from xlb.grid import grid_factory + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (50, 50)), + (3, xlb.velocity_set.D3Q19, (20, 20, 20)), + (3, xlb.velocity_set.D3Q19, (20, 20, 20)), + ], +) +def test_indices_masker_warp(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + + # Make indices for boundary conditions (sphere) + sphere_radius = grid_shape[0] // 4 + nr = grid_shape[0] + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + if dim == 2: + X, Y = np.meshgrid(x, y) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) + else: + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + + indices = wp.array(indices, dtype=wp.int32) + + assert indices.shape[0] == dim + test_id = 5 + boundary_id_field, missing_mask = indices_boundary_masker( + indices, + test_id, + boundary_id_field, + missing_mask, + start_index=(0, 0, 0) if dim == 3 else (0, 0), + ) + assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype + + assert boundary_id_field.dtype == xlb.Precision.UINT8.wp_dtype + + boundary_id_field = boundary_id_field.numpy() + missing_mask = missing_mask.numpy() + indices = indices.numpy() + + assert boundary_id_field.shape == (1,) + grid_shape + + assert missing_mask.shape == (velocity_set.q,) + grid_shape + + if dim == 2: + assert np.all(boundary_id_field[0, indices[0], indices[1]] == test_id) + # assert that the rest of the boundary_id_field is zero + boundary_id_field[0, indices[0], indices[1]]= 0 + assert np.all(boundary_id_field == 0) + if dim == 3: + assert np.all( + boundary_id_field[0, indices[0], indices[1], indices[2]] == test_id + ) + # assert that the rest of the boundary_id_field is zero + boundary_id_field[0, indices[0], indices[1], indices[2]] = 0 + assert np.all(boundary_id_field == 0) + +if __name__ == "__main__": + pytest.main() diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index 9f8a90f..88ca5fe 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -2,7 +2,7 @@ import jax import xlb from xlb.compute_backend import ComputeBackend -from xlb.grid import grid +from xlb.grid import grid_factory from jax.sharding import Mesh from jax.experimental import mesh_utils import jax.numpy as jnp @@ -19,7 +19,7 @@ def init_xlb_env(): def test_jax_2d_grid_initialization(grid_size): init_xlb_env() grid_shape = (grid_size, grid_size) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) n_devices = jax.device_count() @@ -35,7 +35,7 @@ def test_jax_2d_grid_initialization(grid_size): def test_jax_3d_grid_initialization(grid_size): init_xlb_env() grid_shape = (grid_size, grid_size, grid_size) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) n_devices = jax.device_count() @@ -55,7 +55,7 @@ def test_jax_grid_create_field_init_val(): init_xlb_env() grid_shape = (100, 100) init_val = 3.14 - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, init_val=init_val) assert f.shape == (9,) + grid_shape, "Field shape is incorrect" diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 5bdac7c..3bf1ae5 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -3,7 +3,7 @@ import numpy as np import xlb from xlb.compute_backend import ComputeBackend -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.precision_policy import Precision @@ -19,7 +19,7 @@ def init_xlb_warp_env(): def test_warp_grid_create_field(grid_size): for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]: init_xlb_warp_env() - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) assert f.shape == (9,) + grid_shape, "Field shape is incorrect" @@ -30,7 +30,7 @@ def test_warp_grid_create_field_init_val(): init_xlb_warp_env() grid_shape = (100, 100) init_val = 3.14 - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32, init_val=init_val) assert isinstance(f, wp.array), "Field should be a Warp ndarray" diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index adabd41..3329187 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -4,7 +4,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.collision import BGK -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.default_config import DefaultConfig @@ -27,7 +27,7 @@ def init_xlb_env(velocity_set): ) def test_bgk_ollision(dim, velocity_set, grid_shape, omega): init_xlb_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) rho = my_grid.create_field(cardinality=1, init_val=1.0) u = my_grid.create_field(cardinality=dim, init_val=0.0) diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 20e2f48..0159462 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -6,7 +6,7 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.operator.collision import BGK -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.default_config import DefaultConfig from xlb.precision_policy import Precision @@ -28,7 +28,7 @@ def init_xlb_warp_env(velocity_set): ) def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): init_xlb_warp_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) rho = my_grid.create_field(cardinality=1, init_val=1.0) u = my_grid.create_field(cardinality=dim, init_val=0.0) diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index c0f50d7..bfea94a 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -3,7 +3,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.default_config import DefaultConfig @@ -26,7 +26,7 @@ def init_xlb_env(velocity_set): ) def test_quadratic_equilibrium(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) rho = my_grid.create_field(cardinality=1, init_val=1.0) u = my_grid.create_field(cardinality=dim, init_val=0.0) diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index ef75624..366c8f1 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -4,7 +4,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.default_config import DefaultConfig def init_xlb_env(velocity_set): @@ -20,7 +20,7 @@ def init_xlb_env(velocity_set): ]) def test_quadratic_equilibrium(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) rho = my_grid.create_field(cardinality=1, init_val=1.0) u = my_grid.create_field(cardinality=dim, init_val=0.0) diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 786063f..a97cb1d 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -4,7 +4,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.default_config import DefaultConfig @@ -26,7 +26,7 @@ def init_xlb_env(velocity_set): ],) def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): init_xlb_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) rho_field = my_grid.create_field(cardinality=1, init_val=rho) velocity_field = my_grid.create_field(cardinality=dim, init_val=velocity) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 163db95..52d6d88 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -4,7 +4,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic -from xlb.grid import grid +from xlb.grid import grid_factory from xlb.default_config import DefaultConfig import warp as wp @@ -31,7 +31,7 @@ def init_xlb_env(velocity_set): ) def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): init_xlb_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) rho_field = my_grid.create_field(cardinality=1, init_val=rho) velocity_field = my_grid.create_field(cardinality=dim, init_val=velocity) diff --git a/tests/kernels/stream/test_stream_jax.py b/tests/kernels/stream/test_stream_jax.py index 83f1714..3ba4f35 100644 --- a/tests/kernels/stream/test_stream_jax.py +++ b/tests/kernels/stream/test_stream_jax.py @@ -5,7 +5,7 @@ from xlb.operator.stream import Stream from xlb.default_config import DefaultConfig -from xlb.grid import grid +from xlb.grid import grid_factory def init_xlb_env(velocity_set): @@ -27,7 +27,7 @@ def init_xlb_env(velocity_set): ) def test_stream_operator(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) - my_grid = grid(grid_shape) + my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set stream_op = Stream() diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index 643ea70..afd2d89 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -7,7 +7,7 @@ from xlb.operator.stream import Stream from xlb.default_config import DefaultConfig -from xlb.grid import grid +from xlb.grid import grid_factory def init_xlb_env(velocity_set): @@ -29,7 +29,7 @@ def init_xlb_env(velocity_set): ) def test_stream_operator(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) - my_grid_jax = grid(grid_shape, compute_backend=ComputeBackend.JAX) + my_grid_jax = grid_factory(grid_shape, compute_backend=ComputeBackend.JAX) velocity_set = DefaultConfig.velocity_set f_initial = my_grid_jax.create_field(cardinality=velocity_set.q) @@ -59,13 +59,13 @@ def test_stream_operator(dim, velocity_set, grid_shape): expected = jnp.stack(expected, axis=0) if dim == 2: - f_initial_warp = wp.from_numpy(f_initial, dtype=wp.float32) + f_initial_warp = wp.array(f_initial) elif dim == 3: - f_initial_warp = wp.from_numpy(f_initial, dtype=wp.float32) + f_initial_warp = wp.array(f_initial) stream_op = Stream() - my_grid_warp = grid(grid_shape, compute_backend=ComputeBackend.WARP) + my_grid_warp = grid_factory(grid_shape, compute_backend=ComputeBackend.WARP) f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q) f_streamed = stream_op(f_initial_warp, f_streamed) diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py index bcefed1..60da291 100644 --- a/xlb/compute_backend.py +++ b/xlb/compute_backend.py @@ -5,5 +5,4 @@ class ComputeBackend(Enum): JAX = auto() - PALLAS = auto() WARP = auto() diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index 47bcef1..692b453 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1 +1 @@ -from xlb.grid.grid import grid, Grid \ No newline at end of file +from xlb.grid.grid import grid_factory \ No newline at end of file diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 6dbdaf6..4fc0665 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -6,11 +6,9 @@ from xlb.precision_policy import Precision -def grid( +def grid_factory( shape: Tuple[int, ...], - compute_backend: ComputeBackend = None, - parallel: bool = False, -): + compute_backend: ComputeBackend = None): compute_backend = compute_backend or DefaultConfig.default_backend if compute_backend == ComputeBackend.WARP: from xlb.grid.warp_grid import WarpGrid @@ -28,14 +26,9 @@ class Grid(ABC): def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): self.shape = shape self.dim = len(shape) - self.parallel = False self.compute_backend = compute_backend self._initialize_backend() @abstractmethod def _initialize_backend(self): pass - - # @abstractmethod - # def parallelize_operator(self, operator: Operator): - # pass diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 2d3b04d..b290d53 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -10,7 +10,7 @@ import jax from xlb.default_config import DefaultConfig -from xlb.grid import Grid +from .grid import Grid from xlb.operator import Operator from xlb.precision_policy import Precision @@ -41,7 +41,7 @@ def _initialize_backend(self): def create_field( self, cardinality: int, - dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, + dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16, Precision.BOOL] = None, init_val=None, ): sharding_dim = self.shape[-1] // self.nDevices @@ -49,7 +49,7 @@ def create_field( full_shape = (cardinality, *self.shape) arrays = [] - dype = dtype.jax_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.jax_dtype + dtype = dtype.jax_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.jax_dtype for d, index in self.sharding.addressable_devices_indices_map( full_shape diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index 92bed54..ae1c4e8 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -1,7 +1,7 @@ from dataclasses import field import warp as wp -from xlb.grid import Grid +from .grid import Grid from xlb.operator import Operator from xlb.precision_policy import Precision from xlb.compute_backend import ComputeBackend diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 27e0472..275a074 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -2,7 +2,7 @@ from xlb.operator.boundary_condition.boundary_condition_registry import ( BoundaryConditionRegistry, ) -from xlb.operator.boundary_condition.equilibrium import EquilibriumBC -from xlb.operator.boundary_condition.do_nothing import DoNothingBC -from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC -from xlb.operator.boundary_condition.fullway_bounce_back import FullwayBounceBackBC +from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC +from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC +from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC +from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py similarity index 84% rename from xlb/operator/boundary_condition/do_nothing.py rename to xlb/operator/boundary_condition/bc_do_nothing.py index 46a6fdd..d1d1a81 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -45,14 +45,11 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): - # TODO: This is unoptimized - boundary = boundary_id == self.id - flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) - skipped_f = lax.select(flip, f_pre, f_post) - return skipped_f + def jax_implementation(self, f_pre, f_post, boundary_id_field, missing_mask): + boundary = (boundary_id_field == self.id) + boundary = boundary[:, None, None, None] + return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update @@ -78,7 +75,7 @@ def functional( def kernel( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -87,7 +84,7 @@ def kernel( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -111,11 +108,11 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id, missing_mask, f], + inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py similarity index 58% rename from xlb/operator/boundary_condition/equilibrium.py rename to xlb/operator/boundary_condition/bc_equilibrium.py index 4a8ccb5..7c0505d 100644 --- a/xlb/operator/boundary_condition/equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -35,9 +35,9 @@ def __init__( rho: float, u: Tuple[float, float, float], equilibrium_operator: Operator, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): # Store the equilibrium information self.rho = rho @@ -53,11 +53,12 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - # @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, boundary_id_field, missing_mask): feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) - feq = feq[:, None, None, None] - boundary = (boundary_id == self.id) + new_shape = feq.shape + (1,) * self.velocity_set.d + feq = lax.broadcast_in_dim(feq, new_shape, [0]) + boundary = boundary_id_field == self.id return jnp.where(boundary, feq, f_post) @@ -66,27 +67,63 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(self.rho) - _u = _u_vec(self.u[0], self.u[1], self.u[2]) + _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) _missing_mask_vec = wp.vec( self.velocity_set.q, dtype=wp.uint8 ) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func - def functional( - f: wp.array4d(dtype=Any), - missing_mask: Any, - index: Any, + def functional2d(): + _f = self.equilibrium_operator.warp_functional(_rho, _u) + return _f + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_id_field: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + f: wp.array3d(dtype=Any), ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the boundary id and missing mask + _boundary_id = boundary_id_field[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(EquilibriumBC.id): + _f = functional2d() + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1]] + + # Write the result + for l in range(self.velocity_set.q): + f[l, index[0], index[1]] = _f[l] + + @wp.func + def functional3d(): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f # Construct the warp kernel @wp.kernel - def kernel( + def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -95,7 +132,7 @@ def kernel( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -106,7 +143,7 @@ def kernel( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional(f_pre, _missing_mask, index) + _f = functional3d() else: _f = _f_vec() for l in range(self.velocity_set.q): @@ -116,14 +153,17 @@ def kernel( for l in range(self.velocity_set.q): f[l, index[0], index[1], index[2]] = _f[l] + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + functional = functional3d if self.velocity_set.d == 3 else functional2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id, missing_mask, f], + inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py similarity index 90% rename from xlb/operator/boundary_condition/fullway_bounce_back.py rename to xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 547cde1..80404eb 100644 --- a/xlb/operator/boundary_condition/fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -46,8 +46,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): - boundary = boundary_id == self.id + def apply_jax(self, f_pre, f_post, boundary_id_field, missing_mask): + boundary = boundary_id_field == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) return lax.select(boundary, f_pre[self.velocity_set.opp_indices], f_post) @@ -77,7 +77,7 @@ def functional( def kernel( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -86,7 +86,7 @@ def kernel( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] # Make vectors for the lattice _f_pre = _f_vec() @@ -116,11 +116,11 @@ def kernel( @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id, missing_mask, f], + inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py similarity index 91% rename from xlb/operator/boundary_condition/halfway_bounce_back.py rename to xlb/operator/boundary_condition/bc_halfway_bounce_back.py index e47cc26..c594ca1 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -49,8 +49,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): - boundary = boundary_id == self.id + def apply_jax(self, f_pre, f_post, boundary_id_field, missing_mask): + boundary = boundary_id_field == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) return lax.select(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], f_post) @@ -98,7 +98,7 @@ def functional( def kernel( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -107,7 +107,7 @@ def kernel( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -131,11 +131,11 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id, missing_mask, f], + inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 92a2c1f..ceb9830 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -12,6 +12,7 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator +from xlb.default_config import DefaultConfig # Enum for implementation step @@ -28,10 +29,14 @@ class BoundaryCondition(Operator): def __init__( self, implementation_step: ImplementationStep, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): + velocity_set = velocity_set or DefaultConfig.velocity_set + precision_policy = precision_policy or DefaultConfig.default_precision_policy + compute_backend = compute_backend or DefaultConfig.default_backend + super().__init__(velocity_set, precision_policy, compute_backend) # Set the implementation step diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index f69252f..7f4b803 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,6 +1,3 @@ -from xlb.operator.boundary_masker.boundary_masker import ( - BoundaryMasker, -) from xlb.operator.boundary_masker.indices_boundary_masker import ( IndicesBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/boundary_masker.py b/xlb/operator/boundary_masker/boundary_masker.py deleted file mode 100644 index 6fe487f..0000000 --- a/xlb/operator/boundary_masker/boundary_masker.py +++ /dev/null @@ -1,9 +0,0 @@ -# Base class for all boundary masker operators - -from xlb.operator.operator import Operator - - -class BoundaryMasker(Operator): - """ - Operator for creating a boundary mask - """ diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index cbec6ac..9bd562f 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -22,9 +22,9 @@ class IndicesBoundaryMasker(Operator): def __init__( self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, + velocity_set=None, + precision_policy=None, + compute_backend=None, ): # Make stream operator self.stream = Stream(velocity_set, precision_policy, compute_backend) @@ -41,38 +41,87 @@ def _indices_to_tuple(indices): @Operator.register_backend(ComputeBackend.JAX) def jax_implementation( - self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0) + self, indices, id_number, boundary_id_field, mask, start_index=None ): - local_indices = indices - np.array(start_index)[np.newaxis, :] + dim = mask.ndim - 1 + if start_index is None: + start_index = (0,) * dim - # Remove any indices that are out of bounds - indices_mask_x = (local_indices[:, 0] >= 0) & (local_indices[:, 0] < mask.shape[1]) - indices_mask_y = (local_indices[:, 1] >= 0) & (local_indices[:, 1] < mask.shape[2]) - indices_mask_z = (local_indices[:, 2] >= 0) & (local_indices[:, 2] < mask.shape[3]) - indices_mask = indices_mask_x & indices_mask_y & indices_mask_z + local_indices = indices - np.array(start_index)[:, np.newaxis] - local_indices = self._indices_to_tuple(local_indices[indices_mask]) + indices_mask = [ + (local_indices[i, :] >= 0) & (local_indices[i, :] < mask.shape[i + 1]) + for i in range(mask.ndim - 1) + ] + indices_mask = np.logical_and.reduce(indices_mask) @jit - def compute_boundary_id_and_mask(boundary_id, mask): - boundary_id = boundary_id.at[0, local_indices[0], local_indices[1], local_indices[2]].set(id_number) - mask = mask.at[:, local_indices[0], local_indices[1], local_indices[2]].set(True) + def compute_boundary_id_and_mask(boundary_id_field, mask): + if dim == 2: + boundary_id_field = boundary_id_field.at[ + 0, local_indices[0], local_indices[1] + ].set(id_number) + mask = mask.at[:, local_indices[0], local_indices[1]].set(True) + + if dim == 3: + boundary_id_field = boundary_id_field.at[ + 0, local_indices[0], local_indices[1], local_indices[2] + ].set(id_number) + mask = mask.at[ + :, local_indices[0], local_indices[1], local_indices[2] + ].set(True) + mask = self.stream(mask) - return boundary_id, mask + return boundary_id_field, mask - return compute_boundary_id_and_mask(boundary_id, mask) + return compute_boundary_id_and_mask(boundary_id_field, mask) def _construct_warp(self): # Make constants for warp _c = self.velocity_set.wp_c _q = wp.constant(self.velocity_set.q) - # Construct the warp kernel + # Construct the warp 2D kernel + @wp.kernel + def kernel2d( + indices: wp.array2d(dtype=wp.int32), + id_number: wp.int32, + boundary_id_field: wp.array3d(dtype=wp.uint8), + mask: wp.array3d(dtype=wp.bool), + start_index: wp.vec2i, + ): + # Get the index of indices + ii = wp.tid() + + # Get local indices + index = wp.vec2i() + index[0] = indices[0, ii] - start_index[0] + index[1] = indices[1, ii] - start_index[1] + + # Check if in bounds + if ( + index[0] >= 0 + and index[0] < mask.shape[1] + and index[1] >= 0 + and index[1] < mask.shape[2] + ): + # Stream indices + for l in range(_q): + # Get the index of the streaming direction + push_index = wp.vec2i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and mask + boundary_id_field[0, index[0], index[1]] = wp.uint8(id_number) + mask[l, push_index[0], push_index[1]] = True + + # Construct the warp 3D kernel @wp.kernel - def kernel( + def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.int32, - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -81,9 +130,9 @@ def kernel( # Get local indices index = wp.vec3i() - index[0] = indices[ii, 0] - start_index[0] - index[1] = indices[ii, 1] - start_index[1] - index[2] = indices[ii, 2] - start_index[2] + index[0] = indices[0, ii] - start_index[0] + index[1] = indices[1, ii] - start_index[1] + index[2] = indices[2, ii] - start_index[2] # Check if in bounds if ( @@ -102,28 +151,32 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id[0, index[0], index[1], index[2]] = ( - wp.uint8(id_number) + boundary_id_field[0, index[0], index[1], index[2]] = wp.uint8( + id_number ) mask[l, push_index[0], push_index[1], push_index[2]] = True + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation( - self, indices, id_number, boundary_id, missing_mask, start_index=(0, 0, 0) + self, indices, id_number, boundary_id_field, missing_mask, start_index=None ): + if start_index is None: + start_index = (0,) * self.velocity_set.d # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ indices, id_number, - boundary_id, + boundary_id_field, missing_mask, start_index, ], - dim=indices.shape[0], + dim=indices.shape[1], ) - return boundary_id, missing_mask + return boundary_id_field, missing_mask diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 4c886c4..f7ffad3 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -37,7 +37,7 @@ def jax_implementation( upper_bound, direction, id_number, - boundary_id, + boundary_id_field, mask, start_index=(0, 0, 0), ): @@ -47,7 +47,7 @@ def jax_implementation( if direction[0] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) + boundary_id_field = boundary_id_field.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -63,7 +63,7 @@ def jax_implementation( elif direction[1] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) + boundary_id_field = boundary_id_field.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -79,7 +79,7 @@ def jax_implementation( elif direction[2] != 0: # Set boundary id - boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number) + boundary_id_field = boundary_id_field.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number) # Set mask for l in range(self.velocity_set.q): @@ -91,7 +91,7 @@ def jax_implementation( if d_dot_c >= 0: mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(True) - return boundary_id, mask + return boundary_id_field, mask def _construct_warp(self): @@ -106,7 +106,7 @@ def kernel( upper_bound: wp.vec3i, direction: wp.vec3i, id_number: wp.int32, - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -137,7 +137,7 @@ def kernel( and k < mask.shape[3] ): # Set the boundary id - boundary_id[0, i, j, k] = wp.uint8(id_number) + boundary_id_field[0, i, j, k] = wp.uint8(id_number) # Set mask for just directions coming from the boundary for l in range(_q): @@ -158,7 +158,7 @@ def warp_implementation( upper_bound, direction, id_number, - boundary_id, + boundary_id_field, mask, start_index=(0, 0, 0), ): @@ -187,11 +187,11 @@ def warp_implementation( upper_bound, direction, id_number, - boundary_id, + boundary_id_field, mask, start_index, ], dim=dim, ) - return boundary_id, mask + return boundary_id_field, mask diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index 8a6f956..20e630c 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -42,7 +42,7 @@ def kernel( origin: wp.vec3, spacing: wp.vec3, id_number: wp.int32, - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_id_field: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -64,9 +64,9 @@ def kernel( # Compute the maximum length max_length = wp.sqrt( - (spacing[0] * wp.float32(boundary_id.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(boundary_id.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(boundary_id.shape[3])) ** 2.0 + (spacing[0] * wp.float32(boundary_id_field.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_id_field.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_id_field.shape[3])) ** 2.0 ) # evaluate if point is inside mesh @@ -87,7 +87,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id[ + boundary_id_field[ 0, push_index[0], push_index[1], push_index[2] ] = wp.uint8(id_number) mask[l, push_index[0], push_index[1], push_index[2]] = True @@ -101,7 +101,7 @@ def warp_implementation( origin, spacing, id_number, - boundary_id, + boundary_id_field, mask, start_index=(0, 0, 0), ): @@ -122,11 +122,11 @@ def warp_implementation( origin, spacing, id_number, - boundary_id, + boundary_id_field, mask, start_index, ], - dim=boundary_id.shape[1:], + dim=boundary_id_field.shape[1:], ) - return boundary_id, mask + return boundary_id_field, mask diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 520cf3f..3b006ab 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -22,12 +22,6 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fout = f - self.compute_dtype(self.omega) * fneq return fout - @Operator.register_backend(ComputeBackend.PALLAS) - def pallas_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): - fneq = f - feq - fout = f - self.omega * fneq - return fout - def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _w = self.velocity_set.wp_w diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index c84abd1..7afacac 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -18,7 +18,7 @@ class QuadraticEquilibrium(Equilibrium): """ @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) @@ -26,34 +26,6 @@ def jax_implementation(self, rho, u): feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq - @Operator.register_backend(ComputeBackend.PALLAS) - def pallas_implementation(self, rho, u): - u0, u1, u2 = u[0], u[1], u[2] - usqr = 1.5 * (u0**2 + u1**2 + u2**2) - - eq = [ - rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u0 + 4.5 * u0 * u0 - usqr), - rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u1 + 4.5 * u1 * u1 - usqr), - rho[0] * (1.0 / 18.0) * (1.0 - 3.0 * u2 + 4.5 * u2 * u2 - usqr), - ] - - combined_velocities = [u0 + u1, u0 - u1, u0 + u2, u0 - u2, u1 + u2, u1 - u2] - - for vel in combined_velocities: - eq.append( - rho[0] * (1.0 / 36.0) * (1.0 - 3.0 * vel + 4.5 * vel * vel - usqr) - ) - - eq.append(rho[0] * (1.0 / 3.0) * (1.0 - usqr)) - - for i in range(3): - eq.append(eq[i] + rho[0] * (1.0 / 18.0) * 6.0 * u[i]) - - for i, vel in enumerate(combined_velocities, 3): - eq.append(eq[i] + rho[0] * (1.0 / 36.0) * 6.0 * vel) - - return jnp.array(eq) - def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _c = self.velocity_set.wp_c diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 9c6ac10..521c6d4 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -46,33 +46,6 @@ def jax_implementation(self, f): return rho, u - @Operator.register_backend(ComputeBackend.PALLAS) - def pallas_implementation(self, f): - # TODO: Maybe this can be done with jnp.sum - rho = jnp.sum(f, axis=0, keepdims=True) - - u = jnp.zeros((3, *rho.shape[1:])) - u.at[0].set( - -f[9] - - f[10] - - f[11] - - f[12] - - f[13] - + f[14] - + f[15] - + f[16] - + f[17] - + f[18] - ) / rho - u.at[1].set( - -f[3] - f[4] - f[5] + f[6] + f[7] + f[8] - f[12] + f[13] - f[17] + f[18] - ) / rho - u.at[2].set( - -f[1] + f[2] - f[4] + f[5] - f[7] + f[8] - f[10] + f[11] - f[15] + f[16] - ) / rho - - return rho, jnp.array(u) - def _construct_warp(self): # Make constants for warp _c = self.velocity_set.wp_c diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index a8b85d0..c6d8ad0 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -70,11 +70,12 @@ def __call__(self, *args, callback=None, **kwargs): if callback and callable(callback): callback(callback_arg) return result - except TypeError: + except Exception as e: + error = e continue # This skips to the next candidate if binding fails - raise NotImplementedError( - f"No implementation found for backend with key {key} for operator {self.__class__.__name__}" + raise Exception( + f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}" ) @property diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py index 44ff137..e5d159c 100644 --- a/xlb/operator/stepper/__init__.py +++ b/xlb/operator/stepper/__init__.py @@ -1,2 +1,2 @@ from xlb.operator.stepper.stepper import Stepper -from xlb.operator.stepper.nse import IncompressibleNavierStokesStepper +from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse_stepper.py similarity index 63% rename from xlb/operator/stepper/nse.py rename to xlb/operator/stepper/nse_stepper.py index 9c6d56e..d51f637 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -20,23 +20,24 @@ class IncompressibleNavierStokesStepper(Stepper): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0, 4), donate_argnums=(1)) - def apply_jax(self, f, boundary_id, missing_mask, timestep): + def jax_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ - # Cast to compute precision TODO add this back in - #f_pre_collision = self.precision_policy.cast_to_compute_jax(f) + # Cast to compute precision + f_0 = self.precision_policy.cast_to_compute_jax(f_0) + f_1 = self.precision_policy.cast_to_compute_jax(f_1) # Compute the macroscopic variables - rho, u = self.macroscopic(f) + rho, u = self.macroscopic(f_0) # Compute equilibrium feq = self.equilibrium(rho, u) # Apply collision f_post_collision = self.collision( - f, + f_0, feq, rho, u, @@ -45,89 +46,30 @@ def apply_jax(self, f, boundary_id, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: if bc.implementation_step == ImplementationStep.COLLISION: - f_post_collision = bc( - f, + f_0 = bc( + f_0, f_post_collision, - boundary_id, + boundary_id_field, missing_mask, ) - ## Apply forcing - # if self.forcing_op is not None: - # f = self.forcing_op.apply_jax(f, timestep) - # Apply streaming - f_post_streaming = self.stream(f_post_collision) + f_1 = self.stream(f_0) # Apply boundary conditions for bc in self.boundary_conditions: if bc.implementation_step == ImplementationStep.STREAMING: - f_post_streaming = bc( + f_1 = bc( f_post_collision, - f_post_streaming, - boundary_id, + f_1, + boundary_id_field, missing_mask, ) # Copy back to store precision - #f = self.precision_policy.cast_to_store_jax(f_post_streaming) - - return f_post_streaming - - @Operator.register_backend(ComputeBackend.PALLAS) - @partial(jit, static_argnums=(0,)) - def apply_pallas(self, fin, boundary_id, missing_mask, timestep): - # Raise warning that the boundary conditions are not implemented - warning("Boundary conditions are not implemented for PALLAS backend currently") - - from xlb.operator.parallel_operator import ParallelOperator - - def _pallas_collide(fin, fout): - idx = pl.program_id(0) - - f = pl.load(fin, (slice(None), idx, slice(None), slice(None))) - - print("f shape", f.shape) - - rho, u = self.macroscopic(f) - - print("rho shape", rho.shape) - print("u shape", u.shape) + f_1 = self.precision_policy.cast_to_store_jax(f_1) - feq = self.equilibrium(rho, u) - - print("feq shape", feq.shape) - - for i in range(self.velocity_set.q): - print("f shape", f[i].shape) - f_post_collision = self.collision(f[i], feq[i]) - print("f_post_collision shape", f_post_collision.shape) - pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) - # f_post_collision = self.collision(f, feq) - # pl.store(fout, (i, idx, slice(None), slice(None)), f_post_collision) - - @jit - def _pallas_collide_kernel(fin): - return pl.pallas_call( - partial(_pallas_collide), - out_shape=jax.ShapeDtypeStruct( - ((self.velocity_set.q,) + (self.grid.grid_shape_per_gpu)), fin.dtype - ), - # grid=1, - grid=(self.grid.grid_shape_per_gpu[0], 1, 1), - )(fin) - - def _pallas_collide_and_stream(f): - f = _pallas_collide_kernel(f) - # f = self.stream._streaming_jax_p(f) - - return f - - fout = ParallelOperator( - self.grid, _pallas_collide_and_stream, self.velocity_set - )(fin) - - return fout + return f_1 def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update @@ -147,7 +89,7 @@ def _construct_warp(self): def kernel( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=Any), + boundary_id_field: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), timestep: int, ): @@ -156,7 +98,7 @@ def kernel( index = wp.vec3i(i, j, k) # TODO warp should fix this # Get the boundary id and missing mask - _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -184,7 +126,7 @@ def kernel( f_post_stream = self.halfway_bounce_back_bc.warp_functional( f_0, _missing_mask, index ) - + # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -215,14 +157,14 @@ def kernel( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_id, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ f_0, f_1, - boundary_id, + boundary_id_field, missing_mask, timestep, ], diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 1e89547..c4423ea 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -23,7 +23,6 @@ def __init__( equilibrium, macroscopic, boundary_conditions=[], - forcing=None, # TODO: Add forcing later ): # Add operators self.collision = collision @@ -31,7 +30,6 @@ def __init__( self.equilibrium = equilibrium self.macroscopic = macroscopic self.boundary_conditions = boundary_conditions - self.forcing = forcing # Get all operators for checking self.operators = [ @@ -41,8 +39,6 @@ def __init__( macroscopic, *self.boundary_conditions, ] - if forcing is not None: - self.operators.append(forcing) # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators]) @@ -61,10 +57,10 @@ def __init__( ############################################ # TODO: Fix this later ############################################ - from xlb.operator.boundary_condition.equilibrium import EquilibriumBC - from xlb.operator.boundary_condition.do_nothing import DoNothingBC - from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC - from xlb.operator.boundary_condition.fullway_bounce_back import FullwayBounceBackBC + from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC + from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC + from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC self.equilibrium_bc = None self.do_nothing_bc = None self.halfway_bounce_back_bc = None diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index fad3c8f..98f7968 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -13,10 +13,7 @@ class Fp64Fp64: def __new__(cls): - if ( - DefaultConfig.compute_backend == ComputeBackend.JAX - or DefaultConfig.compute_backend == ComputeBackend.PALLAS - ): + if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp64() else: raise ValueError( @@ -26,10 +23,7 @@ def __new__(cls): class Fp64Fp32: def __new__(cls): - if ( - DefaultConfig.compute_backend == ComputeBackend.JAX - or DefaultConfig.compute_backend == ComputeBackend.PALLAS - ): + if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp32() else: raise ValueError( @@ -39,10 +33,7 @@ def __new__(cls): class Fp32Fp32: def __new__(cls): - if ( - DefaultConfig.compute_backend == ComputeBackend.JAX - or DefaultConfig.compute_backend == ComputeBackend.PALLAS - ): + if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp32Fp32() else: raise ValueError( @@ -52,10 +43,7 @@ def __new__(cls): class Fp64Fp16: def __new__(cls): - if ( - DefaultConfig.compute_backend == ComputeBackend.JAX - or DefaultConfig.compute_backend == ComputeBackend.PALLAS - ): + if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp16() else: raise ValueError( @@ -65,10 +53,7 @@ def __new__(cls): class Fp32Fp16: def __new__(cls): - if ( - DefaultConfig.compute_backend == ComputeBackend.JAX - or DefaultConfig.compute_backend == ComputeBackend.PALLAS - ): + if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp32Fp16() else: raise ValueError( diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 1b0ab43..1ffadf8 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -13,7 +13,6 @@ from xlb.operator.macroscopic import Macroscopic from xlb.solver.solver import Solver from xlb.operator import Operator -from jax.experimental import pallas as pl class IncompressibleNavierStokesSolver(Solver): @@ -29,37 +28,25 @@ class IncompressibleNavierStokesSolver(Solver): def __init__( self, omega: float, - shape: tuple[int, int, int], + domain_shape: tuple[int, int, int], collision="BGK", equilibrium="Quadratic", boundary_conditions=[], - initializer=None, - forcing=None, - velocity_set: VelocitySet = None, + velocity_set = None, precision_policy=None, compute_backend=None, - grid_backend=None, - grid_configs={}, ): super().__init__( - shape=shape, + domain_shape=domain_shape, boundary_conditions=boundary_conditions, velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, - grid_backend=grid_backend, - grid_configs=grid_configs, ) # Set omega self.omega = omega - # Add fields to grid - self.grid.create_field("rho", 1, self.precision_policy.store_precision) - self.grid.create_field("u", 3, self.precision_policy.store_precision) - self.grid.create_field("f0", self.velocity_set.q, self.precision_policy.store_precision) - self.grid.create_field("f1", self.velocity_set.q, self.precision_policy.store_precision) - # Create operators self.collision = self._get_collision(collision)( omega=self.omega, @@ -78,15 +65,6 @@ def __init__( self.macroscopic = Macroscopic( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend ) - if initializer is None: - self.initializer = EquilibriumInitializer( - rho=1.0, u=(0.0, 0.0, 0.0), - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - if forcing is not None: - raise NotImplementedError("Forcing not yet implemented") # Create stepper operator self.stepper = IncompressibleNavierStokesStepper( @@ -98,15 +76,6 @@ def __init__( forcing=None, ) - # Add parrallelization - self.stepper = self.grid.parallelize_operator(self.stepper) - - # Initialize - self.initialize() - - def initialize(self): - self.initializer(f=self.grid.get_field("f0")) - def monitor(self): pass diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 7979c11..a826ab7 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -1,6 +1,3 @@ -# Base class for all stepper operators - -from xlb.compute_backend import ComputeBackend from xlb.default_config import DefaultConfig from xlb.operator.operator import Operator @@ -12,29 +9,15 @@ class Solver(Operator): def __init__( self, - shape: tuple[int, int, int], + domain_shape: tuple[int, int, int], boundary_conditions=[], velocity_set=None, precision_policy=None, compute_backend=None, - grid_backend=None, - grid_configs={}, ): - # Set parameters - self.shape = shape + self.domain_shape = domain_shape + self.boundary_conditions = boundary_conditions self.velocity_set = velocity_set or DefaultConfig.velocity_set self.precision_policy = precision_policy or DefaultConfig.precision_policy self.compute_backend = compute_backend or DefaultConfig.compute_backend - self.grid_backend = grid_backend or DefaultConfig.grid_backend - self.boundary_conditions = boundary_conditions - - # Make grid - if self.grid_backend is GridBackend.JAX: - self.grid = JaxGrid(**grid_configs) - elif self.grid_backend is GridBackend.WARP: - self.grid = WarpGrid(**grid_configs) - elif self.grid_backend is GridBackend.OOC: - self.grid = OOCGrid(**grid_configs) - else: - raise ValueError(f"Grid backend {self.grid_backend} not recognized") From 0534eb0c0949c0a71967de73ed770ed223ff231f Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 10 May 2024 15:45:34 -0400 Subject: [PATCH 030/144] WIP added more tests --- .../test_bc_fullway_bounce_back_jax.py | 98 +++++++++++++++++++ .../test_bc_fullway_bounce_back_warp.py | 98 +++++++++++++++++++ .../mask/test_bc_indices_masker_jax.py | 8 +- .../mask/test_bc_indices_masker_warp.py | 8 +- .../collision/test_bgk_collision_jax.py | 2 + .../collision/test_bgk_collision_warp.py | 2 + .../equilibrium/test_equilibrium_jax.py | 8 +- .../equilibrium/test_equilibrium_warp.py | 17 +++- .../macroscopic/test_macroscopic_jax.py | 8 +- .../macroscopic/test_macroscopic_warp.py | 2 +- tests/kernels/stream/test_stream_jax.py | 8 +- tests/kernels/stream/test_stream_warp.py | 8 +- .../bc_fullway_bounce_back.py | 54 ++++++++-- xlb/velocity_set/velocity_set.py | 1 + 14 files changed, 290 insertions(+), 32 deletions(-) create mode 100644 tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py create mode 100644 tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py new file mode 100644 index 0000000..82738e9 --- /dev/null +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -0,0 +1,98 @@ +import pytest +import numpy as np +import jax.numpy as jnp +import xlb +import jax +from xlb.compute_backend import ComputeBackend +from xlb.grid import grid_factory +from xlb.default_config import DefaultConfig + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + ], +) +def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC() + + boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + + # Make indices for boundary conditions (sphere) + sphere_radius = grid_shape[0] // 4 + nr = grid_shape[0] + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + if dim == 2: + X, Y = np.meshgrid(x, y) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) + else: + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + + indices = jnp.array(indices) + + boundary_id_field, missing_mask = indices_boundary_masker( + indices, fullway_bc.id, boundary_id_field, missing_mask, start_index=None + ) + + f_pre = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=0.0 + ) + # Generate a random field with the same shape + key = jax.random.PRNGKey(0) + random_field = jax.random.uniform(key, f_pre.shape) + # Add the random field to f_pre + f_pre += random_field + + f_post = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + ) # Arbitrary value so that we can check if the values are changed outside the boundary + + f = fullway_bc(f_pre, f_post, boundary_id_field, missing_mask) + + assert f.shape == (velocity_set.q,) + grid_shape + + for i in range(velocity_set.q): + jnp.allclose( + f[velocity_set.get_opp_index(i)][tuple(indices)], + f_pre[i][tuple(indices)], + ) + + # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values + mask_outside = np.ones(grid_shape, dtype=bool) + mask_outside[indices] = False # Mark boundary as false + if dim == 2: + for i in range(velocity_set.q): + assert jnp.allclose(f[i, mask_outside], f_post[i, mask_outside]) + else: + for i in range(velocity_set.q): + assert jnp.allclose(f[i, mask_outside], f_post[i, mask_outside]) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py new file mode 100644 index 0000000..91a0205 --- /dev/null +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -0,0 +1,98 @@ +import pytest +import numpy as np +import warp as wp +import xlb +import jax +from xlb.compute_backend import ComputeBackend +from xlb.grid import grid_factory +from xlb.default_config import DefaultConfig + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + ], +) +def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC() + + boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + + # Make indices for boundary conditions (sphere) + sphere_radius = grid_shape[0] // 4 + nr = grid_shape[0] + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + if dim == 2: + X, Y = np.meshgrid(x, y) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) + else: + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + + indices = wp.array(indices, dtype=wp.int32) + + boundary_id_field, missing_mask = indices_boundary_masker( + indices, fullway_bc.id, boundary_id_field, missing_mask, start_index=None + ) + + # Generate a random field with the same shape + random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) + # Add the random field to f_pre + f_pre = wp.array(random_field) + + f_post = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + ) # Arbitrary value so that we can check if the values are changed outside the boundary + + f_pre = fullway_bc(f_pre, f_post, boundary_id_field, missing_mask, f_pre) + + f = f_pre.numpy() + f_post = f_post.numpy() + indices = indices.numpy() + + assert f.shape == (velocity_set.q,) + grid_shape + + for i in range(velocity_set.q): + np.allclose( + f[velocity_set.get_opp_index(i)][tuple(indices)], + f_post[i][tuple(indices)], + ) + + # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values + mask_outside = np.ones(grid_shape, dtype=bool) + mask_outside[indices] = False # Mark boundary as false + if dim == 2: + for i in range(velocity_set.q): + assert np.allclose(f[i, mask_outside], f_post[i, mask_outside]) + else: + for i in range(velocity_set.q): + assert np.allclose(f[i, mask_outside], f_post[i, mask_outside]) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index ea19bc7..6a1b05f 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -20,9 +20,11 @@ def init_xlb_env(velocity_set): "dim,velocity_set,grid_shape", [ (2, xlb.velocity_set.D2Q9, (50, 50)), - (2, xlb.velocity_set.D2Q9, (50, 50)), - (3, xlb.velocity_set.D3Q19, (20, 20, 20)), - (3, xlb.velocity_set.D3Q19, (20, 20, 20)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), ], ) def test_indices_masker_jax(dim, velocity_set, grid_shape): diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index ce42d7f..ba00a05 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -20,9 +20,11 @@ def init_xlb_env(velocity_set): "dim,velocity_set,grid_shape", [ (2, xlb.velocity_set.D2Q9, (50, 50)), - (2, xlb.velocity_set.D2Q9, (50, 50)), - (3, xlb.velocity_set.D3Q19, (20, 20, 20)), - (3, xlb.velocity_set.D3Q19, (20, 20, 20)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), ], ) def test_indices_masker_warp(dim, velocity_set, grid_shape): diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 3329187..0bc4b49 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -23,6 +23,8 @@ def init_xlb_env(velocity_set): (2, xlb.velocity_set.D2Q9, (100, 100), 1.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 0.6), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0), + (3, xlb.velocity_set.D3Q27, (50, 50, 50), 0.6), + (3, xlb.velocity_set.D3Q27, (50, 50, 50), 1.0), ], ) def test_bgk_ollision(dim, velocity_set, grid_shape, omega): diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 0159462..94e286b 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -24,6 +24,8 @@ def init_xlb_warp_env(velocity_set): (2, xlb.velocity_set.D2Q9, (100, 100), 1.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 0.6), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0), + (3, xlb.velocity_set.D3Q27, (50, 50, 50), 0.6), + (3, xlb.velocity_set.D3Q27, (50, 50, 50), 1.0), ], ) def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index bfea94a..0495613 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -18,13 +18,15 @@ def init_xlb_env(velocity_set): @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ + (2, xlb.velocity_set.D2Q9, (50, 50)), (2, xlb.velocity_set.D2Q9, (100, 100)), - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q19, (50, 50, 50)), (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), ], ) -def test_quadratic_equilibrium(dim, velocity_set, grid_shape): +def test_quadratic_equilibrium_jax(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index 366c8f1..bf47880 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -14,11 +14,18 @@ def init_xlb_env(velocity_set): velocity_set=velocity_set, ) -@pytest.mark.parametrize("dim,velocity_set,grid_shape", [ - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q27, (50, 50, 50)) -]) -def test_quadratic_equilibrium(dim, velocity_set, grid_shape): +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + ], +) +def test_quadratic_equilibrium_warp(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index a97cb1d..f1f6349 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -23,8 +23,11 @@ def init_xlb_env(velocity_set): (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 1.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0, 0.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), - ],) -def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): + (3, xlb.velocity_set.D3Q27, (50, 50, 50), 1.0, 0.0), + (3, xlb.velocity_set.D3Q27, (50, 50, 50), 1.1, 1.0), + ], +) +def test_macroscopic_jax(dim, velocity_set, grid_shape, rho, velocity): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) @@ -34,7 +37,6 @@ def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): # Compute equilibrium f_eq = QuadraticEquilibrium()(rho_field, velocity_field) - compute_macro = Macroscopic() rho_calc, u_calc = compute_macro(f_eq) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 52d6d88..ea1aaf4 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -29,7 +29,7 @@ def init_xlb_env(velocity_set): # (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. ], ) -def test_macroscopic(dim, velocity_set, grid_shape, rho, velocity): +def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) diff --git a/tests/kernels/stream/test_stream_jax.py b/tests/kernels/stream/test_stream_jax.py index 3ba4f35..0d95f5f 100644 --- a/tests/kernels/stream/test_stream_jax.py +++ b/tests/kernels/stream/test_stream_jax.py @@ -19,13 +19,15 @@ def init_xlb_env(velocity_set): @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ + (2, xlb.velocity_set.D2Q9, (50, 50)), (2, xlb.velocity_set.D2Q9, (100, 100)), - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q19, (50, 50, 50)), (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), ], ) -def test_stream_operator(dim, velocity_set, grid_shape): +def test_stream_operator_jax(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index afd2d89..6ce1329 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -21,13 +21,15 @@ def init_xlb_env(velocity_set): @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ + (2, xlb.velocity_set.D2Q9, (50, 50)), (2, xlb.velocity_set.D2Q9, (100, 100)), - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q19, (50, 50, 50)), (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), ], ) -def test_stream_operator(dim, velocity_set, grid_shape): +def test_stream_operator_warp(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid_jax = grid_factory(grid_shape, compute_backend=ComputeBackend.JAX) velocity_set = DefaultConfig.velocity_set diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 80404eb..bb7b2cd 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -32,9 +32,9 @@ class FullwayBounceBackBC(BoundaryCondition): def __init__( self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): super().__init__( ImplementationStep.COLLISION, @@ -44,12 +44,11 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_id_field, missing_mask): boundary = boundary_id_field == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return lax.select(boundary, f_pre[self.velocity_set.opp_indices], f_post) + return jnp.where(boundary, f_pre[self.velocity_set.opp_indices], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update @@ -72,9 +71,47 @@ def functional( fliped_f[l] = f_pre[_opp_indices[l]] return fliped_f + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_id_field: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + f: wp.array3d(dtype=Any), + ): # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the boundary id and missing mask + _boundary_id = boundary_id_field[0, index[0], index[1]] + + # Make vectors for the lattice + _f_pre = _f_vec() + _f_post = _f_vec() + _mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) + + # Check if the boundary is active + if _boundary_id == wp.uint8(FullwayBounceBackBC.id): + _f = functional(_f_pre, _f_post, _mask) + else: + _f = _f_post + + # Write the result to the output + for l in range(self.velocity_set.q): + f[l, index[0], index[1]] = _f[l] + # Construct the warp kernel @wp.kernel - def kernel( + def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), boundary_id_field: wp.array4d(dtype=wp.uint8), @@ -87,7 +124,7 @@ def kernel( # Get the boundary id and missing mask _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] - + # Make vectors for the lattice _f_pre = _f_vec() _f_post = _f_vec() @@ -112,8 +149,9 @@ def kernel( for l in range(self.velocity_set.q): f[l, index[0], index[1], index[2]] = _f[l] - return functional, kernel + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 03395c8..6f2bf4e 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -40,6 +40,7 @@ def __init__(self, d, q, c, w): self.w = w self.cc = self._construct_lattice_moment() self.opp_indices = self._construct_opposite_indices() + self.get_opp_index = lambda i: self.opp_indices[i] self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() From b4b825ef6092cce199004553b199f4dcb9467f37 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 10 May 2024 18:35:58 -0400 Subject: [PATCH 031/144] WIP added JAX planar masker and test is passing --- .../bc_equilibrium/test_bc_equilibrium_jax.py | 2 +- .../test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_jax.py | 4 +- .../test_bc_fullway_bounce_back_warp.py | 2 +- .../mask/test_bc_indices_masker_jax.py | 1 + .../mask/test_bc_planar_masker_jax.py | 85 +++++++++++++ .../mask/test_bc_planar_masker_warp.py | 0 tests/grids/test_grid_jax.py | 8 +- tests/grids/test_grid_warp.py | 8 +- .../collision/test_bgk_collision_jax.py | 4 +- .../collision/test_bgk_collision_warp.py | 4 +- .../equilibrium/test_equilibrium_jax.py | 4 +- .../equilibrium/test_equilibrium_warp.py | 4 +- .../macroscopic/test_macroscopic_jax.py | 4 +- .../macroscopic/test_macroscopic_warp.py | 4 +- xlb/grid/jax_grid.py | 6 +- xlb/grid/warp_grid.py | 6 +- .../boundary_masker/planar_boundary_masker.py | 113 +++++++++--------- 18 files changed, 175 insertions(+), 86 deletions(-) create mode 100644 tests/boundary_conditions/mask/test_bc_planar_masker_jax.py create mode 100644 tests/boundary_conditions/mask/test_bc_planar_masker_warp.py diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 202f4b4..966fc90 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -70,7 +70,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_post = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary f = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index eaaa17f..21c28f7 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -70,7 +70,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_post = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary f = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask, f) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 82738e9..8219501 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -65,7 +65,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): ) f_pre = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=0.0 + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0 ) # Generate a random field with the same shape key = jax.random.PRNGKey(0) @@ -74,7 +74,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): f_pre += random_field f_post = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary f = fullway_bc(f_pre, f_post, boundary_id_field, missing_mask) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 91a0205..f62bc83 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -70,7 +70,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): f_pre = wp.array(random_field) f_post = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, init_val=2.0 + cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary f_pre = fullway_bc(f_pre, f_post, boundary_id_field, missing_mask, f_pre) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index 6a1b05f..dd5099c 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -19,6 +19,7 @@ def init_xlb_env(velocity_set): @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ + (2, xlb.velocity_set.D2Q9, (4, 4)), (2, xlb.velocity_set.D2Q9, (50, 50)), (2, xlb.velocity_set.D2Q9, (100, 100)), (3, xlb.velocity_set.D3Q19, (50, 50, 50)), diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py new file mode 100644 index 0000000..24236c2 --- /dev/null +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py @@ -0,0 +1,85 @@ +import pytest +import jax.numpy as jnp +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.default_config import DefaultConfig +from xlb.grid import grid_factory + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.JAX, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (4, 4)), + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + ], +) +def test_planar_masker_jax(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + + fill_value = 0 + boundary_id_field = my_grid.create_field( + cardinality=1, dtype=xlb.Precision.UINT8, fill_value=0 + ) + + planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() + + if dim == 2: + lower_bound = (0, 0) + upper_bound = (1, grid_shape[1]) + direction = (1, 0) + else: # dim == 3 + lower_bound = (0, 0, 0) + upper_bound = (1, grid_shape[1], grid_shape[2]) + direction = (1, 0, 0) + + start_index = (0,) * dim + id_number = 1 + + boundary_id_field, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + id_number, + boundary_id_field, + missing_mask, + start_index, + ) + + # Assert that the boundary condition is set on the left side of the domain based on the lower and upper bounds + expected_slice = (slice(None),) + tuple( + slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) + ) + assert jnp.all( + boundary_id_field[expected_slice] == id_number + ), "Boundary not set correctly" + + # Assert that the rest of the domain is not affected and is equal to fill_value + full_slice = tuple(slice(None) for _ in grid_shape) + mask = jnp.ones_like(boundary_id_field, dtype=bool) + mask = mask.at[expected_slice].set(False) + assert jnp.all( + boundary_id_field[full_slice][mask] == fill_value + ), "Rest of domain incorrectly affected" + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index 88ca5fe..ce4bc70 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -51,15 +51,15 @@ def test_jax_3d_grid_initialization(grid_size): "z", ), "PartitionSpec is incorrect" -def test_jax_grid_create_field_init_val(): +def test_jax_grid_create_field_fill_value(): init_xlb_env() grid_shape = (100, 100) - init_val = 3.14 + fill_value = 3.14 my_grid = grid_factory(grid_shape) - f = my_grid.create_field(cardinality=9, init_val=init_val) + f = my_grid.create_field(cardinality=9, fill_value=fill_value) assert f.shape == (9,) + grid_shape, "Field shape is incorrect" - assert jnp.allclose(f, init_val), "Field not properly initialized with init_val" + assert jnp.allclose(f, fill_value), "Field not properly initialized with fill_value" diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 3bf1ae5..140a64e 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -26,18 +26,18 @@ def test_warp_grid_create_field(grid_size): assert isinstance(f, wp.array), "Field should be a Warp ndarray" -def test_warp_grid_create_field_init_val(): +def test_warp_grid_create_field_fill_value(): init_xlb_warp_env() grid_shape = (100, 100) - init_val = 3.14 + fill_value = 3.14 my_grid = grid_factory(grid_shape) - f = my_grid.create_field(cardinality=9, dtype=Precision.FP32, init_val=init_val) + f = my_grid.create_field(cardinality=9, dtype=Precision.FP32, fill_value=fill_value) assert isinstance(f, wp.array), "Field should be a Warp ndarray" f = f.numpy() assert f.shape == (9,) + grid_shape, "Field shape is incorrect" - assert np.allclose(f, init_val), "Field not properly initialized with init_val" + assert np.allclose(f, fill_value), "Field not properly initialized with fill_value" @pytest.fixture(autouse=True) diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 0bc4b49..4415e0d 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -31,8 +31,8 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) - rho = my_grid.create_field(cardinality=1, init_val=1.0) - u = my_grid.create_field(cardinality=dim, init_val=0.0) + rho = my_grid.create_field(cardinality=1, fill_value=1.0) + u = my_grid.create_field(cardinality=dim, fill_value=0.0) # Compute equilibrium compute_macro = QuadraticEquilibrium() diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 94e286b..7d03cd8 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -32,8 +32,8 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): init_xlb_warp_env(velocity_set) my_grid = grid_factory(grid_shape) - rho = my_grid.create_field(cardinality=1, init_val=1.0) - u = my_grid.create_field(cardinality=dim, init_val=0.0) + rho = my_grid.create_field(cardinality=1, fill_value=1.0) + u = my_grid.create_field(cardinality=dim, fill_value=0.0) compute_macro = QuadraticEquilibrium() diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index 0495613..56d2672 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -30,8 +30,8 @@ def test_quadratic_equilibrium_jax(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) - rho = my_grid.create_field(cardinality=1, init_val=1.0) - u = my_grid.create_field(cardinality=dim, init_val=0.0) + rho = my_grid.create_field(cardinality=1, fill_value=1.0) + u = my_grid.create_field(cardinality=dim, fill_value=0.0) # Compute equilibrium compute_macro = QuadraticEquilibrium() diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index bf47880..10de60d 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -29,8 +29,8 @@ def test_quadratic_equilibrium_warp(dim, velocity_set, grid_shape): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) - rho = my_grid.create_field(cardinality=1, init_val=1.0) - u = my_grid.create_field(cardinality=dim, init_val=0.0) + rho = my_grid.create_field(cardinality=1, fill_value=1.0) + u = my_grid.create_field(cardinality=dim, fill_value=0.0) f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index f1f6349..1004f9c 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -31,8 +31,8 @@ def test_macroscopic_jax(dim, velocity_set, grid_shape, rho, velocity): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) - rho_field = my_grid.create_field(cardinality=1, init_val=rho) - velocity_field = my_grid.create_field(cardinality=dim, init_val=velocity) + rho_field = my_grid.create_field(cardinality=1, fill_value=rho) + velocity_field = my_grid.create_field(cardinality=dim, fill_value=velocity) # Compute equilibrium f_eq = QuadraticEquilibrium()(rho_field, velocity_field) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index ea1aaf4..e25180f 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -33,8 +33,8 @@ def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) - rho_field = my_grid.create_field(cardinality=1, init_val=rho) - velocity_field = my_grid.create_field(cardinality=dim, init_val=velocity) + rho_field = my_grid.create_field(cardinality=1, fill_value=rho) + velocity_field = my_grid.create_field(cardinality=dim, fill_value=velocity) f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_eq = QuadraticEquilibrium()(rho_field, velocity_field, f_eq) diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index b290d53..edeb562 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -42,7 +42,7 @@ def create_field( self, cardinality: int, dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16, Precision.BOOL] = None, - init_val=None, + fill_value=None, ): sharding_dim = self.shape[-1] // self.nDevices device_shape = (cardinality, sharding_dim, *self.shape[1:]) @@ -55,8 +55,8 @@ def create_field( full_shape ).items(): jax.default_device = d - if init_val: - x = jnp.full(device_shape, init_val, dtype=dtype) + if fill_value: + x = jnp.full(device_shape, fill_value, dtype=dtype) else: x = jnp.zeros(shape=device_shape, dtype=dtype) arrays += [jax.device_put(x, d)] diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index ae1c4e8..099361b 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -27,7 +27,7 @@ def create_field( self, cardinality: int, dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, - init_val=None, + fill_value=None, ): dtype = ( dtype.wp_dtype @@ -36,8 +36,8 @@ def create_field( ) shape = (cardinality,) + (self.shape) - if init_val is None: + if fill_value is None: f = wp.zeros(shape, dtype=dtype) else: - f = wp.full(shape, init_val, dtype=dtype) + f = wp.full(shape, fill_value, dtype=dtype) return f diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index f7ffad3..6de1e61 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -6,6 +6,7 @@ from jax import jit import warp as wp from typing import Tuple +from jax.numpy import where, einsum, full_like from xlb.default_config import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet @@ -22,15 +23,15 @@ class PlanarBoundaryMasker(Operator): def __init__( self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): # Call super super().__init__(velocity_set, precision_policy, compute_backend) @Operator.register_backend(ComputeBackend.JAX) - # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this + # @partial(jit, static_argnums=(0, 1, 2, 3, 4, 7)) def jax_implementation( self, lower_bound, @@ -39,69 +40,69 @@ def jax_implementation( id_number, boundary_id_field, mask, - start_index=(0, 0, 0), + start_index=None, ): - # TODO: Optimize this - - # x plane - if direction[0] != 0: - - # Set boundary id - boundary_id_field = boundary_id_field.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) - - # Set mask - for l in range(self.velocity_set.q): - d_dot_c = ( - direction[0] * self.velocity_set.c[0, l] - + direction[1] * self.velocity_set.c[1, l] - + direction[2] * self.velocity_set.c[2, l] - ) - if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(True) - - # y plane - elif direction[1] != 0: + if start_index is None: + start_index = (0,) * self.velocity_set.d - # Set boundary id - boundary_id_field = boundary_id_field.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) + _, *dimensions = boundary_id_field.shape - # Set mask - for l in range(self.velocity_set.q): - d_dot_c = ( - direction[0] * self.velocity_set.c[0, l] - + direction[1] * self.velocity_set.c[1, l] - + direction[2] * self.velocity_set.c[2, l] - ) - if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(True) - - # z plane - elif direction[2] != 0: - - # Set boundary id - boundary_id_field = boundary_id_field.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number) + indices = [ + (max(0, lb + start), min(dim, ub + start)) + for lb, ub, start, dim in zip( + lower_bound, upper_bound, start_index, dimensions + ) + ] - # Set mask - for l in range(self.velocity_set.q): - d_dot_c = ( - direction[0] * self.velocity_set.c[0, l] - + direction[1] * self.velocity_set.c[1, l] - + direction[2] * self.velocity_set.c[2, l] - ) - if d_dot_c >= 0: - mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(True) - - return boundary_id_field, mask + slices = [slice(None)] + slices.extend(slice(lb, ub) for lb, ub in indices) + boundary_id_field = boundary_id_field.at[tuple(slices)].set(id_number) + return boundary_id_field, None def _construct_warp(self): # Make constants for warp _c = self.velocity_set.wp_c _q = wp.constant(self.velocity_set.q) - # Construct the warp kernel @wp.kernel - def kernel( + def kernel2d( + lower_bound: wp.vec3i, + upper_bound: wp.vec3i, + direction: wp.vec2i, + id_number: wp.int32, + boundary_id_field: wp.array3d(dtype=wp.uint8), + mask: wp.array3d(dtype=wp.bool), + start_index: wp.vec2i, + ): + # Get the indices of the plane to mask + plane_i, plane_j = wp.tid() + + # Get local indices + if direction[0] != 0: + i = lower_bound[0] - start_index[0] + j = plane_i + lower_bound[1] - start_index[1] + elif direction[1] != 0: + i = plane_i + lower_bound[0] - start_index[0] + j = lower_bound[1] - start_index[1] + + # Check if in bounds + if i >= 0 and i < mask.shape[1] and j >= 0 and j < mask.shape[2]: + # Set the boundary id + boundary_id_field[0, i, j] = wp.uint8(id_number) + + # Set mask for just directions coming from the boundary + for l in range(_q): + d_dot_c = ( + direction[0] * _c[0, l] + + direction[1] * _c[1, l] + + direction[2] * _c[2, l] + ) + if d_dot_c >= 0: + mask[l, i, j] = wp.bool(True) + + @wp.kernel + def kernel3d( lower_bound: wp.vec3i, upper_bound: wp.vec3i, direction: wp.vec3i, @@ -149,6 +150,8 @@ def kernel( if d_dot_c >= 0: mask[l, i, j, k] = wp.bool(True) + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return None, kernel @Operator.register_backend(ComputeBackend.WARP) From f62d06cba8232e6c78218485343cb5e47a3f7aed Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 10 May 2024 19:15:09 -0400 Subject: [PATCH 032/144] WIP warp test for planar masker is passing now. Mask part is left for both JAX and Warp. --- .../mask/test_bc_planar_masker_jax.py | 2 +- .../mask/test_bc_planar_masker_warp.py | 88 +++++++++++++ .../boundary_masker/planar_boundary_masker.py | 117 +++++------------- 3 files changed, 119 insertions(+), 88 deletions(-) diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py index 24236c2..3af8aac 100644 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py @@ -37,7 +37,7 @@ def test_planar_masker_jax(dim, velocity_set, grid_shape): fill_value = 0 boundary_id_field = my_grid.create_field( - cardinality=1, dtype=xlb.Precision.UINT8, fill_value=0 + cardinality=1, dtype=xlb.Precision.UINT8, fill_value=fill_value ) planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py index e69de29..9cb9395 100644 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py @@ -0,0 +1,88 @@ +import pytest +import numpy as np +import xlb +import warp as wp + +from xlb.compute_backend import ComputeBackend +from xlb.default_config import DefaultConfig +from xlb.grid import grid_factory + + +def init_xlb_env(velocity_set): + xlb.init( + default_precision_policy=xlb.PrecisionPolicy.FP32FP32, + default_backend=ComputeBackend.WARP, + velocity_set=velocity_set, + ) + + +@pytest.mark.parametrize( + "dim,velocity_set,grid_shape", + [ + (2, xlb.velocity_set.D2Q9, (4, 4)), + (2, xlb.velocity_set.D2Q9, (50, 50)), + (2, xlb.velocity_set.D2Q9, (100, 100)), + (3, xlb.velocity_set.D3Q19, (50, 50, 50)), + (3, xlb.velocity_set.D3Q19, (100, 100, 100)), + (3, xlb.velocity_set.D3Q27, (50, 50, 50)), + (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + ], +) +def test_planar_masker_jax(dim, velocity_set, grid_shape): + init_xlb_env(velocity_set) + my_grid = grid_factory(grid_shape) + velocity_set = DefaultConfig.velocity_set + + missing_mask = my_grid.create_field( + cardinality=velocity_set.q, dtype=xlb.Precision.BOOL + ) + fill_value = 0 + boundary_id_field = my_grid.create_field( + cardinality=1, dtype=xlb.Precision.UINT8, fill_value=fill_value + ) + + planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() + + if dim == 2: + lower_bound = (0, 0) + upper_bound = (1, grid_shape[1]) + direction = (1, 0) + else: # dim == 3 + lower_bound = (0, 0, 0) + upper_bound = (1, grid_shape[1], grid_shape[2]) + direction = (1, 0, 0) + + start_index = (0,) * dim + id_number = 1 + + boundary_id_field, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + id_number, + boundary_id_field, + missing_mask, + start_index, + ) + + boundary_id_field = boundary_id_field.numpy() + + # Assert that the boundary condition is set on the left side of the domain based on the lower and upper bounds + expected_slice = (slice(None),) + tuple( + slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) + ) + assert np.all( + boundary_id_field[expected_slice] == id_number + ), "Boundary not set correctly" + + # Assert that the rest of the domain is not affected and is equal to fill_value + full_slice = tuple(slice(None) for _ in grid_shape) + mask = np.ones_like(boundary_id_field, dtype=bool) + mask[expected_slice] = False + assert np.all( + boundary_id_field[full_slice][mask] == fill_value + ), "Rest of domain incorrectly affected" + + +if __name__ == "__main__": + pytest.main() diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index 6de1e61..e12206b 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -31,7 +31,7 @@ def __init__( super().__init__(velocity_set, precision_policy, compute_backend) @Operator.register_backend(ComputeBackend.JAX) - # @partial(jit, static_argnums=(0, 1, 2, 3, 4, 7)) + @partial(jit, static_argnums=(0, 1, 2, 3, 4, 7)) def jax_implementation( self, lower_bound, @@ -67,88 +67,45 @@ def _construct_warp(self): @wp.kernel def kernel2d( - lower_bound: wp.vec3i, - upper_bound: wp.vec3i, + lower_bound: wp.vec2i, + upper_bound: wp.vec2i, direction: wp.vec2i, - id_number: wp.int32, + id_number: wp.uint8, boundary_id_field: wp.array3d(dtype=wp.uint8), mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): - # Get the indices of the plane to mask - plane_i, plane_j = wp.tid() - - # Get local indices - if direction[0] != 0: - i = lower_bound[0] - start_index[0] - j = plane_i + lower_bound[1] - start_index[1] - elif direction[1] != 0: - i = plane_i + lower_bound[0] - start_index[0] - j = lower_bound[1] - start_index[1] - - # Check if in bounds - if i >= 0 and i < mask.shape[1] and j >= 0 and j < mask.shape[2]: - # Set the boundary id - boundary_id_field[0, i, j] = wp.uint8(id_number) - - # Set mask for just directions coming from the boundary - for l in range(_q): - d_dot_c = ( - direction[0] * _c[0, l] - + direction[1] * _c[1, l] - + direction[2] * _c[2, l] - ) - if d_dot_c >= 0: - mask[l, i, j] = wp.bool(True) + i, j = wp.tid() + lb_x, lb_y = lower_bound.x + start_index.x, lower_bound.y + start_index.y + ub_x, ub_y = upper_bound.x + start_index.x, upper_bound.y + start_index.y + + if lb_x <= i < ub_x and lb_y <= j < ub_y: + boundary_id_field[0, i, j] = id_number @wp.kernel def kernel3d( lower_bound: wp.vec3i, upper_bound: wp.vec3i, direction: wp.vec3i, - id_number: wp.int32, + id_number: wp.uint8, boundary_id_field: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): - # Get the indices of the plane to mask - plane_i, plane_j = wp.tid() - - # Get local indices - if direction[0] != 0: - i = lower_bound[0] - start_index[0] - j = plane_i + lower_bound[1] - start_index[1] - k = plane_j + lower_bound[2] - start_index[2] - elif direction[1] != 0: - i = plane_i + lower_bound[0] - start_index[0] - j = lower_bound[1] - start_index[1] - k = plane_j + lower_bound[2] - start_index[2] - elif direction[2] != 0: - i = plane_i + lower_bound[0] - start_index[0] - j = plane_j + lower_bound[1] - start_index[1] - k = lower_bound[2] - start_index[2] - - # Check if in bounds - if ( - i >= 0 - and i < mask.shape[1] - and j >= 0 - and j < mask.shape[2] - and k >= 0 - and k < mask.shape[3] - ): - # Set the boundary id - boundary_id_field[0, i, j, k] = wp.uint8(id_number) - - # Set mask for just directions coming from the boundary - for l in range(_q): - d_dot_c = ( - direction[0] * _c[0, l] - + direction[1] * _c[1, l] - + direction[2] * _c[2, l] - ) - if d_dot_c >= 0: - mask[l, i, j, k] = wp.bool(True) + i, j, k = wp.tid() + lb_x, lb_y, lb_z = ( + lower_bound.x + start_index.x, + lower_bound.y + start_index.y, + lower_bound.z + start_index.z, + ) + ub_x, ub_y, ub_z = ( + upper_bound.x + start_index.x, + upper_bound.y + start_index.y, + upper_bound.z + start_index.z, + ) + + if lb_x <= i < ub_x and lb_y <= j < ub_y and lb_z <= k < ub_z: + boundary_id_field[0, i, j, k] = id_number kernel = kernel3d if self.velocity_set.d == 3 else kernel2d @@ -163,25 +120,11 @@ def warp_implementation( id_number, boundary_id_field, mask, - start_index=(0, 0, 0), + start_index=None, ): - # Get plane dimensions - if direction[0] != 0: - dim = ( - upper_bound[1] - lower_bound[1], - upper_bound[2] - lower_bound[2], - ) - elif direction[1] != 0: - dim = ( - upper_bound[0] - lower_bound[0], - upper_bound[2] - lower_bound[2], - ) - elif direction[2] != 0: - dim = ( - upper_bound[0] - lower_bound[0], - upper_bound[1] - lower_bound[1], - ) - + if start_index is None: + start_index = (0,) * self.velocity_set.d + # Launch the warp kernel wp.launch( self.warp_kernel, @@ -194,7 +137,7 @@ def warp_implementation( mask, start_index, ], - dim=dim, + dim=mask.shape[1:], ) return boundary_id_field, mask From 96759e9b812e67f28b3bc3c7264371b6e2bcfc86 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 10 May 2024 19:35:39 -0400 Subject: [PATCH 033/144] WIP: Added more tests for planar masker --- .../mask/test_bc_planar_masker_jax.py | 80 +++++++++++++---- .../mask/test_bc_planar_masker_warp.py | 90 +++++++++++++------ 2 files changed, 127 insertions(+), 43 deletions(-) diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py index 3af8aac..e3d436b 100644 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py @@ -15,18 +15,71 @@ def init_xlb_env(velocity_set): @pytest.mark.parametrize( - "dim,velocity_set,grid_shape", + "dim,velocity_set,grid_shape,lower_bound,upper_bound,direction", [ - (2, xlb.velocity_set.D2Q9, (4, 4)), - (2, xlb.velocity_set.D2Q9, (50, 50)), - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q19, (50, 50, 50)), - (3, xlb.velocity_set.D3Q19, (100, 100, 100)), - (3, xlb.velocity_set.D3Q27, (50, 50, 50)), - (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + # 2D Grids - Different directions + ( + 2, + xlb.velocity_set.D2Q9, + (4, 4), + (0, 0), + (2, 4), + (1, 0), + ), # Horizontal direction + ( + 2, + xlb.velocity_set.D2Q9, + (50, 50), + (0, 0), + (50, 25), + (0, 1), + ), # Vertical direction + ( + 2, + xlb.velocity_set.D2Q9, + (100, 100), + (50, 0), + (100, 50), + (0, 1), + ), # Vertical direction + # 3D Grids - Different directions + ( + 3, + xlb.velocity_set.D3Q19, + (50, 50, 50), + (0, 0, 0), + (25, 50, 50), + (1, 0, 0), + ), # Along x-axis + ( + 3, + xlb.velocity_set.D3Q19, + (100, 100, 100), + (0, 50, 0), + (50, 100, 100), + (0, 1, 0), + ), # Along y-axis + ( + 3, + xlb.velocity_set.D3Q27, + (50, 50, 50), + (0, 0, 0), + (50, 25, 50), + (0, 0, 1), + ), # Along z-axis + ( + 3, + xlb.velocity_set.D3Q27, + (100, 100, 100), + (0, 0, 0), + (50, 100, 50), + (1, 0, 0), + ), # Along x-axis ], ) -def test_planar_masker_jax(dim, velocity_set, grid_shape): +def test_planar_masker_jax( + dim, velocity_set, grid_shape, lower_bound, upper_bound, direction +): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set @@ -42,15 +95,6 @@ def test_planar_masker_jax(dim, velocity_set, grid_shape): planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() - if dim == 2: - lower_bound = (0, 0) - upper_bound = (1, grid_shape[1]) - direction = (1, 0) - else: # dim == 3 - lower_bound = (0, 0, 0) - upper_bound = (1, grid_shape[1], grid_shape[2]) - direction = (1, 0, 0) - start_index = (0,) * dim id_number = 1 diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py index 9cb9395..75d1aa6 100644 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py @@ -17,22 +17,76 @@ def init_xlb_env(velocity_set): @pytest.mark.parametrize( - "dim,velocity_set,grid_shape", + "dim,velocity_set,grid_shape,lower_bound,upper_bound,direction", [ - (2, xlb.velocity_set.D2Q9, (4, 4)), - (2, xlb.velocity_set.D2Q9, (50, 50)), - (2, xlb.velocity_set.D2Q9, (100, 100)), - (3, xlb.velocity_set.D3Q19, (50, 50, 50)), - (3, xlb.velocity_set.D3Q19, (100, 100, 100)), - (3, xlb.velocity_set.D3Q27, (50, 50, 50)), - (3, xlb.velocity_set.D3Q27, (100, 100, 100)), + # 2D Grids - Different directions + ( + 2, + xlb.velocity_set.D2Q9, + (4, 4), + (0, 0), + (2, 4), + (1, 0), + ), # Horizontal direction + ( + 2, + xlb.velocity_set.D2Q9, + (50, 50), + (0, 0), + (50, 25), + (0, 1), + ), # Vertical direction + ( + 2, + xlb.velocity_set.D2Q9, + (100, 100), + (50, 0), + (100, 50), + (0, 1), + ), # Vertical direction + # 3D Grids - Different directions + ( + 3, + xlb.velocity_set.D3Q19, + (50, 50, 50), + (0, 0, 0), + (25, 50, 50), + (1, 0, 0), + ), # Along x-axis + ( + 3, + xlb.velocity_set.D3Q19, + (100, 100, 100), + (0, 50, 0), + (50, 100, 100), + (0, 1, 0), + ), # Along y-axis + ( + 3, + xlb.velocity_set.D3Q27, + (50, 50, 50), + (0, 0, 0), + (50, 25, 50), + (0, 0, 1), + ), # Along z-axis + ( + 3, + xlb.velocity_set.D3Q27, + (100, 100, 100), + (0, 0, 0), + (50, 100, 50), + (1, 0, 0), + ), # Along x-axis ], ) -def test_planar_masker_jax(dim, velocity_set, grid_shape): +def test_planar_masker_warp( + dim, velocity_set, grid_shape, lower_bound, upper_bound, direction +): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set + # Create required fields missing_mask = my_grid.create_field( cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) @@ -42,16 +96,6 @@ def test_planar_masker_jax(dim, velocity_set, grid_shape): ) planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() - - if dim == 2: - lower_bound = (0, 0) - upper_bound = (1, grid_shape[1]) - direction = (1, 0) - else: # dim == 3 - lower_bound = (0, 0, 0) - upper_bound = (1, grid_shape[1], grid_shape[2]) - direction = (1, 0, 0) - start_index = (0,) * dim id_number = 1 @@ -67,7 +111,7 @@ def test_planar_masker_jax(dim, velocity_set, grid_shape): boundary_id_field = boundary_id_field.numpy() - # Assert that the boundary condition is set on the left side of the domain based on the lower and upper bounds + # Assertions to verify boundary settings expected_slice = (slice(None),) + tuple( slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) ) @@ -75,14 +119,10 @@ def test_planar_masker_jax(dim, velocity_set, grid_shape): boundary_id_field[expected_slice] == id_number ), "Boundary not set correctly" - # Assert that the rest of the domain is not affected and is equal to fill_value + # Assertions for non-affected areas full_slice = tuple(slice(None) for _ in grid_shape) mask = np.ones_like(boundary_id_field, dtype=bool) mask[expected_slice] = False assert np.all( boundary_id_field[full_slice][mask] == fill_value ), "Rest of domain incorrectly affected" - - -if __name__ == "__main__": - pytest.main() From c05c0e8093c172dd93d7720daaaad51df289d907 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Tue, 28 May 2024 11:25:42 -0400 Subject: [PATCH 034/144] Added half-way. Not tested yet. --- .../bc_halfway_bounce_back.py | 80 +++++++++++++++++-- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index c594ca1..0560d11 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -47,12 +47,15 @@ def __init__( ) @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_id_field, missing_mask): boundary = boundary_id_field == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return lax.select(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], f_post) + return jnp.where( + jnp.logical_and(missing_mask, boundary), + f_pre[self.velocity_set.opp_indices], + f_post, + ) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update @@ -63,9 +66,38 @@ def _construct_warp(self): self.velocity_set.q, dtype=wp.uint8 ) # TODO fix vec bool + @wp.func + def functional2d( + f: wp.array3d(dtype=Any), + missing_mask: Any, + index: Any, + ): + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + # Get pull index + pull_index = type(index)() + + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + use_l = _opp_indices[l] + for d in range(self.velocity_set.d): + pull_index[d] = index[d] + + # Pull the distribution function + else: + use_l = l + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, l] + + # Get the distribution function + _f[l] = f[use_l, pull_index[0], pull_index[1]] + + return _f + # Construct the funcional to get streamed indices @wp.func - def functional( + def functional3d( f: wp.array4d(dtype=Any), missing_mask: Any, index: Any, @@ -95,7 +127,42 @@ def functional( # Construct the warp kernel @wp.kernel - def kernel( + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_id_field: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + f: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec3i(i, j) + + # Get the boundary id and missing mask + _boundary_id = boundary_id_field[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + _f = functional2d(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1]] + + # Write the distribution function + for l in range(self.velocity_set.q): + f[l, index[0], index[1]] = _f[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), boundary_id_field: wp.array4d(dtype=wp.uint8), @@ -118,7 +185,7 @@ def kernel( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional(f_pre, _missing_mask, index) + _f = functional3d(f_pre, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): @@ -128,6 +195,9 @@ def kernel( for l in range(self.velocity_set.q): f[l, index[0], index[1], index[2]] = _f[l] + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + functional = functional3d if self.velocity_set.d == 3 else functional2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) From 2d0aa5f07538ee7944ecc6cfee03d84104386c1e Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Tue, 28 May 2024 17:23:08 -0400 Subject: [PATCH 035/144] Added trackback to the error. --- xlb/operator/operator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index c6d8ad0..8b50abd 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -3,10 +3,11 @@ import inspect import warp as wp from typing import Any +import traceback from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy, Precision -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig class Operator: @@ -72,10 +73,11 @@ def __call__(self, *args, callback=None, **kwargs): return result except Exception as e: error = e + traceback_str = traceback.format_exc() continue # This skips to the next candidate if binding fails raise Exception( - f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}" + f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}" ) @property From 66a9fc268839e81890754d7db7c554b4c9ffabf7 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 30 May 2024 18:08:09 -0400 Subject: [PATCH 036/144] Refactoring done (single-GPU). Multi-GPU will be pushed soon. --- examples/cfd/example_basic.py | 69 ---- examples/cfd/lid_driven_cavity.py | 322 +++--------------- .../flow_past_sphere.py | 18 +- .../taylor_green.py | 6 +- examples/performance/mlups3d.py | 2 +- .../boundary_conditions.py | 216 ------------ .../bc_equilibrium/test_bc_equilibrium_jax.py | 13 +- .../test_bc_equilibrium_warp.py | 13 +- .../test_bc_fullway_bounce_back_jax.py | 13 +- .../test_bc_fullway_bounce_back_warp.py | 13 +- .../mask/test_bc_indices_masker_jax.py | 31 +- .../mask/test_bc_indices_masker_warp.py | 33 +- .../mask/test_bc_planar_masker_jax.py | 16 +- .../mask/test_bc_planar_masker_warp.py | 18 +- tests/grids/test_grid_warp.py | 8 +- .../collision/test_bgk_collision_jax.py | 5 +- .../collision/test_bgk_collision_warp.py | 8 +- .../equilibrium/test_equilibrium_jax.py | 5 +- .../equilibrium/test_equilibrium_warp.py | 5 +- .../macroscopic/test_macroscopic_jax.py | 5 +- .../macroscopic/test_macroscopic_warp.py | 4 +- tests/kernels/stream/test_stream_jax.py | 5 +- tests/kernels/stream/test_stream_warp.py | 5 +- xlb/__init__.py | 4 +- xlb/default_config.py | 2 +- xlb/grid/grid.py | 2 +- xlb/grid/jax_grid.py | 4 +- xlb/grid/warp_grid.py | 8 +- xlb/helper/__init__.py | 3 + xlb/helper/boundary_conditions.py | 69 ++++ xlb/helper/initializers.py | 18 + xlb/helper/nse_solver.py | 29 ++ xlb/operator/__init__.py | 2 - .../boundary_condition/bc_do_nothing.py | 67 +++- .../boundary_condition/bc_equilibrium.py | 42 ++- .../bc_fullway_bounce_back.py | 18 +- .../bc_halfway_bounce_back.py | 16 +- .../boundary_condition/boundary_condition.py | 3 +- .../indices_boundary_masker.py | 28 +- .../boundary_masker/planar_boundary_masker.py | 24 +- .../boundary_masker/stl_boundary_masker.py | 20 +- xlb/operator/collision/bgk.py | 40 ++- xlb/operator/collision/collision.py | 1 + xlb/operator/equilibrium/__init__.py | 5 +- .../equilibrium/quadratic_equilibrium.py | 3 +- xlb/operator/macroscopic/macroscopic.py | 2 +- xlb/operator/stepper/nse_stepper.py | 134 ++++++-- xlb/operator/stepper/stepper.py | 93 ++--- xlb/precision_policy.py | 16 + xlb/precision_policy/precision_policy.py | 3 +- xlb/solver/__init__.py | 2 - xlb/solver/nse.py | 116 ------- xlb/solver/solver.py | 23 -- xlb/velocity_set/d2q9.py | 1 - xlb/velocity_set/d3q19.py | 1 - xlb/velocity_set/d3q27.py | 1 - 56 files changed, 635 insertions(+), 998 deletions(-) delete mode 100644 examples/cfd/example_basic.py rename examples/{cfd => cfd_old_to_be_migrated}/flow_past_sphere.py (92%) rename examples/{cfd => cfd_old_to_be_migrated}/taylor_green.py (96%) delete mode 100644 tests/backends_conformance/boundary_conditions.py create mode 100644 xlb/helper/__init__.py create mode 100644 xlb/helper/boundary_conditions.py create mode 100644 xlb/helper/initializers.py create mode 100644 xlb/helper/nse_solver.py delete mode 100644 xlb/solver/__init__.py delete mode 100644 xlb/solver/nse.py delete mode 100644 xlb/solver/solver.py diff --git a/examples/cfd/example_basic.py b/examples/cfd/example_basic.py deleted file mode 100644 index aeac7f7..0000000 --- a/examples/cfd/example_basic.py +++ /dev/null @@ -1,69 +0,0 @@ -import xlb -from xlb.compute_backend import ComputeBackend -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.default_config import DefaultConfig -import warp as wp -from xlb.grid import grid_factory -from xlb.precision_policy import Precision -import xlb.velocity_set - -xlb.init( - default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, - velocity_set=xlb.velocity_set.D3Q19, -) - -grid_size = 50 -grid_shape = (grid_size, grid_size, grid_size) -my_grid = grid_factory(grid_shape) -f = my_grid.create_field(cardinality=9) - -# compute_macro = QuadraticEquilibrium() - -# f_eq = compute_macro(rho, u) - - -# DefaultConfig.velocity_set.w - - - - -# def initializer(): -# rho = grid.create_field(cardinality=1) + 1.0 -# u = grid.create_field(cardinality=2) - -# circle_center = (grid_shape[0] // 2, grid_shape[1] // 2) -# circle_radius = 10 - -# for x in range(grid_shape[0]): -# for y in range(grid_shape[1]): -# if (x - circle_center[0]) ** 2 + ( -# y - circle_center[1] -# ) ** 2 <= circle_radius**2: -# rho = rho.at[0, x, y].add(0.001) - -# func_eq = QuadraticEquilibrium() -# f_eq = func_eq(rho, u) - -# return f_eq - - - -# solver = IncompressibleNavierStokes(grid, omega=1.0) - - -# def perform_io(f, step): -# rho, u = compute_macro(f) -# fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1]} -# save_fields_vtk(fields, step) -# save_image(rho[0], step) -# print(f"Step {step + 1} complete") - - -# num_steps = 1000 -# io_rate = 100 -# for step in range(num_steps): -# f = solver.step(f, timestep=step) - -# if step % io_rate == 0: -# perform_io(f, step) diff --git a/examples/cfd/lid_driven_cavity.py b/examples/cfd/lid_driven_cavity.py index 96cf425..202d155 100644 --- a/examples/cfd/lid_driven_cavity.py +++ b/examples/cfd/lid_driven_cavity.py @@ -1,290 +1,76 @@ -# Simple flow past sphere example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import numpy as np - -import warp as wp - -wp.init() - import xlb -from xlb.operator import Operator - -class UniformInitializer(Operator): - - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - ): - # Get the global index - i, j, k = wp.tid() - - # Set the velocity - u[0, i, j, k] = 0.0 - u[1, i, j, k] = 0.0 - u[2, i, j, k] = 0.0 - - # Set the density - rho[0, i, j, k] = 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - ], - dim=rho.shape[1:], - ) - return rho, u - - -def run_ldc(backend, compute_mlup=True): +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.utils import save_fields_vtk, save_image - # Set the compute backend - if backend == "warp": - compute_backend = xlb.ComputeBackend.WARP - elif backend == "jax": - compute_backend = xlb.ComputeBackend.JAX +backend = ComputeBackend.JAX +velocity_set = xlb.velocity_set.D2Q9() +precision_policy = PrecisionPolicy.FP32FP32 - # Set the precision policy - precision_policy = xlb.PrecisionPolicy.FP32FP32 +xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, +) - # Set the velocity set - velocity_set = xlb.velocity_set.D3Q19() +grid_size = 512 +grid_shape = (grid_size, grid_size) - # Make grid - nr = 128 - shape = (nr, nr, nr) - if backend == "jax": - grid = xlb.grid.JaxGrid(shape=shape) - elif backend == "warp": - grid = xlb.grid.WarpGrid(shape=shape) +grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) - # Make feilds - rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) - u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) - f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id_field = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) +# Velocity on top face (2D) +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["top"] +) - # Make operators - initializer = UniformInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - collision = xlb.operator.collision.BGK( - omega=1.9, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=1.0, - u=(0, 0.10, 0.0), - equilibrium_operator=equilibrium, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream, - #boundary_conditions=[equilibrium_bc, half_way_bc, full_way_bc], - boundary_conditions=[half_way_bc, full_way_bc, equilibrium_bc], - ) - planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) +bc_eq = QuadraticEquilibrium() - # Set inlet bc (bottom x face) - lower_bound = (0, 1, 1) - upper_bound = (0, nr-1, nr-1) - direction = (1, 0, 0) - boundary_id_field, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - equilibrium_bc.id, - boundary_id_field, - missing_mask, - (0, 0, 0) - ) +bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), equilibrium_operator=bc_eq) - # Set outlet bc (top x face) - lower_bound = (nr-1, 0, 0) - upper_bound = (nr-1, nr, nr) - direction = (-1, 0, 0) - boundary_id_field, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - half_way_bc.id, - #full_way_bc.id, - boundary_id_field, - missing_mask, - (0, 0, 0) - ) - # Set half way bc (bottom y face) - lower_bound = (0, 0, 0) - upper_bound = (nr, 0, nr) - direction = (0, 1, 0) - boundary_id_field, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - half_way_bc.id, - #full_way_bc.id, - boundary_id_field, - missing_mask, - (0, 0, 0) - ) +# Wall on all other faces (2D) +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, + missing_mask, + grid_shape, + FullwayBounceBackBC.id, + ["bottom", "left", "right"], +) - # Set half way bc (top y face) - lower_bound = (0, nr-1, 0) - upper_bound = (nr, nr-1, nr) - direction = (0, -1, 0) - boundary_id_field, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - half_way_bc.id, - #full_way_bc.id, - boundary_id_field, - missing_mask, - (0, 0, 0) - ) +bc_walls = FullwayBounceBackBC() - # Set half way bc (bottom z face) - lower_bound = (0, 0, 0) - upper_bound = (nr, nr, 0) - direction = (0, 0, 1) - boundary_id_field, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - half_way_bc.id, - #full_way_bc.id, - boundary_id_field, - missing_mask, - (0, 0, 0) - ) - # Set half way bc (top z face) - lower_bound = (0, 0, nr-1) - upper_bound = (nr, nr, nr-1) - direction = (0, 0, -1) - boundary_id_field, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - half_way_bc.id, - #full_way_bc.id, - boundary_id_field, - missing_mask, - (0, 0, 0) - ) +f_0 = initialize_eq(f_0, grid, velocity_set, backend) +boundary_conditions = [bc_top, bc_walls] +omega = 1.6 - # Set initial conditions - if backend == "warp": - rho, u = initializer(rho, u) - f0 = equilibrium(rho, u, f0) - elif backend == "jax": - rho = rho + 1.0 - f0 = equilibrium(rho, u) +stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=boundary_conditions +) - # Time stepping - plot_freq = 128 - save_dir = "ldc" - os.makedirs(save_dir, exist_ok=True) - num_steps = nr * 16 - start = time.time() +for i in range(50000): + f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_0, f_1 = f_1, f_0 - for _ in tqdm(range(num_steps)): - # Time step - if backend == "warp": - f1 = stepper(f0, f1, boundary_id_field, missing_mask, _) - f1, f0 = f0, f1 - elif backend == "jax": - f0 = stepper(f0, boundary_id_field, missing_mask, _) - # Plot if necessary - if (_ % plot_freq == 0) and (not compute_mlup): - if backend == "warp": - rho, u = macroscopic(f0, rho, u) - local_rho = rho.numpy() - local_u = u.numpy() - local_boundary_id = boundary_id_field.numpy() - elif backend == "jax": - local_rho, local_u = macroscopic(f0) - local_boundary_id = boundary_id_field +# Write the results +macro = Macroscopic() - # Plot the velocity field, rho and boundary id side by side - plt.subplot(1, 3, 1) - plt.imshow(np.linalg.norm(local_u[:, :, nr // 2, :], axis=0)) - plt.colorbar() - plt.subplot(1, 3, 2) - plt.imshow(local_rho[0, :, nr // 2, :]) - plt.colorbar() - plt.subplot(1, 3, 3) - plt.imshow(local_boundary_id[0, :, nr // 2, :]) - plt.colorbar() - plt.savefig(f"{save_dir}/{backend}_{str(_).zfill(6)}.png") - plt.close() +rho, u = macro(f_0) - wp.synchronize() - end = time.time() +# remove boundary cells +rho = rho[:, 1:-1, 1:-1] +u = u[:, 1:-1, 1:-1] - # Print MLUPS - print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") +u_magnitude = (u[0]**2 + u[1]**2)**0.5 -if __name__ == "__main__": +fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} + +save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") - # Run the LDC example - backends = ["warp", "jax"] - compute_mlup = False - for backend in backends: - run_ldc(backend, compute_mlup=compute_mlup) +save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") \ No newline at end of file diff --git a/examples/cfd/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py similarity index 92% rename from examples/cfd/flow_past_sphere.py rename to examples/cfd_old_to_be_migrated/flow_past_sphere.py index 515b6f0..7e8af30 100644 --- a/examples/cfd/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -75,7 +75,7 @@ def warp_implementation(self, rho, u, vel): u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_id_field = grid.create_field(cardinality=1, dtype=wp.uint8) + boundary_mask = grid.create_field(cardinality=1, dtype=wp.uint8) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators @@ -157,10 +157,10 @@ def warp_implementation(self, rho, u, vel): indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_id_field, missing_mask = indices_boundary_masker( + boundary_mask, missing_mask = indices_boundary_masker( indices, half_way_bc.id, - boundary_id_field, + boundary_mask, missing_mask, (0, 0, 0) ) @@ -169,12 +169,12 @@ def warp_implementation(self, rho, u, vel): lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) - boundary_id_field, missing_mask = planar_boundary_masker( + boundary_mask, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, equilibrium_bc.id, - boundary_id_field, + boundary_mask, missing_mask, (0, 0, 0) ) @@ -183,12 +183,12 @@ def warp_implementation(self, rho, u, vel): lower_bound = (nr-1, 0, 0) upper_bound = (nr-1, nr, nr) direction = (-1, 0, 0) - boundary_id_field, missing_mask = planar_boundary_masker( + boundary_mask, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, do_nothing_bc.id, - boundary_id_field, + boundary_mask, missing_mask, (0, 0, 0) ) @@ -206,7 +206,7 @@ def warp_implementation(self, rho, u, vel): num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_id_field, missing_mask, _) + f1 = stepper(f0, f1, boundary_mask, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) @@ -216,7 +216,7 @@ def warp_implementation(self, rho, u, vel): plt.imshow(u[0, :, nr // 2, :].numpy()) plt.colorbar() plt.subplot(1, 2, 2) - plt.imshow(boundary_id_field[0, :, nr // 2, :].numpy()) + plt.imshow(boundary_mask[0, :, nr // 2, :].numpy()) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() diff --git a/examples/cfd/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py similarity index 96% rename from examples/cfd/taylor_green.py rename to examples/cfd_old_to_be_migrated/taylor_green.py index b806225..10eb54f 100644 --- a/examples/cfd/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -135,7 +135,7 @@ def run_taylor_green(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id_field = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + boundary_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -187,10 +187,10 @@ def run_taylor_green(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, boundary_id_field, missing_mask, _) + f1 = stepper(f0, f1, boundary_mask, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, boundary_id_field, missing_mask, _) + f0 = stepper(f0, boundary_mask, missing_mask, _) # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): diff --git a/examples/performance/mlups3d.py b/examples/performance/mlups3d.py index bce5a0b..f044c33 100644 --- a/examples/performance/mlups3d.py +++ b/examples/performance/mlups3d.py @@ -6,7 +6,7 @@ from xlb.precision_policy import Fp32Fp32 from xlb.operator.initializer import EquilibriumInitializer -from xlb.solver import IncompressibleNavierStokes +from xlb.helper import IncompressibleNavierStokes from xlb.grid import grid_factory parser = argparse.ArgumentParser( diff --git a/tests/backends_conformance/boundary_conditions.py b/tests/backends_conformance/boundary_conditions.py deleted file mode 100644 index 5a4f0b0..0000000 --- a/tests/backends_conformance/boundary_conditions.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import numpy as np -import jax.numpy as jnp -import warp as wp -from xlb.grid import grid_factory -import xlb - -wp.init() - - -class TestBoundaryConditions(unittest.TestCase): - def setUp(self): - self.backends = ["warp", "jax"] - self.results = {} - - def run_boundary_conditions(self, backend): - # Set the compute backend - if backend == "warp": - compute_backend = xlb.ComputeBackend.WARP - elif backend == "jax": - compute_backend = xlb.ComputeBackend.JAX - - # Set the precision policy - precision_policy = xlb.PrecisionPolicy.FP32FP32 - - # Set the velocity set - velocity_set = xlb.velocity_set.D3Q19() - - # Make grid - nr = 128 - shape = (nr, nr, nr) - grid = grid_factory(shape) - # Make fields - f_pre = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.FP32 - ) - f_post = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.FP32 - ) - f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_id_field = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - missing_mask = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.BOOL - ) - - # Make needed operators - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=equilibrium, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - halfway_bounce_back_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - fullway_bounce_back_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - - # Make indices for boundary conditions (sphere) - sphere_radius = 10 - x = np.arange(nr) - y = np.arange(nr) - z = np.arange(nr) - X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) - indices = np.array(indices).T - if backend == "jax": - indices = jnp.array(indices) - elif backend == "warp": - indices = wp.from_numpy(indices, dtype=wp.int32) - - # Test equilibrium boundary condition - boundary_id_field, missing_mask = indices_boundary_masker( - indices, equilibrium_bc.id, boundary_id_field, missing_mask, (0, 0, 0) - ) - if backend == "jax": - f_equilibrium = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask) - elif backend == "warp": - f_equilibrium = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.FP32 - ) - f_equilibrium = equilibrium_bc( - f_pre, f_post, boundary_id_field, missing_mask, f_equilibrium - ) - - # Test do nothing boundary condition - boundary_id_field, missing_mask = indices_boundary_masker( - indices, do_nothing_bc.id, boundary_id_field, missing_mask, (0, 0, 0) - ) - if backend == "jax": - f_do_nothing = do_nothing_bc(f_pre, f_post, boundary_id_field, missing_mask) - elif backend == "warp": - f_do_nothing = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.FP32 - ) - f_do_nothing = do_nothing_bc( - f_pre, f_post, boundary_id_field, missing_mask, f_do_nothing - ) - - # Test halfway bounce back boundary condition - boundary_id_field, missing_mask = indices_boundary_masker( - indices, halfway_bounce_back_bc.id, boundary_id_field, missing_mask, (0, 0, 0) - ) - if backend == "jax": - f_halfway_bounce_back = halfway_bounce_back_bc( - f_pre, f_post, boundary_id_field, missing_mask - ) - elif backend == "warp": - f_halfway_bounce_back = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.FP32 - ) - f_halfway_bounce_back = halfway_bounce_back_bc( - f_pre, f_post, boundary_id_field, missing_mask, f_halfway_bounce_back - ) - - # Test the full boundary condition - boundary_id_field, missing_mask = indices_boundary_masker( - indices, fullway_bounce_back_bc.id, boundary_id_field, missing_mask, (0, 0, 0) - ) - if backend == "jax": - f_fullway_bounce_back = fullway_bounce_back_bc( - f_pre, f_post, boundary_id_field, missing_mask - ) - elif backend == "warp": - f_fullway_bounce_back = grid.create_field( - cardinality=velocity_set.q, precision=xlb.Precision.FP32 - ) - f_fullway_bounce_back = fullway_bounce_back_bc( - f_pre, f_post, boundary_id_field, missing_mask, f_fullway_bounce_back - ) - - return f_equilibrium, f_do_nothing, f_halfway_bounce_back, f_fullway_bounce_back - - def test_boundary_conditions(self): - for backend in self.backends: - ( - f_equilibrium, - f_do_nothing, - f_halfway_bounce_back, - f_fullway_bounce_back, - ) = self.run_boundary_conditions(backend) - self.results[backend] = { - "equilibrium": np.array(f_equilibrium) - if backend == "jax" - else f_equilibrium.numpy(), - "do_nothing": np.array(f_do_nothing) - if backend == "jax" - else f_do_nothing.numpy(), - "halfway_bounce_back": np.array(f_halfway_bounce_back) - if backend == "jax" - else f_halfway_bounce_back.numpy(), - "fullway_bounce_back": np.array(f_fullway_bounce_back) - if backend == "jax" - else f_fullway_bounce_back.numpy(), - } - - for test_name in [ - "equilibrium", - "do_nothing", - "halfway_bounce_back", - "fullway_bounce_back", - ]: - with self.subTest(test_name=test_name): - warp_results = self.results["warp"][test_name] - jax_results = self.results["jax"][test_name] - - is_close = np.allclose(warp_results, jax_results, atol=1e-8, rtol=1e-5) - if not is_close: - diff_indices = np.where( - ~np.isclose(warp_results, jax_results, atol=1e-8, rtol=1e-5) - ) - differences = [ - (idx, warp_results[idx], jax_results[idx]) - for idx in zip(*diff_indices) - ] - difference_str = "\n".join( - [ - f"Index: {idx}, Warp: {w}, JAX: {j}" - for idx, w, j in differences - ] - ) - msg = f"{test_name} test failed: results do not match between backends. Differences:\n{difference_str}" - else: - msg = "" - - self.assertTrue(is_close, msg=msg) - - -if __name__ == "__main__": - unittest.main() - diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 966fc90..e937295 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -4,14 +4,13 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -33,7 +32,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) - boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -63,8 +62,8 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): equilibrium_operator=equilibrium, ) - boundary_id_field, missing_mask = indices_boundary_masker( - indices, equilibrium_bc.id, boundary_id_field, missing_mask, start_index=None + boundary_mask, missing_mask = indices_boundary_masker( + indices, equilibrium_bc.id, boundary_mask, missing_mask, start_index=None ) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -73,7 +72,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask) + f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 21c28f7..a94fcbe 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -4,14 +4,13 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -33,7 +32,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) - boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -63,8 +62,8 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): equilibrium_operator=equilibrium, ) - boundary_id_field, missing_mask = indices_boundary_masker( - indices, equilibrium_bc.id, boundary_id_field, missing_mask, start_index=None + boundary_mask, missing_mask = indices_boundary_masker( + indices, equilibrium_bc.id, boundary_mask, missing_mask, start_index=None ) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -73,7 +72,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_id_field, missing_mask, f) + f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask, f) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 8219501..6ec3edc 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -5,14 +5,13 @@ import jax from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -38,7 +37,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC() - boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -60,8 +59,8 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = jnp.array(indices) - boundary_id_field, missing_mask = indices_boundary_masker( - indices, fullway_bc.id, boundary_id_field, missing_mask, start_index=None + boundary_mask, missing_mask = indices_boundary_masker( + indices, fullway_bc.id, boundary_mask, missing_mask, start_index=None ) f_pre = my_grid.create_field( @@ -77,7 +76,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = fullway_bc(f_pre, f_post, boundary_id_field, missing_mask) + f = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index f62bc83..a9b8c11 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -5,14 +5,13 @@ import jax from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -38,7 +37,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC() - boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -60,8 +59,8 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = wp.array(indices, dtype=wp.int32) - boundary_id_field, missing_mask = indices_boundary_masker( - indices, fullway_bc.id, boundary_id_field, missing_mask, start_index=None + boundary_mask, missing_mask = indices_boundary_masker( + indices, fullway_bc.id, boundary_mask, missing_mask, start_index=None ) # Generate a random field with the same shape @@ -73,7 +72,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, boundary_id_field, missing_mask, f_pre) + f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask, f_pre) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index dd5099c..0480605 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -3,8 +3,7 @@ import numpy as np import xlb from xlb.compute_backend import ComputeBackend -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig from xlb.grid import grid_factory @@ -12,7 +11,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -37,7 +36,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) - boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -61,32 +60,32 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): assert indices.shape[0] == dim test_id = 5 - boundary_id_field, missing_mask = indices_boundary_masker( - indices, test_id, boundary_id_field, missing_mask, start_index=None + boundary_mask, missing_mask = indices_boundary_masker( + indices, test_id, boundary_mask, missing_mask, start_index=None ) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype - assert boundary_id_field.dtype == xlb.Precision.UINT8.jax_dtype + assert boundary_mask.dtype == xlb.Precision.UINT8.jax_dtype - assert boundary_id_field.shape == (1,) + grid_shape + assert boundary_mask.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert jnp.all(boundary_id_field[0, indices[0], indices[1]] == test_id) - # assert that the rest of the boundary_id_field is zero - boundary_id_field = boundary_id_field.at[0, indices[0], indices[1]].set(0) - assert jnp.all(boundary_id_field == 0) + assert jnp.all(boundary_mask[0, indices[0], indices[1]] == test_id) + # assert that the rest of the boundary_mask is zero + boundary_mask = boundary_mask.at[0, indices[0], indices[1]].set(0) + assert jnp.all(boundary_mask == 0) if dim == 3: assert jnp.all( - boundary_id_field[0, indices[0], indices[1], indices[2]] == test_id + boundary_mask[0, indices[0], indices[1], indices[2]] == test_id ) - # assert that the rest of the boundary_id_field is zero - boundary_id_field = boundary_id_field.at[ + # assert that the rest of the boundary_mask is zero + boundary_mask = boundary_mask.at[ 0, indices[0], indices[1], indices[2] ].set(0) - assert jnp.all(boundary_id_field == 0) + assert jnp.all(boundary_mask == 0) if __name__ == "__main__": diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index ba00a05..782551c 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -3,8 +3,7 @@ import numpy as np import xlb from xlb.compute_backend import ComputeBackend -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig from xlb.grid import grid_factory @@ -12,7 +11,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -36,7 +35,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) - boundary_id_field = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -60,37 +59,37 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): assert indices.shape[0] == dim test_id = 5 - boundary_id_field, missing_mask = indices_boundary_masker( + boundary_mask, missing_mask = indices_boundary_masker( indices, test_id, - boundary_id_field, + boundary_mask, missing_mask, start_index=(0, 0, 0) if dim == 3 else (0, 0), ) assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype - assert boundary_id_field.dtype == xlb.Precision.UINT8.wp_dtype + assert boundary_mask.dtype == xlb.Precision.UINT8.wp_dtype - boundary_id_field = boundary_id_field.numpy() + boundary_mask = boundary_mask.numpy() missing_mask = missing_mask.numpy() indices = indices.numpy() - assert boundary_id_field.shape == (1,) + grid_shape + assert boundary_mask.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert np.all(boundary_id_field[0, indices[0], indices[1]] == test_id) - # assert that the rest of the boundary_id_field is zero - boundary_id_field[0, indices[0], indices[1]]= 0 - assert np.all(boundary_id_field == 0) + assert np.all(boundary_mask[0, indices[0], indices[1]] == test_id) + # assert that the rest of the boundary_mask is zero + boundary_mask[0, indices[0], indices[1]]= 0 + assert np.all(boundary_mask == 0) if dim == 3: assert np.all( - boundary_id_field[0, indices[0], indices[1], indices[2]] == test_id + boundary_mask[0, indices[0], indices[1], indices[2]] == test_id ) - # assert that the rest of the boundary_id_field is zero - boundary_id_field[0, indices[0], indices[1], indices[2]] = 0 - assert np.all(boundary_id_field == 0) + # assert that the rest of the boundary_mask is zero + boundary_mask[0, indices[0], indices[1], indices[2]] = 0 + assert np.all(boundary_mask == 0) if __name__ == "__main__": pytest.main() diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py index e3d436b..96a382e 100644 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import xlb from xlb.compute_backend import ComputeBackend -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.grid import grid_factory @@ -10,7 +10,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -89,7 +89,7 @@ def test_planar_masker_jax( ) fill_value = 0 - boundary_id_field = my_grid.create_field( + boundary_mask = my_grid.create_field( cardinality=1, dtype=xlb.Precision.UINT8, fill_value=fill_value ) @@ -98,12 +98,12 @@ def test_planar_masker_jax( start_index = (0,) * dim id_number = 1 - boundary_id_field, missing_mask = planar_boundary_masker( + boundary_mask, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, id_number, - boundary_id_field, + boundary_mask, missing_mask, start_index, ) @@ -113,15 +113,15 @@ def test_planar_masker_jax( slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) ) assert jnp.all( - boundary_id_field[expected_slice] == id_number + boundary_mask[expected_slice] == id_number ), "Boundary not set correctly" # Assert that the rest of the domain is not affected and is equal to fill_value full_slice = tuple(slice(None) for _ in grid_shape) - mask = jnp.ones_like(boundary_id_field, dtype=bool) + mask = jnp.ones_like(boundary_mask, dtype=bool) mask = mask.at[expected_slice].set(False) assert jnp.all( - boundary_id_field[full_slice][mask] == fill_value + boundary_mask[full_slice][mask] == fill_value ), "Rest of domain incorrectly affected" diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py index 75d1aa6..deee70b 100644 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py @@ -4,7 +4,7 @@ import warp as wp from xlb.compute_backend import ComputeBackend -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.grid import grid_factory @@ -12,7 +12,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @@ -91,7 +91,7 @@ def test_planar_masker_warp( cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) fill_value = 0 - boundary_id_field = my_grid.create_field( + boundary_mask = my_grid.create_field( cardinality=1, dtype=xlb.Precision.UINT8, fill_value=fill_value ) @@ -99,30 +99,30 @@ def test_planar_masker_warp( start_index = (0,) * dim id_number = 1 - boundary_id_field, missing_mask = planar_boundary_masker( + boundary_mask, missing_mask = planar_boundary_masker( lower_bound, upper_bound, direction, id_number, - boundary_id_field, + boundary_mask, missing_mask, start_index, ) - boundary_id_field = boundary_id_field.numpy() + boundary_mask = boundary_mask.numpy() # Assertions to verify boundary settings expected_slice = (slice(None),) + tuple( slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) ) assert np.all( - boundary_id_field[expected_slice] == id_number + boundary_mask[expected_slice] == id_number ), "Boundary not set correctly" # Assertions for non-affected areas full_slice = tuple(slice(None) for _ in grid_shape) - mask = np.ones_like(boundary_id_field, dtype=bool) + mask = np.ones_like(boundary_mask, dtype=bool) mask[expected_slice] = False assert np.all( - boundary_id_field[full_slice][mask] == fill_value + boundary_mask[full_slice][mask] == fill_value ), "Rest of domain incorrectly affected" diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 140a64e..22445cc 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -7,7 +7,7 @@ from xlb.precision_policy import Precision -def init_xlb_warp_env(): +def init_xlb_env(): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, @@ -18,7 +18,7 @@ def init_xlb_warp_env(): @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_warp_grid_create_field(grid_size): for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]: - init_xlb_warp_env() + init_xlb_env() my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) @@ -27,7 +27,7 @@ def test_warp_grid_create_field(grid_size): def test_warp_grid_create_field_fill_value(): - init_xlb_warp_env() + init_xlb_env() grid_shape = (100, 100) fill_value = 3.14 my_grid = grid_factory(grid_shape) @@ -42,7 +42,7 @@ def test_warp_grid_create_field_fill_value(): @pytest.fixture(autouse=True) def setup_xlb_env(): - init_xlb_warp_env() + init_xlb_env() if __name__ == "__main__": diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 4415e0d..76d7233 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -5,14 +5,13 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.collision import BGK from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 7d03cd8..f0fff4a 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -7,14 +7,14 @@ from xlb.operator.macroscopic import Macroscopic from xlb.operator.collision import BGK from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.precision_policy import Precision -def init_xlb_warp_env(velocity_set): +def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @pytest.mark.parametrize( @@ -29,7 +29,7 @@ def init_xlb_warp_env(velocity_set): ], ) def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): - init_xlb_warp_env(velocity_set) + init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) rho = my_grid.create_field(cardinality=1, fill_value=1.0) diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index 56d2672..fbdadb6 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -4,14 +4,13 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index 10de60d..ef2287f 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -5,13 +5,12 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) @pytest.mark.parametrize( diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 1004f9c..89ef393 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -5,14 +5,13 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index e25180f..69c09ac 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -5,7 +5,7 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.grid import grid_factory -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig import warp as wp @@ -13,7 +13,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) diff --git a/tests/kernels/stream/test_stream_jax.py b/tests/kernels/stream/test_stream_jax.py index 0d95f5f..ef635ea 100644 --- a/tests/kernels/stream/test_stream_jax.py +++ b/tests/kernels/stream/test_stream_jax.py @@ -3,8 +3,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.stream import Stream -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig from xlb.grid import grid_factory @@ -12,7 +11,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index 6ce1329..af70b4c 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -5,8 +5,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.stream import Stream -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig from xlb.grid import grid_factory @@ -14,7 +13,7 @@ def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set, + velocity_set=velocity_set(), ) diff --git a/xlb/__init__.py b/xlb/__init__.py index 4dc0639..06e9f11 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -4,7 +4,7 @@ from xlb.physics_type import PhysicsType # Config -from .default_config import init, default_backend +from .default_config import init, DefaultConfig # Velocity Set import xlb.velocity_set @@ -20,7 +20,7 @@ import xlb.grid # Solvers -import xlb.solver +import xlb.helper # Utils import xlb.utils diff --git a/xlb/default_config.py b/xlb/default_config.py index 7f0e617..f1ca25f 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -10,7 +10,7 @@ class DefaultConfig: def init(velocity_set, default_backend, default_precision_policy): - DefaultConfig.velocity_set = velocity_set() + DefaultConfig.velocity_set = velocity_set DefaultConfig.default_backend = default_backend DefaultConfig.default_precision_policy = default_precision_policy diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 4fc0665..5a796f1 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Literal, Optional, Tuple -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend from xlb.precision_policy import Precision diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index edeb562..64dc33a 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -9,7 +9,9 @@ from jax import lax import jax -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig + + from .grid import Grid from xlb.operator import Operator from xlb.precision_policy import Precision diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index 099361b..75c3f14 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -6,7 +6,7 @@ from xlb.precision_policy import Precision from xlb.compute_backend import ComputeBackend from typing import Literal -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig import numpy as np @@ -17,12 +17,6 @@ def __init__(self, shape): def _initialize_backend(self): pass - def parallelize_operator(self, operator: Operator): - # TODO: Implement parallelization of the operator - raise NotImplementedError( - "Parallelization of the operator is not implemented yet for the WarpGrid" - ) - def create_field( self, cardinality: int, diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py new file mode 100644 index 0000000..2dab37b --- /dev/null +++ b/xlb/helper/__init__.py @@ -0,0 +1,3 @@ +from xlb.helper.nse_solver import create_nse_fields +from xlb.helper.initializers import initialize_eq +from xlb.helper.boundary_conditions import assign_bc_id_box_faces \ No newline at end of file diff --git a/xlb/helper/boundary_conditions.py b/xlb/helper/boundary_conditions.py new file mode 100644 index 0000000..3178bc2 --- /dev/null +++ b/xlb/helper/boundary_conditions.py @@ -0,0 +1,69 @@ +from xlb.operator.boundary_masker import PlanarBoundaryMasker + + +def assign_bc_id_box_faces(boundary_mask, missing_mask, shape, bc_id, sides): + """ + Assign boundary conditions for specified sides of 2D and 3D boxes using planar_boundary_masker function. + + Parameters: + boundary_mask: ndarray + The field containing boundary IDs. + missing_mask: ndarray + The mask indicating missing boundary IDs. + shape: tuple + The shape of the grid (extent of the grid in each dimension). + bc_id: int + The boundary condition ID to assign to the specified boundaries. + sides: list of str + The list of sides to apply conditions to. Valid values for 2D are 'bottom', 'top', 'left', 'right'. + Valid values for 3D are 'bottom', 'top', 'front', 'back', 'left', 'right'. + """ + + planar_boundary_masker = PlanarBoundaryMasker() + + def apply(lower_bound, upper_bound, direction, reference=(0, 0, 0)): + nonlocal boundary_mask, missing_mask, planar_boundary_masker + boundary_mask, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + bc_id, + boundary_mask, + missing_mask, + reference, + ) + + dimensions = len(shape) + + if dimensions == 2: + nr, nc = shape + for boundary in sides: + if boundary == "bottom": + apply((0, 0), (nr, 1), (1, 0), (0, 0)) + elif boundary == "top": + apply((0, nc - 1), (nr, nc), (1, 0), (0, 0)) + elif boundary == "left": + apply((0, 0), (1, nc), (0, 1), (0, 0)) + elif boundary == "right": + apply((nr - 1, 0), (nr, nc), (0, 1), (0, 0)) + + elif dimensions == 3: + nr, nc, nz = shape + for boundary in sides: + if boundary == "bottom": + apply((0, 0, 0), (nr, 1, nz), (1, 0, 0), (0, 0, 0)) + elif boundary == "top": + apply((0, nc - 1, 0), (nr, nc, nz), (1, 0, 0), (0, 0, 0)) + elif boundary == "front": + apply((0, 0, 0), (nr, nc, 1), (0, 1, 0), (0, 0, 0)) + elif boundary == "back": + apply((0, 0, nz - 1), (nr, nc, nz), (0, 1, 0), (0, 0, 0)) + elif boundary == "left": + apply((0, 0, 0), (1, nc, nz), (0, 0, 1), (0, 0, 0)) + elif boundary == "right": + apply((nr - 1, 0, 0), (nr, nc, nz), (0, 0, 1), (0, 0, 0)) + + else: + raise ValueError("Unsupported dimensions: {}".format(dimensions)) + + return boundary_mask, missing_mask diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py new file mode 100644 index 0000000..aff3ee8 --- /dev/null +++ b/xlb/helper/initializers.py @@ -0,0 +1,18 @@ +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium + + +def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None): + rho = rho or grid.create_field(cardinality=1, fill_value=1.0) + u = u or grid.create_field(cardinality=velocity_set.d, fill_value=0.0) + equilibrium = QuadraticEquilibrium() + + if backend == ComputeBackend.JAX: + f = equilibrium(rho, u) + + elif backend == ComputeBackend.WARP: + f = equilibrium(rho, u, f) + + del rho, u + + return f diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py new file mode 100644 index 0000000..7fc11fc --- /dev/null +++ b/xlb/helper/nse_solver.py @@ -0,0 +1,29 @@ +import xlb +from xlb.compute_backend import ComputeBackend +from xlb import DefaultConfig +from xlb.grid import grid_factory +from xlb.precision_policy import Precision +from typing import Tuple + + +def create_nse_fields( + grid_shape: Tuple[int, int, int], velocity_set=None, compute_backend=None, precision_policy=None +): + velocity_set = velocity_set if velocity_set else DefaultConfig.velocity_set + compute_backend = ( + compute_backend if compute_backend else DefaultConfig.default_backend + ) + precision_policy = ( + precision_policy if precision_policy else DefaultConfig.default_precision_policy + ) + grid = grid_factory(grid_shape, compute_backend=compute_backend) + + # Create fields + f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) + f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) + missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) + boundary_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) + + return grid, f_0, f_1, missing_mask, boundary_mask + + diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index c1232a3..02b8a59 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,4 +1,2 @@ from xlb.operator.operator import Operator from xlb.operator.parallel_operator import ParallelOperator -import xlb.operator.stepper -import xlb.operator.boundary_masker diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index d1d1a81..37a91ab 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -46,8 +46,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_id_field, missing_mask): - boundary = (boundary_id_field == self.id) + def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + boundary = boundary_mask == self.id boundary = boundary[:, None, None, None] return jnp.where(boundary, f_pre, f_post) @@ -59,8 +59,20 @@ def _construct_warp(self): ) # TODO fix vec bool # Construct the funcional to get streamed indices + + @wp.func + def functional2d( + f: wp.array3d(dtype=Any), + missing_mask: Any, + index: Any, + ): + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + return _f + @wp.func - def functional( + def functional3d( f: wp.array4d(dtype=Any), missing_mask: Any, index: Any, @@ -70,12 +82,46 @@ def functional( _f[l] = f[l, index[0], index[1], index[2]] return _f + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.uint8), + f: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the boundary id and missing mask + _boundary_id = boundary_mask[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(DoNothingBC.id): + _f = functional3d(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1]] + + # Write the result + for l in range(self.velocity_set.q): + f[l, index[0], index[1]] = _f[l] + # Construct the warp kernel @wp.kernel - def kernel( + def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -84,7 +130,7 @@ def kernel( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -95,7 +141,7 @@ def kernel( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional(f_pre, _missing_mask, index) + _f = functional3d(f_pre, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): @@ -105,14 +151,17 @@ def kernel( for l in range(self.velocity_set.q): f[l, index[0], index[1], index[2]] = _f[l] + functional = functional3d if self.velocity_set.d == 3 else functional2d + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 7c0505d..ee06fcd 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -13,6 +13,7 @@ from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator.operator import Operator from xlb.operator.boundary_condition.boundary_condition import ( ImplementationStep, @@ -43,6 +44,9 @@ def __init__( self.rho = rho self.u = u self.equilibrium_operator = equilibrium_operator + # Raise error if equilibrium operator is not a subclass of Equilibrium + if not issubclass(type(self.equilibrium_operator), Equilibrium): + raise ValueError("Equilibrium operator must be a subclass of Equilibrium") # Call the parent constructor super().__init__( @@ -54,11 +58,11 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_id_field, missing_mask): + def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) new_shape = feq.shape + (1,) * self.velocity_set.d feq = lax.broadcast_in_dim(feq, new_shape, [0]) - boundary = boundary_id_field == self.id + boundary = boundary_mask == self.id return jnp.where(boundary, feq, f_post) @@ -67,14 +71,22 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(self.rho) - _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) + _u = ( + _u_vec(self.u[0], self.u[1], self.u[2]) + if self.velocity_set.d == 3 + else _u_vec(self.u[0], self.u[1]) + ) _missing_mask_vec = wp.vec( self.velocity_set.q, dtype=wp.uint8 ) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func - def functional2d(): + def functional2d( + f: wp.array3d(dtype=Any), + missing_mask: Any, + index: Any, + ): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f @@ -83,7 +95,7 @@ def functional2d(): def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_id_field: wp.array3d(dtype=wp.uint8), + boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), f: wp.array3d(dtype=Any), ): @@ -92,7 +104,7 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1]] + _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -103,7 +115,7 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional2d() + _f = functional2d(f_post, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): @@ -114,7 +126,11 @@ def kernel2d( f[l, index[0], index[1]] = _f[l] @wp.func - def functional3d(): + def functional3d( + f: wp.array4d(dtype=Any), + missing_mask: Any, + index: Any, + ): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f @@ -123,7 +139,7 @@ def functional3d(): def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -132,7 +148,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -143,7 +159,7 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional3d() + _f = functional3d(f_post, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): @@ -159,11 +175,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index bb7b2cd..31b2a1f 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -45,10 +45,10 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_id_field, missing_mask): - boundary = boundary_id_field == self.id + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + boundary = boundary_mask == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return jnp.where(boundary, f_pre[self.velocity_set.opp_indices], f_post) + return jnp.where(boundary, f_pre[self.velocity_set.opp_indices,...], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update @@ -75,7 +75,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_id_field: wp.array3d(dtype=wp.uint8), + boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), f: wp.array3d(dtype=Any), ): # Get the global index @@ -83,7 +83,7 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1]] + _boundary_id = boundary_mask[0, index[0], index[1]] # Make vectors for the lattice _f_pre = _f_vec() @@ -114,7 +114,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -123,7 +123,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] # Make vectors for the lattice _f_pre = _f_vec() @@ -154,11 +154,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 0560d11..d748bdd 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -48,8 +48,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_id_field, missing_mask): - boundary = boundary_id_field == self.id + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + boundary = boundary_mask == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) return jnp.where( jnp.logical_and(missing_mask, boundary), @@ -130,7 +130,7 @@ def functional3d( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_id_field: wp.array3d(dtype=wp.uint8), + boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), f: wp.array3d(dtype=Any), ): @@ -139,7 +139,7 @@ def kernel2d( index = wp.vec3i(i, j) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1]] + _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -165,7 +165,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), f: wp.array4d(dtype=Any), ): @@ -174,7 +174,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -201,11 +201,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_id_field, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_id_field, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask, f], dim=f_pre.shape[1:], ) return f diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index ceb9830..288565d 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -12,8 +12,7 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig # Enum for implementation step class ImplementationStep(Enum): diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 9bd562f..cf8feab 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -7,7 +7,7 @@ import warp as wp from typing import Tuple -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend @@ -41,7 +41,7 @@ def _indices_to_tuple(indices): @Operator.register_backend(ComputeBackend.JAX) def jax_implementation( - self, indices, id_number, boundary_id_field, mask, start_index=None + self, indices, id_number, boundary_mask, mask, start_index=None ): dim = mask.ndim - 1 if start_index is None: @@ -56,15 +56,15 @@ def jax_implementation( indices_mask = np.logical_and.reduce(indices_mask) @jit - def compute_boundary_id_and_mask(boundary_id_field, mask): + def compute_boundary_id_and_mask(boundary_mask, mask): if dim == 2: - boundary_id_field = boundary_id_field.at[ + boundary_mask = boundary_mask.at[ 0, local_indices[0], local_indices[1] ].set(id_number) mask = mask.at[:, local_indices[0], local_indices[1]].set(True) if dim == 3: - boundary_id_field = boundary_id_field.at[ + boundary_mask = boundary_mask.at[ 0, local_indices[0], local_indices[1], local_indices[2] ].set(id_number) mask = mask.at[ @@ -72,9 +72,9 @@ def compute_boundary_id_and_mask(boundary_id_field, mask): ].set(True) mask = self.stream(mask) - return boundary_id_field, mask + return boundary_mask, mask - return compute_boundary_id_and_mask(boundary_id_field, mask) + return compute_boundary_id_and_mask(boundary_mask, mask) def _construct_warp(self): # Make constants for warp @@ -86,7 +86,7 @@ def _construct_warp(self): def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.int32, - boundary_id_field: wp.array3d(dtype=wp.uint8), + boundary_mask: wp.array3d(dtype=wp.uint8), mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): @@ -113,7 +113,7 @@ def kernel2d( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id_field[0, index[0], index[1]] = wp.uint8(id_number) + boundary_mask[0, index[0], index[1]] = wp.uint8(id_number) mask[l, push_index[0], push_index[1]] = True # Construct the warp 3D kernel @@ -121,7 +121,7 @@ def kernel2d( def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.int32, - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -151,7 +151,7 @@ def kernel3d( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id_field[0, index[0], index[1], index[2]] = wp.uint8( + boundary_mask[0, index[0], index[1], index[2]] = wp.uint8( id_number ) mask[l, push_index[0], push_index[1], push_index[2]] = True @@ -162,7 +162,7 @@ def kernel3d( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation( - self, indices, id_number, boundary_id_field, missing_mask, start_index=None + self, indices, id_number, boundary_mask, missing_mask, start_index=None ): if start_index is None: start_index = (0,) * self.velocity_set.d @@ -172,11 +172,11 @@ def warp_implementation( inputs=[ indices, id_number, - boundary_id_field, + boundary_mask, missing_mask, start_index, ], dim=indices.shape[1], ) - return boundary_id_field, missing_mask + return boundary_mask, missing_mask diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index e12206b..fa83c7d 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -8,7 +8,7 @@ from typing import Tuple from jax.numpy import where, einsum, full_like -from xlb.default_config import DefaultConfig + from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend @@ -38,14 +38,14 @@ def jax_implementation( upper_bound, direction, id_number, - boundary_id_field, + boundary_mask, mask, start_index=None, ): if start_index is None: start_index = (0,) * self.velocity_set.d - _, *dimensions = boundary_id_field.shape + _, *dimensions = boundary_mask.shape indices = [ (max(0, lb + start), min(dim, ub + start)) @@ -56,9 +56,9 @@ def jax_implementation( slices = [slice(None)] slices.extend(slice(lb, ub) for lb, ub in indices) - boundary_id_field = boundary_id_field.at[tuple(slices)].set(id_number) + boundary_mask = boundary_mask.at[tuple(slices)].set(id_number) - return boundary_id_field, None + return boundary_mask, None def _construct_warp(self): # Make constants for warp @@ -71,7 +71,7 @@ def kernel2d( upper_bound: wp.vec2i, direction: wp.vec2i, id_number: wp.uint8, - boundary_id_field: wp.array3d(dtype=wp.uint8), + boundary_mask: wp.array3d(dtype=wp.uint8), mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): @@ -80,7 +80,7 @@ def kernel2d( ub_x, ub_y = upper_bound.x + start_index.x, upper_bound.y + start_index.y if lb_x <= i < ub_x and lb_y <= j < ub_y: - boundary_id_field[0, i, j] = id_number + boundary_mask[0, i, j] = id_number @wp.kernel def kernel3d( @@ -88,7 +88,7 @@ def kernel3d( upper_bound: wp.vec3i, direction: wp.vec3i, id_number: wp.uint8, - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -105,7 +105,7 @@ def kernel3d( ) if lb_x <= i < ub_x and lb_y <= j < ub_y and lb_z <= k < ub_z: - boundary_id_field[0, i, j, k] = id_number + boundary_mask[0, i, j, k] = id_number kernel = kernel3d if self.velocity_set.d == 3 else kernel2d @@ -118,7 +118,7 @@ def warp_implementation( upper_bound, direction, id_number, - boundary_id_field, + boundary_mask, mask, start_index=None, ): @@ -133,11 +133,11 @@ def warp_implementation( upper_bound, direction, id_number, - boundary_id_field, + boundary_mask, mask, start_index, ], dim=mask.shape[1:], ) - return boundary_id_field, mask + return boundary_mask, mask diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index 20e630c..c2cfc30 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -8,7 +8,7 @@ import warp as wp from typing import Tuple -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend @@ -42,7 +42,7 @@ def kernel( origin: wp.vec3, spacing: wp.vec3, id_number: wp.int32, - boundary_id_field: wp.array4d(dtype=wp.uint8), + boundary_mask: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -64,9 +64,9 @@ def kernel( # Compute the maximum length max_length = wp.sqrt( - (spacing[0] * wp.float32(boundary_id_field.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(boundary_id_field.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(boundary_id_field.shape[3])) ** 2.0 + (spacing[0] * wp.float32(boundary_mask.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_mask.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_mask.shape[3])) ** 2.0 ) # evaluate if point is inside mesh @@ -87,7 +87,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_id_field[ + boundary_mask[ 0, push_index[0], push_index[1], push_index[2] ] = wp.uint8(id_number) mask[l, push_index[0], push_index[1], push_index[2]] = True @@ -101,7 +101,7 @@ def warp_implementation( origin, spacing, id_number, - boundary_id_field, + boundary_mask, mask, start_index=(0, 0, 0), ): @@ -122,11 +122,11 @@ def warp_implementation( origin, spacing, id_number, - boundary_id_field, + boundary_mask, mask, start_index, ], - dim=boundary_id_field.shape[1:], + dim=boundary_mask.shape[1:], ) - return boundary_id_field, mask + return boundary_mask, mask diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 3b006ab..c4fa62d 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -3,7 +3,6 @@ import warp as wp from typing import Any -from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision from xlb.operator import Operator @@ -37,67 +36,66 @@ def functional(f: Any, feq: Any): # Construct the warp kernel @wp.kernel - def kernel3d( - f: wp.array4d(dtype=Any), - feq: wp.array4d(dtype=Any), - fout: wp.array4d(dtype=Any), + def kernel2d( + f: wp.array3d(dtype=Any), + feq: wp.array3d(dtype=Any), + fout: wp.array3d(dtype=Any), ): # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this + i, j = wp.tid() + index = wp.vec2i(i, j) # Load needed values _f = _f_vec() _feq = _f_vec() for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1], index[2]] - _feq[l] = feq[l, index[0], index[1], index[2]] + _f[l] = f[l, index[0], index[1]] + _feq[l] = feq[l, index[0], index[1]] # Compute the collision _fout = functional(_f, _feq) # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1]] = _fout[l] # Construct the warp kernel @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - feq: wp.array3d(dtype=Any), - fout: wp.array3d(dtype=Any), + def kernel3d( + f: wp.array4d(dtype=Any), + feq: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), ): # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this # Load needed values _f = _f_vec() _feq = _f_vec() for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _feq[l] = feq[l, index[0], index[1]] + _f[l] = f[l, index[0], index[1], index[2]] + _feq[l] = feq[l, index[0], index[1], index[2]] # Compute the collision _fout = functional(_f, _feq) # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1], index[2]] = _fout[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout): + def warp_implementation(self, f, feq): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ f, feq, - fout, ], dim=f.shape[1:], ) diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index acf4538..00a8dfd 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -1,6 +1,7 @@ """ Base class for Collision operators """ + from xlb.velocity_set import VelocitySet from xlb.operator import Operator diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py index 1cf7459..42b601e 100644 --- a/xlb/operator/equilibrium/__init__.py +++ b/xlb/operator/equilibrium/__init__.py @@ -1 +1,4 @@ -from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium +from xlb.operator.equilibrium.quadratic_equilibrium import ( + Equilibrium, + QuadraticEquilibrium, +) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 7afacac..794d78d 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -8,8 +8,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator import Operator -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig class QuadraticEquilibrium(Equilibrium): """ diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 521c6d4..7fa309f 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -6,7 +6,7 @@ import warp as wp from typing import Tuple, Any -from xlb.default_config import DefaultConfig +from xlb import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index d51f637..0206416 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -6,26 +6,48 @@ import warp as wp from typing import Any +from xlb import DefaultConfig from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator +from xlb.operator.stream import Stream +from xlb.operator.collision import BGK, KBC +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep class IncompressibleNavierStokesStepper(Stepper): - """ - Class that handles the construction of lattice boltzmann stepping operator - """ + def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): + velocity_set = DefaultConfig.velocity_set + precision_policy = DefaultConfig.default_precision_policy + compute_backend = DefaultConfig.default_backend + + # Construct the collision operator + if collision_type == "BGK": + self.collision = BGK(omega, velocity_set, precision_policy, compute_backend) + elif collision_type == "KBC": + self.collision = KBC(omega, velocity_set, precision_policy, compute_backend) + + # Construct the operators + self.stream = Stream(velocity_set, precision_policy, compute_backend) + self.equilibrium = QuadraticEquilibrium( + velocity_set, precision_policy, compute_backend + ) + self.macroscopic = Macroscopic(velocity_set, precision_policy, compute_backend) + + operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] + + super().__init__(operators, boundary_conditions) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0, 4), donate_argnums=(1)) - def jax_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep): + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ - - # Cast to compute precision + # Cast to compute precisioni f_0 = self.precision_policy.cast_to_compute_jax(f_0) f_1 = self.precision_policy.cast_to_compute_jax(f_1) @@ -36,12 +58,7 @@ def jax_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision( - f_0, - feq, - rho, - u, - ) + f_post_collision = self.collision(f_0, feq) # Apply collision type boundary conditions for bc in self.boundary_conditions: @@ -49,7 +66,7 @@ def jax_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep f_0 = bc( f_0, f_post_collision, - boundary_id_field, + boundary_mask, missing_mask, ) @@ -62,7 +79,7 @@ def jax_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep f_1 = bc( f_post_collision, f_1, - boundary_id_field, + boundary_mask, missing_mask, ) @@ -84,12 +101,79 @@ def _construct_warp(self): _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) + @wp.kernel + def kernel2d( + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=Any), + missing_mask: wp.array3d(dtype=Any), + timestep: int, + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) # TODO warp should fix this + + # Get the boundary id and missing mask + _boundary_id = boundary_mask[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply streaming boundary conditions + if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + # Regular streaming + f_post_stream = self.stream.warp_functional(f_0, index) + elif _boundary_id == _equilibrium_bc: + # Equilibrium boundary condition + f_post_stream = self.equilibrium_bc.warp_functional( + f_0, _missing_mask, index + ) + elif _boundary_id == _do_nothing_bc: + # Do nothing boundary condition + f_post_stream = self.do_nothing_bc.warp_functional( + f_0, _missing_mask, index + ) + elif _boundary_id == _halfway_bounce_back_bc: + # Half way boundary condition + f_post_stream = self.halfway_bounce_back_bc.warp_functional( + f_0, _missing_mask, index + ) + + # Compute rho and u + rho, u = self.macroscopic.warp_functional(f_post_stream) + + # Compute equilibrium + feq = self.equilibrium.warp_functional(rho, u) + + # Apply collision + f_post_collision = self.collision.warp_functional( + f_post_stream, + feq, + ) + + # Apply collision type boundary conditions + if _boundary_id == _fullway_bounce_back_bc: + # Full way boundary condition + f_post_collision = self.fullway_bounce_back_bc.warp_functional( + f_post_stream, + f_post_collision, + _missing_mask, + ) + + # Set the output + for l in range(self.velocity_set.q): + f_1[l, index[0], index[1]] = f_post_collision[l] + # Construct the kernel @wp.kernel - def kernel( + def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), - boundary_id_field: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), timestep: int, ): @@ -98,7 +182,7 @@ def kernel( index = wp.vec3i(i, j, k) # TODO warp should fix this # Get the boundary id and missing mask - _boundary_id = boundary_id_field[0, index[0], index[1], index[2]] + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -134,12 +218,7 @@ def kernel( feq = self.equilibrium.warp_functional(rho, u) # Apply collision - f_post_collision = self.collision.warp_functional( - f_post_stream, - feq, - rho, - u, - ) + f_post_collision = self.collision.warp_functional(f_post_stream, feq) # Apply collision type boundary conditions if _boundary_id == _fullway_bounce_back_bc: @@ -154,17 +233,20 @@ def kernel( for l in range(self.velocity_set.q): f_1[l, index[0], index[1], index[2]] = f_post_collision[l] + # Return the correct kernel + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_id_field, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ f_0, f_1, - boundary_id_field, + boundary_mask, missing_mask, timestep, ], diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c4423ea..c11b39b 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,14 +1,18 @@ # Base class for all stepper operators +from ast import Raise from functools import partial import jax.numpy as jnp from jax import jit import warp as wp +from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.precision_caster import PrecisionCaster +from xlb.operator.equilibrium import Equilibrium +from xlb import DefaultConfig class Stepper(Operator): @@ -16,40 +20,43 @@ class Stepper(Operator): Class that handles the construction of lattice boltzmann stepping operator """ - def __init__( - self, - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions=[], - ): - # Add operators - self.collision = collision - self.stream = stream - self.equilibrium = equilibrium - self.macroscopic = macroscopic + def __init__(self, operators, boundary_conditions): + self.operators = operators self.boundary_conditions = boundary_conditions + # Get velocity set, precision policy, and compute backend + velocity_sets = set( + [op.velocity_set for op in self.operators if op is not None] + ) + assert ( + len(velocity_sets) < 2 + ), "All velocity sets must be the same. Got {}".format(velocity_sets) + velocity_set = ( + DefaultConfig.velocity_set if not velocity_sets else velocity_sets.pop() + ) - # Get all operators for checking - self.operators = [ - collision, - stream, - equilibrium, - macroscopic, - *self.boundary_conditions, - ] + precision_policies = set( + [op.precision_policy for op in self.operators if op is not None] + ) + assert ( + len(precision_policies) < 2 + ), "All precision policies must be the same. Got {}".format(precision_policies) + precision_policy = ( + DefaultConfig.default_precision_policy + if not precision_policies + else precision_policies.pop() + ) - # Get velocity set, precision policy, and compute backend - velocity_sets = set([op.velocity_set for op in self.operators]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - velocity_set = velocity_sets.pop() - precision_policies = set([op.precision_policy for op in self.operators]) - assert len(precision_policies) == 1, "All precision policies must be the same" - precision_policy = precision_policies.pop() - compute_backend = set([op.compute_backend for op in self.operators]) - assert len(compute_backend) == 1, "All compute backends must be the same" - compute_backend = compute_backend.pop() + compute_backends = set( + [op.compute_backend for op in self.operators if op is not None] + ) + assert ( + len(compute_backends) < 2 + ), "All compute backends must be the same. Got {}".format(compute_backends) + compute_backend = ( + DefaultConfig.default_backend + if not compute_backends + else compute_backends.pop() + ) # Add boundary conditions # Warp cannot handle lists of functions currently @@ -59,12 +66,18 @@ def __init__( ############################################ from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC - from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + from xlb.operator.boundary_condition.bc_halfway_bounce_back import ( + HalfwayBounceBackBC, + ) + from xlb.operator.boundary_condition.bc_fullway_bounce_back import ( + FullwayBounceBackBC, + ) + self.equilibrium_bc = None self.do_nothing_bc = None self.halfway_bounce_back_bc = None self.fullway_bounce_back_bc = None + for bc in boundary_conditions: if isinstance(bc, EquilibriumBC): self.equilibrium_bc = bc @@ -74,32 +87,36 @@ def __init__( self.halfway_bounce_back_bc = bc elif isinstance(bc, FullwayBounceBackBC): self.fullway_bounce_back_bc = bc + if self.equilibrium_bc is None: + # Select the equilibrium operator based on its type self.equilibrium_bc = EquilibriumBC( rho=1.0, u=(0.0, 0.0, 0.0), - equilibrium_operator=self.equilibrium, + equilibrium_operator=next( + (op for op in self.operators if isinstance(op, Equilibrium)), None + ), velocity_set=velocity_set, precision_policy=precision_policy, - compute_backend=compute_backend + compute_backend=compute_backend, ) if self.do_nothing_bc is None: self.do_nothing_bc = DoNothingBC( velocity_set=velocity_set, precision_policy=precision_policy, - compute_backend=compute_backend + compute_backend=compute_backend, ) if self.halfway_bounce_back_bc is None: self.halfway_bounce_back_bc = HalfwayBounceBackBC( velocity_set=velocity_set, precision_policy=precision_policy, - compute_backend=compute_backend + compute_backend=compute_backend, ) if self.fullway_bounce_back_bc is None: self.fullway_bounce_back_bc = FullwayBounceBackBC( velocity_set=velocity_set, precision_policy=precision_policy, - compute_backend=compute_backend + compute_backend=compute_backend, ) ############################################ diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index db8a422..25b0583 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -78,3 +78,19 @@ def store_precision(self): return Precision.FP16 else: raise ValueError("Invalid precision policy") + + def cast_to_compute_jax(self, array): + compute_precision = self.compute_precision + return jnp.array(array, dtype=compute_precision.jax_dtype) + + def cast_to_store_jax(self, array): + store_precision = self.store_precision + return jnp.array(array, dtype=store_precision.jax_dtype) + + def cast_to_compute_warp(self, array): + compute_precision = self.compute_precision + return wp.array(array, dtype=compute_precision.wp_dtype) + + def cast_to_store_warp(self, array): + store_precision = self.store_precision + return wp.array(array, dtype=store_precision.wp_dtype) \ No newline at end of file diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 98f7968..57f65c0 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from xlb.compute_backend import ComputeBackend -from xlb.default_config import DefaultConfig - +from xlb import DefaultConfig from xlb.precision_policy.jax_precision_policy import ( JaxFp32Fp32, JaxFp32Fp16, diff --git a/xlb/solver/__init__.py b/xlb/solver/__init__.py deleted file mode 100644 index 0304fda..0000000 --- a/xlb/solver/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from xlb.solver.solver import Solver -from xlb.solver.nse import IncompressibleNavierStokesSolver diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py deleted file mode 100644 index 1ffadf8..0000000 --- a/xlb/solver/nse.py +++ /dev/null @@ -1,116 +0,0 @@ -# Base class for all stepper operators - -from functools import partial -from jax import jit -import jax - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend -from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium -from xlb.operator.collision.bgk import BGK -from xlb.operator.collision.kbc import KBC -from xlb.operator.stream import Stream -from xlb.operator.macroscopic import Macroscopic -from xlb.solver.solver import Solver -from xlb.operator import Operator - - -class IncompressibleNavierStokesSolver(Solver): - - _equilibrium_registry = { - "Quadratic": QuadraticEquilibrium, - } - _collision_registry = { - "BGK": BGK, - "KBC": KBC, - } - - def __init__( - self, - omega: float, - domain_shape: tuple[int, int, int], - collision="BGK", - equilibrium="Quadratic", - boundary_conditions=[], - velocity_set = None, - precision_policy=None, - compute_backend=None, - ): - super().__init__( - domain_shape=domain_shape, - boundary_conditions=boundary_conditions, - velocity_set=velocity_set, - compute_backend=compute_backend, - precision_policy=precision_policy, - ) - - # Set omega - self.omega = omega - - # Create operators - self.collision = self._get_collision(collision)( - omega=self.omega, - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.stream = Stream( - velocity_set=self.velocity_set, - precision_policy=self.precision_policy, - compute_backend=self.compute_backend, - ) - self.equilibrium = self._get_equilibrium(equilibrium)( - velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend - ) - self.macroscopic = Macroscopic( - velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend - ) - - # Create stepper operator - self.stepper = IncompressibleNavierStokesStepper( - collision=self.collision, - stream=self.stream, - equilibrium=self.equilibrium, - macroscopic=self.macroscopic, - boundary_conditions=self.boundary_conditions, - forcing=None, - ) - - def monitor(self): - pass - - def run(self, steps: int, monitor_frequency: int = 1, compute_mlups: bool = False): - - # Run steps - for _ in range(steps): - # Run step - self.stepper( - f0=self.grid.get_field("f0"), - f1=self.grid.get_field("f1") - ) - self.grid.swap_fields("f0", "f1") - - def checkpoint(self): - raise NotImplementedError("Checkpointing not yet implemented") - - def _get_collision(self, collision: str): - if isinstance(collision, str): - try: - return self._collision_registry[collision] - except KeyError: - raise ValueError(f"Collision {collision} not recognized for incompressible Navier-Stokes solver") - elif issubclass(collision, Operator): - return collision - else: - raise ValueError(f"Collision {collision} not recognized for incompressible Navier-Stokes solver") - - def _get_equilibrium(self, equilibrium: str): - if isinstance(equilibrium, str): - try: - return self._equilibrium_registry[equilibrium] - except KeyError: - raise ValueError(f"Equilibrium {equilibrium} not recognized for incompressible Navier-Stokes solver") - elif issubclass(equilibrium, Operator): - return equilibrium - else: - raise ValueError(f"Equilibrium {equilibrium} not recognized for incompressible Navier-Stokes solver") diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py deleted file mode 100644 index a826ab7..0000000 --- a/xlb/solver/solver.py +++ /dev/null @@ -1,23 +0,0 @@ -from xlb.default_config import DefaultConfig -from xlb.operator.operator import Operator - - -class Solver(Operator): - """ - Abstract class for the construction of lattice boltzmann solvers - """ - - def __init__( - self, - domain_shape: tuple[int, int, int], - boundary_conditions=[], - velocity_set=None, - precision_policy=None, - compute_backend=None, - ): - # Set parameters - self.domain_shape = domain_shape - self.boundary_conditions = boundary_conditions - self.velocity_set = velocity_set or DefaultConfig.velocity_set - self.precision_policy = precision_policy or DefaultConfig.precision_policy - self.compute_backend = compute_backend or DefaultConfig.compute_backend diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index 9b09dd4..178c89e 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -12,7 +12,6 @@ class D2Q9(VelocitySet): D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ - def __init__(self): # Construct the velocity vectors and weights cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 5debdea..7f69019 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -13,7 +13,6 @@ class D3Q19(VelocitySet): D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self): # Construct the velocity vectors and weights c = np.array( diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py index 702acf4..ac908eb 100644 --- a/xlb/velocity_set/d3q27.py +++ b/xlb/velocity_set/d3q27.py @@ -13,7 +13,6 @@ class D3Q27(VelocitySet): D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self): # Construct the velocity vectors and weights c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T From b36b053a7f8931d5419d408758f3e8eabc35e0d0 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Tue, 18 Jun 2024 14:25:17 -0400 Subject: [PATCH 037/144] WIP distribute function. Warp backend closure issues are mostly fixed but one remaining. --- examples/cfd/lid_driven_cavity.py | 19 +++-- .../macroscopic/test_macroscopic_warp.py | 4 +- xlb/__init__.py | 4 +- xlb/distribute/__init__.py | 1 + xlb/distribute/distribute.py | 79 +++++++++++++++++++ xlb/operator/collision/bgk.py | 3 +- 6 files changed, 98 insertions(+), 12 deletions(-) create mode 100644 xlb/distribute/__init__.py create mode 100644 xlb/distribute/distribute.py diff --git a/examples/cfd/lid_driven_cavity.py b/examples/cfd/lid_driven_cavity.py index 202d155..962e2e6 100644 --- a/examples/cfd/lid_driven_cavity.py +++ b/examples/cfd/lid_driven_cavity.py @@ -7,8 +7,10 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image +import warp as wp +import jax.numpy as jnp -backend = ComputeBackend.JAX +backend = ComputeBackend.WARP velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 @@ -58,19 +60,20 @@ f_0, f_1 = f_1, f_0 -# Write the results -macro = Macroscopic() +# Write the results. We'll use JAX backend for the post-processing +if not isinstance(f_0, jnp.ndarray): + f_0 = wp.to_jax(f_0) + +macro = Macroscopic(compute_backend=ComputeBackend.JAX) rho, u = macro(f_0) # remove boundary cells rho = rho[:, 1:-1, 1:-1] u = u[:, 1:-1, 1:-1] - -u_magnitude = (u[0]**2 + u[1]**2)**0.5 +u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5 fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} - -save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") -save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") \ No newline at end of file +save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") +save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 69c09ac..7a4a8cd 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -25,8 +25,8 @@ def init_xlb_env(velocity_set): (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 2.0), (2, xlb.velocity_set.D2Q9, (50, 50), 1.1, 2.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0, 0.0), - # (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. - # (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. ], ) def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): diff --git a/xlb/__init__.py b/xlb/__init__.py index 06e9f11..be63d06 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -15,7 +15,6 @@ import xlb.operator.stream import xlb.operator.boundary_condition import xlb.operator.macroscopic - # Grids import xlb.grid @@ -24,3 +23,6 @@ # Utils import xlb.utils + +# Distributed computing +import xlb.distribute \ No newline at end of file diff --git a/xlb/distribute/__init__.py b/xlb/distribute/__init__.py new file mode 100644 index 0000000..33e0d2b --- /dev/null +++ b/xlb/distribute/__init__.py @@ -0,0 +1 @@ +from .distribute import distribute \ No newline at end of file diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py new file mode 100644 index 0000000..96352b3 --- /dev/null +++ b/xlb/distribute/distribute.py @@ -0,0 +1,79 @@ +from jax.sharding import PartitionSpec as P +from xlb.operator import Operator +from xlb import DefaultConfig +from jax import lax, sharding +from jax import jit +import warp as wp + + +def distribute( + operator: Operator, grid, velocity_set, num_results=2, ops="permute" +) -> Operator: + # Define the sharded operator + def _sharded_operator(*args): + results = operator(*args) + + if not isinstance(results, tuple): + results = (results,) + + if DefaultConfig.default_backend == DefaultConfig.ComputeBackend.WARP: + for i, result in enumerate(results): + if isinstance(result, wp.array): + # Convert to jax array (zero copy) + results[i] = wp.to_jax(result) + + if ops == "permute": + # Define permutation rules for right and left communication + rightPerm = [(i, (i + 1) % grid.nDevices) for i in range(grid.nDevices)] + leftPerm = [((i + 1) % grid.nDevices, i) for i in range(grid.nDevices)] + + right_comm = [ + lax.ppermute( + arg[velocity_set.right_indices, :1, ...], + perm=rightPerm, + axis_name="x", + ) + for arg in results + ] + left_comm = [ + lax.ppermute( + arg[velocity_set.left_indices, -1:, ...], + perm=leftPerm, + axis_name="x", + ) + for arg in results + ] + + updated_results = [] + for result in results: + result = result.at[velocity_set.right_indices, :1, ...].set( + right_comm.pop(0) + ) + result = result.at[velocity_set.left_indices, -1:, ...].set( + left_comm.pop(0) + ) + updated_results.append(result) + + return ( + tuple(updated_results) + if len(updated_results) > 1 + else updated_results[0] + ) + else: + raise NotImplementedError(f"Operation {ops} not implemented") + + in_specs = (P(*((None, "x") + (grid.dim - 1) * (None,)))) * len(num_results) + out_specs = in_specs + + distributed_operator = sharding.shard_map( + _sharded_operator, + mesh=grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + + if DefaultConfig.default_backend == DefaultConfig.ComputeBackend.JAX: + distributed_operator = jit(distributed_operator) + + return distributed_operator diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index c4fa62d..15c0ad0 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -89,13 +89,14 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq): + def warp_implementation(self, f, feq, fout): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ f, feq, + fout, ], dim=f.shape[1:], ) From 346b2e5191a9241fd95b9feaa01a179016ae8e1f Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 27 Jun 2024 15:27:02 -0400 Subject: [PATCH 038/144] Fixed distributed errors and added mlups 3D --- ...iven_cavity.py => lid_driven_cavity_2d.py} | 6 +- .../cfd/lid_driven_cavity_2d_distributed.py | 87 ++++++++++++++ examples/performance/mlups3d.py | 51 -------- examples/performance/mlups_3d.py | 109 ++++++++++++++++++ xlb/distribute/distribute.py | 88 +++++++------- xlb/grid/jax_grid.py | 14 +-- xlb/helper/boundary_conditions.py | 6 +- xlb/helper/nse_solver.py | 1 - .../boundary_masker/planar_boundary_masker.py | 16 +-- 9 files changed, 258 insertions(+), 120 deletions(-) rename examples/cfd/{lid_driven_cavity.py => lid_driven_cavity_2d.py} (96%) create mode 100644 examples/cfd/lid_driven_cavity_2d_distributed.py delete mode 100644 examples/performance/mlups3d.py create mode 100644 examples/performance/mlups_3d.py diff --git a/examples/cfd/lid_driven_cavity.py b/examples/cfd/lid_driven_cavity_2d.py similarity index 96% rename from examples/cfd/lid_driven_cavity.py rename to examples/cfd/lid_driven_cavity_2d.py index 962e2e6..84ed1a7 100644 --- a/examples/cfd/lid_driven_cavity.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -10,7 +10,7 @@ import warp as wp import jax.numpy as jnp -backend = ComputeBackend.WARP +backend = ComputeBackend.JAX velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 @@ -20,7 +20,7 @@ default_precision_policy=precision_policy, ) -grid_size = 512 +grid_size = 128 grid_shape = (grid_size, grid_size) grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) @@ -55,7 +55,7 @@ omega, boundary_conditions=boundary_conditions ) -for i in range(50000): +for i in range(500): f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) f_0, f_1 = f_1, f_0 diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py new file mode 100644 index 0000000..3f1d50e --- /dev/null +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -0,0 +1,87 @@ +from math import dist +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.utils import save_fields_vtk, save_image +from xlb.distribute import distribute +import warp as wp +import jax.numpy as jnp + +backend = ComputeBackend.JAX +velocity_set = xlb.velocity_set.D2Q9() +precision_policy = PrecisionPolicy.FP32FP32 + +xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, +) + +grid_size = 512 +grid_shape = (grid_size, grid_size) + +grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) + +# Velocity on top face (2D) +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, + missing_mask, + grid_shape, + EquilibriumBC.id, + ["top"], +) + +# Wall on all other faces (2D) +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, + missing_mask, + grid_shape, + FullwayBounceBackBC.id, + ["bottom", "left", "right"], + backend=ComputeBackend.JAX, +) + +bc_eq = QuadraticEquilibrium(compute_backend=backend) + +bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), equilibrium_operator=bc_eq) + +bc_walls = FullwayBounceBackBC(compute_backend=backend) + + +f_0 = initialize_eq(f_0, grid, velocity_set, backend=ComputeBackend.JAX) +boundary_conditions = [bc_top, bc_walls] +omega = 1.6 + +stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=boundary_conditions +) +distributed_stepper = distribute( + stepper, grid, velocity_set, sharding_flags=(True, True, True, True, False) +) +for i in range(5000): + f_1 = distributed_stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_0, f_1 = f_1, f_0 + + +# Write the results. We'll use JAX backend for the post-processing +if not isinstance(f_0, jnp.ndarray): + f_0 = wp.to_jax(f_0) + +macro = Macroscopic(compute_backend=ComputeBackend.JAX) + +rho, u = macro(f_0) + +# remove boundary cells +rho = rho[:, 1:-1, 1:-1] +u = u[:, 1:-1, 1:-1] +u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5 + +fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} + +save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") +save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") diff --git a/examples/performance/mlups3d.py b/examples/performance/mlups3d.py deleted file mode 100644 index f044c33..0000000 --- a/examples/performance/mlups3d.py +++ /dev/null @@ -1,51 +0,0 @@ -import xlb -import time -import jax -import argparse -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import Fp32Fp32 -from xlb.operator.initializer import EquilibriumInitializer - -from xlb.helper import IncompressibleNavierStokes -from xlb.grid import grid_factory - -parser = argparse.ArgumentParser( - description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" -) -parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") -parser.add_argument("num_steps", type=int, help="Timestep for the simulation") - -args = parser.parse_args() - -cube_edge = args.cube_edge -num_steps = args.num_steps - - -xlb.init( - precision_policy=Fp32Fp32, - compute_backend=ComputeBackend.PALLAS, - velocity_set=xlb.velocity_set.D3Q19, -) - -grid_shape = (cube_edge, cube_edge, cube_edge) -grid = Grid.create(grid_shape) - -f = grid.create_field(cardinality=19) - -print("f shape", f.shape) - -solver = IncompressibleNavierStokes(grid, omega=1.0) - -# Ahead-of-Time Compilation to remove JIT overhead -f = solver.step(f, timestep=0) - -start_time = time.time() - -for step in range(num_steps): - f = solver.step(f, timestep=step) - -end_time = time.time() -total_lattice_updates = cube_edge**3 * num_steps -total_time_seconds = end_time - start_time -mlups = (total_lattice_updates / total_time_seconds) / 1e6 -print(f"MLUPS: {mlups}") diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py new file mode 100644 index 0000000..6131a0d --- /dev/null +++ b/examples/performance/mlups_3d.py @@ -0,0 +1,109 @@ +from turtle import back +import xlb +import argparse +import time +import warp as wp +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.distribute import distribute + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" + ) + parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") + parser.add_argument("num_steps", type=int, help="Timestep for the simulation") + parser.add_argument("backend", type=str, help="Backend for the simulation (jax or warp)") + parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)") + return parser.parse_args() + +def setup_simulation(args): + backend = ComputeBackend.JAX if args.backend == "jax" else ComputeBackend.WARP + precision_policy_map = { + "fp32/fp32": PrecisionPolicy.FP32FP32, + "fp64/fp64": PrecisionPolicy.FP64FP64, + "fp64/fp32": PrecisionPolicy.FP64FP32, + "fp32/fp16": PrecisionPolicy.FP32FP16 + } + precision_policy = precision_policy_map.get(args.precision) + if precision_policy is None: + raise ValueError("Invalid precision") + + xlb.init( + velocity_set=xlb.velocity_set.D3Q19(), + default_backend=backend, + default_precision_policy=precision_policy, + ) + + return backend, precision_policy + +def create_grid_and_fields(cube_edge): + grid_shape = (cube_edge, cube_edge, cube_edge) + grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) + + # Velocity on top face (3D) + boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["top"] + ) + + # Wall on all other faces (3D) + boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, + missing_mask, + grid_shape, + FullwayBounceBackBC.id, + ["bottom", "left", "right", "front", "back"], + ) + + return grid, f_0, f_1, missing_mask, boundary_mask + +def setup_boundary_conditions(): + bc_eq = QuadraticEquilibrium() + bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=bc_eq) + bc_walls = FullwayBounceBackBC() + return [bc_top, bc_walls] + +def run_simulation(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): + omega = 1.0 + stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=setup_boundary_conditions() + ) + + if backend == ComputeBackend.JAX: + stepper = distribute( + stepper, grid, xlb.velocity_set.D3Q19(), sharding_flags=(True, True, True, True, False) + ) + + start_time = time.time() + + for i in range(num_steps): + f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_0, f_1 = f_1, f_0 + wp.synchronize() + + end_time = time.time() + return end_time - start_time + +def calculate_mlups(cube_edge, num_steps, elapsed_time): + total_lattice_updates = cube_edge**3 * num_steps + mlups = (total_lattice_updates / elapsed_time) / 1e6 + return mlups + +def main(): + args = parse_arguments() + backend, precision_policy = setup_simulation(args) + grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) + f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) + + elapsed_time = run_simulation(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) + mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) + + print(f"Simulation completed in {elapsed_time:.2f} seconds") + print(f"MLUPs: {mlups:.2f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index 96352b3..bb72072 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -1,79 +1,73 @@ from jax.sharding import PartitionSpec as P from xlb.operator import Operator from xlb import DefaultConfig -from jax import lax, sharding +from xlb import ComputeBackend +from jax import lax +from jax.experimental.shard_map import shard_map from jax import jit +import jax.numpy as jnp import warp as wp +from typing import Tuple def distribute( - operator: Operator, grid, velocity_set, num_results=2, ops="permute" + operator: Operator, + grid, + velocity_set, + sharding_flags: Tuple[bool, ...], + num_results=1, + ops="permute", ) -> Operator: # Define the sharded operator def _sharded_operator(*args): - results = operator(*args) - - if not isinstance(results, tuple): - results = (results,) - - if DefaultConfig.default_backend == DefaultConfig.ComputeBackend.WARP: - for i, result in enumerate(results): - if isinstance(result, wp.array): - # Convert to jax array (zero copy) - results[i] = wp.to_jax(result) + result = operator(*args) if ops == "permute": # Define permutation rules for right and left communication rightPerm = [(i, (i + 1) % grid.nDevices) for i in range(grid.nDevices)] leftPerm = [((i + 1) % grid.nDevices, i) for i in range(grid.nDevices)] - right_comm = [ - lax.ppermute( - arg[velocity_set.right_indices, :1, ...], - perm=rightPerm, - axis_name="x", - ) - for arg in results - ] - left_comm = [ - lax.ppermute( - arg[velocity_set.left_indices, -1:, ...], - perm=leftPerm, - axis_name="x", - ) - for arg in results - ] + right_comm = lax.ppermute( + result[velocity_set.right_indices, :1, ...], + perm=rightPerm, + axis_name="x", + ) - updated_results = [] - for result in results: - result = result.at[velocity_set.right_indices, :1, ...].set( - right_comm.pop(0) - ) - result = result.at[velocity_set.left_indices, -1:, ...].set( - left_comm.pop(0) - ) - updated_results.append(result) + left_comm = lax.ppermute( + result[velocity_set.left_indices, -1:, ...], + perm=leftPerm, + axis_name="x", + ) - return ( - tuple(updated_results) - if len(updated_results) > 1 - else updated_results[0] + result = result.at[velocity_set.right_indices, :1, ...].set( + right_comm + ) + result = result.at[velocity_set.left_indices, -1:, ...].set( + left_comm ) + + return result else: raise NotImplementedError(f"Operation {ops} not implemented") - in_specs = (P(*((None, "x") + (grid.dim - 1) * (None,)))) * len(num_results) - out_specs = in_specs + in_specs = tuple( + P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() + for flag in sharding_flags + ) + out_specs = tuple( + P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) + ) - distributed_operator = sharding.shard_map( + if len(out_specs) == 1: + out_specs = out_specs[0] + + distributed_operator = shard_map( _sharded_operator, mesh=grid.global_mesh, in_specs=in_specs, out_specs=out_specs, check_rep=False, ) - - if DefaultConfig.default_backend == DefaultConfig.ComputeBackend.JAX: - distributed_operator = jit(distributed_operator) + distributed_operator = jit(distributed_operator) return distributed_operator diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 64dc33a..d890fe3 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -24,20 +24,20 @@ def __init__(self, shape): def _initialize_backend(self): self.nDevices = jax.device_count() self.backend = jax.default_backend() - device_mesh = ( + self.device_mesh = ( mesh_utils.create_device_mesh((1, self.nDevices, 1)) if self.dim == 2 else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) ) - global_mesh = ( - Mesh(device_mesh, axis_names=("cardinality", "x", "y")) + self.global_mesh = ( + Mesh(self.device_mesh, axis_names=("cardinality", "x", "y")) if self.dim == 2 - else Mesh(device_mesh, axis_names=("cardinality", "x", "y", "z")) + else Mesh(self.device_mesh, axis_names=("cardinality", "x", "y", "z")) ) self.sharding = ( - NamedSharding(global_mesh, P("cardinality", "x", "y")) + NamedSharding(self.global_mesh, P("cardinality", "x", "y")) if self.dim == 2 - else NamedSharding(global_mesh, P("cardinality", "x", "y", "z")) + else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) ) def create_field( @@ -65,4 +65,4 @@ def create_field( jax.default_device = jax.devices()[0] return jax.make_array_from_single_device_arrays( full_shape, self.sharding, arrays - ) + ) \ No newline at end of file diff --git a/xlb/helper/boundary_conditions.py b/xlb/helper/boundary_conditions.py index 3178bc2..d27a842 100644 --- a/xlb/helper/boundary_conditions.py +++ b/xlb/helper/boundary_conditions.py @@ -1,7 +1,7 @@ from xlb.operator.boundary_masker import PlanarBoundaryMasker -def assign_bc_id_box_faces(boundary_mask, missing_mask, shape, bc_id, sides): +def assign_bc_id_box_faces(boundary_mask, missing_mask, shape, bc_id, sides, backend=None): """ Assign boundary conditions for specified sides of 2D and 3D boxes using planar_boundary_masker function. @@ -19,7 +19,7 @@ def assign_bc_id_box_faces(boundary_mask, missing_mask, shape, bc_id, sides): Valid values for 3D are 'bottom', 'top', 'front', 'back', 'left', 'right'. """ - planar_boundary_masker = PlanarBoundaryMasker() + planar_boundary_masker = PlanarBoundaryMasker(compute_backend=backend) def apply(lower_bound, upper_bound, direction, reference=(0, 0, 0)): nonlocal boundary_mask, missing_mask, planar_boundary_masker @@ -66,4 +66,4 @@ def apply(lower_bound, upper_bound, direction, reference=(0, 0, 0)): else: raise ValueError("Unsupported dimensions: {}".format(dimensions)) - return boundary_mask, missing_mask + return boundary_mask, missing_mask \ No newline at end of file diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index 7fc11fc..4462edc 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -26,4 +26,3 @@ def create_nse_fields( return grid, f_0, f_1, missing_mask, boundary_mask - diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py index fa83c7d..ed20604 100644 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -39,7 +39,7 @@ def jax_implementation( direction, id_number, boundary_mask, - mask, + missing_mask, start_index=None, ): if start_index is None: @@ -58,7 +58,7 @@ def jax_implementation( slices.extend(slice(lb, ub) for lb, ub in indices) boundary_mask = boundary_mask.at[tuple(slices)].set(id_number) - return boundary_mask, None + return boundary_mask, missing_mask def _construct_warp(self): # Make constants for warp @@ -72,7 +72,7 @@ def kernel2d( direction: wp.vec2i, id_number: wp.uint8, boundary_mask: wp.array3d(dtype=wp.uint8), - mask: wp.array3d(dtype=wp.bool), + missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): i, j = wp.tid() @@ -89,7 +89,7 @@ def kernel3d( direction: wp.vec3i, id_number: wp.uint8, boundary_mask: wp.array4d(dtype=wp.uint8), - mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): i, j, k = wp.tid() @@ -119,7 +119,7 @@ def warp_implementation( direction, id_number, boundary_mask, - mask, + missing_mask, start_index=None, ): if start_index is None: @@ -134,10 +134,10 @@ def warp_implementation( direction, id_number, boundary_mask, - mask, + missing_mask, start_index, ], - dim=mask.shape[1:], + dim=missing_mask.shape[1:], ) - return boundary_mask, mask + return boundary_mask, missing_mask From 4c5b1c2da0b0a8dff9df2283461b91bdba2e49be Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 28 Jun 2024 13:09:23 -0400 Subject: [PATCH 039/144] Fixed bugs + wind tunnel and flow over sphere examples --- examples/cfd/flow_past_sphere_3d.py | 111 +++++++++++++++ examples/cfd/windtunnel_3d.py | 127 ++++++++++++++++++ requirements.txt | 2 +- xlb/grid/jax_grid.py | 2 +- .../boundary_condition/bc_do_nothing.py | 7 +- .../bc_halfway_bounce_back.py | 6 +- .../indices_boundary_masker.py | 20 ++- xlb/utils/utils.py | 10 +- 8 files changed, 260 insertions(+), 25 deletions(-) create mode 100644 examples/cfd/flow_past_sphere_3d.py create mode 100644 examples/cfd/windtunnel_3d.py diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py new file mode 100644 index 0000000..3175cfa --- /dev/null +++ b/examples/cfd/flow_past_sphere_3d.py @@ -0,0 +1,111 @@ +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import ( + FullwayBounceBackBC, + EquilibriumBC, + DoNothingBC, +) +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.utils import save_fields_vtk, save_image +import warp as wp +import numpy as np +import jax.numpy as jnp + +backend = ComputeBackend.WARP +velocity_set = xlb.velocity_set.D3Q19() +precision_policy = PrecisionPolicy.FP32FP32 + +xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, +) + +grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 +grid_shape = (grid_size_x, grid_size_y, grid_size_z) + +grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) + +# Velocity on left face (3D) +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["left"] +) + + +# Wall on all other faces (3D) except right +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, + missing_mask, + grid_shape, + FullwayBounceBackBC.id, + ["bottom", "right", "front", "back"], +) + +# Do nothing on right face +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, missing_mask, grid_shape, DoNothingBC.id, ["right"] +) + +bc_eq = QuadraticEquilibrium() +bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=bc_eq) +bc_walls = FullwayBounceBackBC() +bc_do_nothing = DoNothingBC() + + +sphere_radius = grid_size_y // 12 +x = np.arange(grid_size_x) +y = np.arange(grid_size_y) +z = np.arange(grid_size_z) +X, Y, Z = np.meshgrid(x, y, z, indexing='ij') +indices = np.where( + (X - grid_size_x // 6) ** 2 + + (Y - grid_size_y // 2) ** 2 + + (Z - grid_size_z // 2) ** 2 + < sphere_radius**2 +) +indices = np.array(indices) + +# Set boundary conditions on the indices +indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=backend, +) + +boundary_mask, missing_mask = indices_boundary_masker( + indices, FullwayBounceBackBC.id, boundary_mask, missing_mask, (0, 0, 0) +) + +f_0 = initialize_eq(f_0, grid, velocity_set, backend) +boundary_conditions = [bc_left, bc_walls, bc_do_nothing] +omega = 1.8 + +stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=boundary_conditions +) + +for i in range(10000): + f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_0, f_1 = f_1, f_0 + + +# Write the results. We'll use JAX backend for the post-processing +if not isinstance(f_0, jnp.ndarray): + f_0 = wp.to_jax(f_0) + +macro = Macroscopic(compute_backend=ComputeBackend.JAX) + +rho, u = macro(f_0) + +# remove boundary cells +u = u[:, 1:-1, 1:-1, 1:-1] +u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 + +fields = {"u_magnitude": u_magnitude} + +save_fields_vtk(fields, timestep=i) +save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py new file mode 100644 index 0000000..7611c0e --- /dev/null +++ b/examples/cfd/windtunnel_3d.py @@ -0,0 +1,127 @@ +import xlb +import trimesh +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import ( + FullwayBounceBackBC, + EquilibriumBC, + DoNothingBC, +) +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.macroscopic import Macroscopic +from xlb.utils import save_fields_vtk, save_image +import warp as wp +import numpy as np +import jax.numpy as jnp + +backend = ComputeBackend.WARP +velocity_set = xlb.velocity_set.D3Q19() +precision_policy = PrecisionPolicy.FP32FP32 + +xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, +) + +grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 +grid_shape = (grid_size_x, grid_size_y, grid_size_z) + +grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) + +# Velocity on left face (3D) +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["left"] +) + + +# Wall on all other faces (3D) except right +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, + missing_mask, + grid_shape, + FullwayBounceBackBC.id, + ["bottom", "right", "front", "back"], +) + +# Do nothing on right face +boundary_mask, missing_mask = assign_bc_id_box_faces( + boundary_mask, missing_mask, grid_shape, DoNothingBC.id, ["right"] +) + +prescribed_vel = 0.02 +bc_eq = QuadraticEquilibrium() +bc_left = EquilibriumBC( + rho=1.0, u=(prescribed_vel, 0.0, 0.0), equilibrium_operator=bc_eq +) +bc_walls = FullwayBounceBackBC() +bc_do_nothing = DoNothingBC() + + +def voxelize_stl(stl_filename, length_lbm_unit): + mesh = trimesh.load_mesh(stl_filename, process=False) + length_phys_unit = mesh.extents.max() + pitch = length_phys_unit / length_lbm_unit + mesh_voxelized = mesh.voxelized(pitch=pitch) + mesh_matrix = mesh_voxelized.matrix + return mesh_matrix, pitch + + +stl_filename = "../stl-files/DrivAer-Notchback.stl" +car_length_lbm_unit = grid_size_x / 4 +car_voxelized, pitch = voxelize_stl(stl_filename, car_length_lbm_unit) + +car_area = np.prod(car_voxelized.shape[1:]) +tx, ty, tz = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape +shift = [tx // 4, ty // 2, 0] +indices = np.argwhere(car_voxelized) + shift + +indices = np.array(indices).T + +# Set boundary conditions on the indices +indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=backend, +) + +boundary_mask, missing_mask = indices_boundary_masker( + indices, FullwayBounceBackBC.id, boundary_mask, missing_mask, (0, 0, 0) +) + +f_0 = initialize_eq(f_0, grid, velocity_set, backend) +boundary_conditions = [bc_left, bc_walls, bc_do_nothing] + +clength = grid_size_x - 1 +Re = 10000.0 + +visc = prescribed_vel * clength / Re +omega = 1.0 / (3.0 * visc + 0.5) + +stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=boundary_conditions +) + +for i in range(100000): + f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_0, f_1 = f_1, f_0 + + +# Write the results. We'll use JAX backend for the post-processing +if not isinstance(f_0, jnp.ndarray): + f_0 = wp.to_jax(f_0) + +macro = Macroscopic(compute_backend=ComputeBackend.JAX) + +rho, u = macro(f_0) + +# remove boundary cells +u = u[:, 1:-1, 1:-1, 1:-1] +u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 + +fields = {"u_magnitude": u_magnitude} + +save_fields_vtk(fields, timestep=i) +save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) diff --git a/requirements.txt b/requirements.txt index 5f356bd..ebae946 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ matplotlib==3.8.0 numpy==1.26.1 pyvista==0.43.4 Rtree==1.0.1 -trimesh==4.2.4 +trimesh==4.4.1 orbax-checkpoint==0.4.1 termcolor==2.3.0 PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index d890fe3..289790a 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -46,7 +46,7 @@ def create_field( dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16, Precision.BOOL] = None, fill_value=None, ): - sharding_dim = self.shape[-1] // self.nDevices + sharding_dim = self.shape[0] // self.nDevices device_shape = (cardinality, sharding_dim, *self.shape[1:]) full_shape = (cardinality, *self.shape) arrays = [] diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 37a91ab..9992898 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -33,9 +33,9 @@ class DoNothingBC(BoundaryCondition): def __init__( self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): super().__init__( ImplementationStep.STREAMING, @@ -48,7 +48,6 @@ def __init__( @partial(jit, static_argnums=(0)) def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = boundary[:, None, None, None] return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index d748bdd..b8ded63 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -34,9 +34,9 @@ class HalfwayBounceBackBC(BoundaryCondition): def __init__( self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): # Call the parent constructor super().__init__( diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index cf8feab..ffc2ff7 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -47,13 +47,7 @@ def jax_implementation( if start_index is None: start_index = (0,) * dim - local_indices = indices - np.array(start_index)[:, np.newaxis] - - indices_mask = [ - (local_indices[i, :] >= 0) & (local_indices[i, :] < mask.shape[i + 1]) - for i in range(mask.ndim - 1) - ] - indices_mask = np.logical_and.reduce(indices_mask) + local_indices = indices - np.array(start_index)[:, np.newaxis].T @jit def compute_boundary_id_and_mask(boundary_mask, mask): @@ -113,9 +107,10 @@ def kernel2d( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_mask[0, index[0], index[1]] = wp.uint8(id_number) mask[l, push_index[0], push_index[1]] = True + boundary_mask[0, index[0], index[1]] = wp.uint8(id_number) + # Construct the warp 3D kernel @wp.kernel def kernel3d( @@ -150,12 +145,11 @@ def kernel3d( for d in range(self.velocity_set.d): push_index[d] = index[d] + _c[d, l] - # Set the boundary id and mask - boundary_mask[0, index[0], index[1], index[2]] = wp.uint8( - id_number - ) + # Set the mask mask[l, push_index[0], push_index[1], push_index[2]] = True + boundary_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number) + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return None, kernel @@ -166,6 +160,8 @@ def warp_implementation( ): if start_index is None: start_index = (0,) * self.velocity_set.d + + indices = wp.array(indices, dtype=wp.int32) # Launch the warp kernel wp.launch( self.warp_kernel, diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index 2752314..7b7ac78 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -69,10 +69,12 @@ def save_image(fld, timestep, prefix=None): This function saves the field as an image in the PNG format. The filename is based on the name of the main script file, the provided prefix, and the timestep number. If the field is 3D, the magnitude of the field is calculated and saved. The image is saved with the 'nipy_spectral' colormap and the origin set to 'lower'. """ - fname = os.path.basename(__main__.__file__) - fname = os.path.splitext(fname)[0] - if prefix is not None: - fname = prefix + fname + if prefix is None: + fname = os.path.basename(__main__.__file__) + fname = os.path.splitext(fname)[0] + else: + fname = prefix + fname = fname + "_" + str(timestep).zfill(4) if len(fld.shape) > 3: From 179ae0bef77fd8b2a98ca0dabd1c6a5b01b1f602 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 28 Jun 2024 13:11:13 -0400 Subject: [PATCH 040/144] Fixed a minor bug for JAX backend --- xlb/operator/boundary_masker/indices_boundary_masker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index ffc2ff7..fa0c5e7 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -47,7 +47,7 @@ def jax_implementation( if start_index is None: start_index = (0,) * dim - local_indices = indices - np.array(start_index)[:, np.newaxis].T + local_indices = indices - np.array(start_index)[:, np.newaxis] @jit def compute_boundary_id_and_mask(boundary_mask, mask): From a10e4a853a7eef10161a1fd6018fa0a43c73e3ef Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 28 Jun 2024 14:34:12 -0400 Subject: [PATCH 041/144] Added KBC and wind tunnel at high Re --- examples/cfd/windtunnel_3d.py | 32 +++++++-- xlb/operator/collision/bgk.py | 14 ++-- xlb/operator/collision/kbc.py | 106 ++++++++++++++++++++++++---- xlb/operator/stepper/nse_stepper.py | 6 +- 4 files changed, 133 insertions(+), 25 deletions(-) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 7611c0e..72bc854 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -1,5 +1,6 @@ import xlb import trimesh +import time from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces @@ -16,9 +17,15 @@ import numpy as np import jax.numpy as jnp +# Configuration backend = ComputeBackend.WARP -velocity_set = xlb.velocity_set.D3Q19() +velocity_set = xlb.velocity_set.D3Q27() precision_policy = PrecisionPolicy.FP32FP32 +grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 +prescribed_vel = 0.02 +Re = 50000.0 +max_iterations = 100000 +print_interval = 1000 xlb.init( velocity_set=velocity_set, @@ -26,6 +33,18 @@ default_precision_policy=precision_policy, ) +# Print simulation info +print("Simulation Configuration:") +print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") +print(f"Backend: {backend}") +print(f"Velocity set: {velocity_set}") +print(f"Precision policy: {precision_policy}") +print(f"Prescribed velocity: {prescribed_vel}") +print(f"Reynolds number: {Re}") +print(f"Max iterations: {max_iterations}") +print("\n" + "=" * 50 + "\n") + + grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 grid_shape = (grid_size_x, grid_size_y, grid_size_z) @@ -51,7 +70,6 @@ boundary_mask, missing_mask, grid_shape, DoNothingBC.id, ["right"] ) -prescribed_vel = 0.02 bc_eq = QuadraticEquilibrium() bc_left = EquilibriumBC( rho=1.0, u=(prescribed_vel, 0.0, 0.0), equilibrium_operator=bc_eq @@ -95,19 +113,23 @@ def voxelize_stl(stl_filename, length_lbm_unit): boundary_conditions = [bc_left, bc_walls, bc_do_nothing] clength = grid_size_x - 1 -Re = 10000.0 visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=boundary_conditions + omega, boundary_conditions=boundary_conditions, collision_type="KBC" ) -for i in range(100000): +start_time = time.time() +for i in range(max_iterations): f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) f_0, f_1 = f_1, f_0 + if (i + 1) % print_interval == 0: + elapsed_time = time.time() - start_time + print(f"Iteration: {i+1}/{max_iterations} | Time elapsed: {elapsed_time:.2f}s") + # Write the results. We'll use JAX backend for the post-processing if not isinstance(f_0, jnp.ndarray): diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 15c0ad0..deb5fd0 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -16,7 +16,7 @@ class BGK(Collision): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u): fneq = f - feq fout = f - self.compute_dtype(self.omega) * fneq return fout @@ -29,7 +29,7 @@ def _construct_warp(self): # Construct the functional @wp.func - def functional(f: Any, feq: Any): + def functional(f: Any, feq: Any, rho: Any, u: Any): fneq = f - feq fout = f - _omega * fneq return fout @@ -40,6 +40,8 @@ def kernel2d( f: wp.array3d(dtype=Any), feq: wp.array3d(dtype=Any), fout: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() @@ -65,6 +67,8 @@ def kernel3d( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -78,7 +82,7 @@ def kernel3d( _feq[l] = feq[l, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq) + _fout = functional(_f, _feq, rho, u) # Write the result for l in range(self.velocity_set.q): @@ -89,7 +93,7 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout): + def warp_implementation(self, f, feq, fout, rho, u): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -97,6 +101,8 @@ def warp_implementation(self, f, feq, fout): f, feq, fout, + rho, + u, ], dim=f.shape[1:], ) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index f3c996b..ec40b56 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -217,8 +217,9 @@ def _construct_warp(self): _cc = self.velocity_set.wp_cc _omega = wp.constant(self.compute_dtype(self.omega)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 _pi_vec = wp.vec( - self.velocity_set.d * (self.velocity_set.d + 1) // 2, + _pi_dim, dtype=self.compute_dtype, ) _epsilon = wp.constant(self.compute_dtype(self.epsilon)) @@ -227,24 +228,39 @@ def _construct_warp(self): # Construct functional for computing momentum flux @wp.func - def momentum_flux( + def momentum_flux_warp( fneq: Any, ): # Get momentum flux pi = _pi_vec() - for d in range(6): + for d in range(_pi_dim): pi[d] = 0.0 for q in range(self.velocity_set.q): pi[d] += _cc[q, d] * fneq[q] return pi + @wp.func + def decompose_shear_d2q9(fneq: Any): + pi = momentum_flux_warp(fneq) + N = pi[0] - pi[1] + s = wp.vec9(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + s[3] = N + s[6] = N + s[2] = -N + s[1] = -N + s[8] = pi[2] + s[4] = -pi[2] + s[5] = -pi[2] + s[7] = pi[2] + return s + # Construct functional for decomposing shear @wp.func def decompose_shear_d3q27( fneq: Any, ): # Get momentum flux - pi = momentum_flux(fneq) + pi = momentum_flux_warp(fneq) nxz = pi[0] - pi[5] nyz = pi[3] - pi[5] @@ -294,7 +310,29 @@ def entropic_scalar_product( # Construct the functional @wp.func - def functional( + def functional2d( + f: Any, + feq: Any, + rho: Any, + u: Any, + ): + # Compute shear and delta_s + fneq = f - feq + shear = decompose_shear_d2q9(fneq) + delta_s = shear * rho # TODO: Check this + + # Perform collision + delta_h = fneq - delta_s + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( + delta_s, delta_h, feq + ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + fout = f - _beta * (2.0 * delta_s + gamma * delta_h) + + return fout + + # Construct the functional + @wp.func + def functional3d( f: Any, feq: Any, rho: Any, @@ -316,7 +354,38 @@ def functional( # Construct the warp kernel @wp.kernel - def kernel( + def kernel2d( + f: wp.array3d(dtype=Any), + feq: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + fout: wp.array3d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j) # TODO: Warp needs to fix this + + # Load needed values + _f = _f_vec() + _feq = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _feq[l] = feq[l, index[0], index[1]] + _u = self._warp_u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1]] + _rho = rho[0, index[0], index[1]] + + # Compute the collision + _fout = functional(_f, _feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1]] = _fout[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), @@ -345,14 +414,23 @@ def kernel( for l in range(self.velocity_set.q): fout[l, index[0], index[1], index[2]] = _fout[l] + functional = functional3d if self.velocity_set.d == 3 else functional2d + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) - def warp_implementation( - self, - f: jnp.ndarray, - feq: jnp.ndarray, - rho: jnp.ndarray, - ): - raise NotImplementedError("Warp implementation not yet implemented") + def warp_implementation(self, f, feq, fout, rho, u): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + f, + feq, + fout, + rho, + u, + ], + dim=f.shape[1:], + ) + return fout diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 0206416..d52f41d 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -58,7 +58,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_0, feq) + f_post_collision = self.collision(f_0, feq, rho, u) # Apply collision type boundary conditions for bc in self.boundary_conditions: @@ -153,6 +153,8 @@ def kernel2d( f_post_collision = self.collision.warp_functional( f_post_stream, feq, + rho, + u, ) # Apply collision type boundary conditions @@ -218,7 +220,7 @@ def kernel3d( feq = self.equilibrium.warp_functional(rho, u) # Apply collision - f_post_collision = self.collision.warp_functional(f_post_stream, feq) + f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions if _boundary_id == _fullway_bounce_back_bc: From 327b19f372562a5144cb92c4f1a682299d26532b Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 12 Jul 2024 17:09:41 -0400 Subject: [PATCH 042/144] WIP: testing a new way of setting up BCs --- examples/cfd/flow_past_sphere_3d.py | 68 +++++---- xlb/grid/grid.py | 44 ++++++ xlb/helper/__init__.py | 3 +- xlb/helper/boundary_conditions.py | 69 --------- .../boundary_condition/bc_do_nothing.py | 8 +- .../boundary_condition/bc_equilibrium.py | 7 +- .../bc_fullway_bounce_back.py | 8 +- .../boundary_condition/boundary_condition.py | 5 + xlb/operator/boundary_masker/__init__.py | 3 - .../indices_boundary_masker.py | 23 +-- .../boundary_masker/planar_boundary_masker.py | 143 ------------------ xlb/operator/operator.py | 2 - 12 files changed, 110 insertions(+), 273 deletions(-) delete mode 100644 xlb/helper/boundary_conditions.py delete mode 100644 xlb/operator/boundary_masker/planar_boundary_masker.py diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 3175cfa..12d3a00 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -1,7 +1,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, @@ -10,12 +10,14 @@ ) from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic +from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np import jax.numpy as jnp -backend = ComputeBackend.WARP +# Initial setup and backend configuration +backend = ComputeBackend.JAX velocity_set = xlb.velocity_set.D3Q19() precision_policy = PrecisionPolicy.FP32FP32 @@ -25,37 +27,26 @@ default_precision_policy=precision_policy, ) +#TODO HS: check inconsistency between grid_shape and velocity_set +#TODO HS: why is boundary_mask and missing_mask in the same function?! they should be separated +#TODO HS: missing_mask needs to be created based on ALL boundary indices and a SINGLE streaming operation not one streaming call per bc! +#TODO HS: why bc operatores need to be stated twice: once in making boundary_mask and missing_mask and one in making bc list. +#TODO HS: proposal: we should include indices as part of the construction of the bc operators and then have a single call to construcut boundary_mask and missing_mask fields based on bc_list. + +# Define grid grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 grid_shape = (grid_size_x, grid_size_y, grid_size_z) +# Define fields on the grid grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) -# Velocity on left face (3D) -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["left"] -) - - -# Wall on all other faces (3D) except right -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, - missing_mask, - grid_shape, - FullwayBounceBackBC.id, - ["bottom", "right", "front", "back"], -) - -# Do nothing on right face -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, missing_mask, grid_shape, DoNothingBC.id, ["right"] -) - -bc_eq = QuadraticEquilibrium() -bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=bc_eq) -bc_walls = FullwayBounceBackBC() -bc_do_nothing = DoNothingBC() - +# Specify BC indices +inlet = grid.boundingBoxIndices['left'] +outlet = grid.boundingBoxIndices['right'] +walls = [grid.boundingBoxIndices['bottom'][i] + grid.boundingBoxIndices['top'][i] + + grid.boundingBoxIndices['front'][i] + grid.boundingBoxIndices['back'][i] for i in range(velocity_set.d)] +# indices for sphere sphere_radius = grid_size_y // 12 x = np.arange(grid_size_x) y = np.arange(grid_size_y) @@ -67,25 +58,36 @@ + (Z - grid_size_z // 2) ** 2 < sphere_radius**2 ) -indices = np.array(indices) +sphere = [tuple(indices[i]) for i in range(velocity_set.d)] -# Set boundary conditions on the indices -indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( +# Instantiate BC objects +bc_left = EquilibriumBC(inlet, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) +bc_walls = FullwayBounceBackBC(walls) +bc_do_nothing = DoNothingBC(outlet) +bc_sphere = FullwayBounceBackBC(sphere) + +# Set boundary_id and missing_mask for all BCs in boundary_conditions list +boundary_condition_list = [bc_left, bc_walls, bc_do_nothing, bc_sphere] +indices_boundary_masker = IndicesBoundaryMasker( velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=backend, ) boundary_mask, missing_mask = indices_boundary_masker( - indices, FullwayBounceBackBC.id, boundary_mask, missing_mask, (0, 0, 0) + boundary_condition_list, boundary_mask, missing_mask, (0, 0, 0) ) +# Note: In case we want to remove indices from BC objects +# for bc in boundary_condition_list: +# bc.__dict__.pop('indices', None) + +# Initialize fields to start the run f_0 = initialize_eq(f_0, grid, velocity_set, backend) -boundary_conditions = [bc_left, bc_walls, bc_do_nothing] omega = 1.8 stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=boundary_conditions + omega, boundary_conditions=boundary_condition_list ) for i in range(10000): diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 5a796f1..483386d 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -27,8 +27,52 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): self.shape = shape self.dim = len(shape) self.compute_backend = compute_backend + self._bounding_box_indices() self._initialize_backend() @abstractmethod def _initialize_backend(self): pass + + def _bounding_box_indices(self): + """ + This function calculates the indices of the bounding box of a 2D or 3D grid. + The bounding box is defined as the set of grid points on the outer edge of the grid. + + Returns + ------- + boundingBox (dict): A dictionary where keys are the names of the bounding box faces + ("bottom", "top", "left", "right" for 2D; additional "front", "back" for 3D), and values + are numpy arrays of indices corresponding to each face. + """ + def to_tuple(list): + d = len(list[0]) + return [tuple([sublist[i] for sublist in list]) for i in range(d)] + + if self.dim == 2: + # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. + # Each edge is represented as an array of indices. For example, the bottom edge includes + # all points where the y-coordinate is 0, so its indices are [[i, 0] for i in range(nx)]. + nx, ny = self.shape + self.boundingBoxIndices = { + "bottom": to_tuple([[i, 0] for i in range(nx)]), + "top": to_tuple([[i, ny - 1] for i in range(nx)]), + "left": to_tuple([[0, i] for i in range(ny)]), + "right": to_tuple([[nx - 1, i] for i in range(ny)]) + } + + elif self.dim == 3: + # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. + # Each face is represented as an array of indices. For example, the bottom face includes all points + # where the z-coordinate is 0, so its indices are [[i, j, 0] for i in range(nx) for j in range(ny)]. + nx, ny, nz = self.shape + self.boundingBoxIndices = { + "bottom": to_tuple([[i, j, 0] for i in range(nx) for j in range(ny)]), + "top": to_tuple([[i, j, nz - 1] for i in range(nx) for j in range(ny)]), + "left": to_tuple([[0, j, k] for j in range(ny) for k in range(nz)]), + "right": to_tuple([[nx - 1, j, k] for j in range(ny) for k in range(nz)]), + "front": to_tuple([[i, 0, k] for i in range(nx) for k in range(nz)]), + "back": to_tuple([[i, ny - 1, k] for i in range(nx) for k in range(nz)]) + } + return + diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py index 2dab37b..29ac3f6 100644 --- a/xlb/helper/__init__.py +++ b/xlb/helper/__init__.py @@ -1,3 +1,2 @@ from xlb.helper.nse_solver import create_nse_fields -from xlb.helper.initializers import initialize_eq -from xlb.helper.boundary_conditions import assign_bc_id_box_faces \ No newline at end of file +from xlb.helper.initializers import initialize_eq \ No newline at end of file diff --git a/xlb/helper/boundary_conditions.py b/xlb/helper/boundary_conditions.py deleted file mode 100644 index d27a842..0000000 --- a/xlb/helper/boundary_conditions.py +++ /dev/null @@ -1,69 +0,0 @@ -from xlb.operator.boundary_masker import PlanarBoundaryMasker - - -def assign_bc_id_box_faces(boundary_mask, missing_mask, shape, bc_id, sides, backend=None): - """ - Assign boundary conditions for specified sides of 2D and 3D boxes using planar_boundary_masker function. - - Parameters: - boundary_mask: ndarray - The field containing boundary IDs. - missing_mask: ndarray - The mask indicating missing boundary IDs. - shape: tuple - The shape of the grid (extent of the grid in each dimension). - bc_id: int - The boundary condition ID to assign to the specified boundaries. - sides: list of str - The list of sides to apply conditions to. Valid values for 2D are 'bottom', 'top', 'left', 'right'. - Valid values for 3D are 'bottom', 'top', 'front', 'back', 'left', 'right'. - """ - - planar_boundary_masker = PlanarBoundaryMasker(compute_backend=backend) - - def apply(lower_bound, upper_bound, direction, reference=(0, 0, 0)): - nonlocal boundary_mask, missing_mask, planar_boundary_masker - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - bc_id, - boundary_mask, - missing_mask, - reference, - ) - - dimensions = len(shape) - - if dimensions == 2: - nr, nc = shape - for boundary in sides: - if boundary == "bottom": - apply((0, 0), (nr, 1), (1, 0), (0, 0)) - elif boundary == "top": - apply((0, nc - 1), (nr, nc), (1, 0), (0, 0)) - elif boundary == "left": - apply((0, 0), (1, nc), (0, 1), (0, 0)) - elif boundary == "right": - apply((nr - 1, 0), (nr, nc), (0, 1), (0, 0)) - - elif dimensions == 3: - nr, nc, nz = shape - for boundary in sides: - if boundary == "bottom": - apply((0, 0, 0), (nr, 1, nz), (1, 0, 0), (0, 0, 0)) - elif boundary == "top": - apply((0, nc - 1, 0), (nr, nc, nz), (1, 0, 0), (0, 0, 0)) - elif boundary == "front": - apply((0, 0, 0), (nr, nc, 1), (0, 1, 0), (0, 0, 0)) - elif boundary == "back": - apply((0, 0, nz - 1), (nr, nc, nz), (0, 1, 0), (0, 0, 0)) - elif boundary == "left": - apply((0, 0, 0), (1, nc, nz), (0, 0, 1), (0, 0, 0)) - elif boundary == "right": - apply((nr - 1, 0, 0), (nr, nc, nz), (0, 0, 1), (0, 0, 0)) - - else: - raise ValueError("Unsupported dimensions: {}".format(dimensions)) - - return boundary_mask, missing_mask \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 9992898..1938207 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -3,12 +3,10 @@ """ import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax +from jax import jit from functools import partial -import numpy as np import warp as wp -from typing import Tuple, Any +from typing import Any, List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -33,11 +31,13 @@ class DoNothingBC(BoundaryCondition): def __init__( self, + indices: List[int], velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, ): super().__init__( + indices, ImplementationStep.STREAMING, velocity_set, precision_policy, diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index ee06fcd..2ac632b 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -3,12 +3,11 @@ """ import jax.numpy as jnp -from jax import jit, device_count +from jax import jit import jax.lax as lax from functools import partial -import numpy as np import warp as wp -from typing import Tuple, Any +from typing import Tuple, Any, List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -33,6 +32,7 @@ class EquilibriumBC(BoundaryCondition): def __init__( self, + indices: List[int], rho: float, u: Tuple[float, float, float], equilibrium_operator: Operator, @@ -50,6 +50,7 @@ def __init__( # Call the parent constructor super().__init__( + indices, ImplementationStep.STREAMING, velocity_set, precision_policy, diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 31b2a1f..f528634 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -3,12 +3,10 @@ """ import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax +from jax import jit from functools import partial -import numpy as np import warp as wp -from typing import Any +from typing import Any, List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -32,11 +30,13 @@ class FullwayBounceBackBC(BoundaryCondition): def __init__( self, + indices: List[int], velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, ): super().__init__( + indices, ImplementationStep.COLLISION, velocity_set, precision_policy, diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 288565d..ec6ba1e 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -7,6 +7,7 @@ from functools import partial import numpy as np from enum import Enum, auto +from typing import List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -27,6 +28,7 @@ class BoundaryCondition(Operator): def __init__( self, + indices: List[int], implementation_step: ImplementationStep, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, @@ -38,5 +40,8 @@ def __init__( super().__init__(velocity_set, precision_policy, compute_backend) + # Set the BC indices + self.indices = indices + # Set the implementation step self.implementation_step = implementation_step diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index 7f4b803..cc80b85 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,9 +1,6 @@ from xlb.operator.boundary_masker.indices_boundary_masker import ( IndicesBoundaryMasker, ) -from xlb.operator.boundary_masker.planar_boundary_masker import ( - PlanarBoundaryMasker, -) from xlb.operator.boundary_masker.stl_boundary_masker import ( STLBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index fa0c5e7..3559e33 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -41,14 +41,9 @@ def _indices_to_tuple(indices): @Operator.register_backend(ComputeBackend.JAX) def jax_implementation( - self, indices, id_number, boundary_mask, mask, start_index=None + self, bclist, boundary_mask, mask, start_index=None ): - dim = mask.ndim - 1 - if start_index is None: - start_index = (0,) * dim - - local_indices = indices - np.array(start_index)[:, np.newaxis] - + # define a helper function @jit def compute_boundary_id_and_mask(boundary_mask, mask): if dim == 2: @@ -64,11 +59,19 @@ def compute_boundary_id_and_mask(boundary_mask, mask): mask = mask.at[ :, local_indices[0], local_indices[1], local_indices[2] ].set(True) - - mask = self.stream(mask) return boundary_mask, mask + + dim = mask.ndim - 1 + if start_index is None: + start_index = (0,) * dim + + for bc in bclist: + id_number = bc.id + local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] + boundary_mask, mask = compute_boundary_id_and_mask(boundary_mask, mask) - return compute_boundary_id_and_mask(boundary_mask, mask) + mask = self.stream(mask) + return boundary_mask, mask def _construct_warp(self): # Make constants for warp diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py deleted file mode 100644 index ed20604..0000000 --- a/xlb/operator/boundary_masker/planar_boundary_masker.py +++ /dev/null @@ -1,143 +0,0 @@ -# Base class for all equilibriums - -from functools import partial -import numpy as np -import jax.numpy as jnp -from jax import jit -import warp as wp -from typing import Tuple -from jax.numpy import where, einsum, full_like - - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.stream.stream import Stream - - -class PlanarBoundaryMasker(Operator): - """ - Operator for creating a boundary mask on a plane of the domain - """ - - def __init__( - self, - velocity_set: VelocitySet = None, - precision_policy: PrecisionPolicy = None, - compute_backend: ComputeBackend = None, - ): - # Call super - super().__init__(velocity_set, precision_policy, compute_backend) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0, 1, 2, 3, 4, 7)) - def jax_implementation( - self, - lower_bound, - upper_bound, - direction, - id_number, - boundary_mask, - missing_mask, - start_index=None, - ): - if start_index is None: - start_index = (0,) * self.velocity_set.d - - _, *dimensions = boundary_mask.shape - - indices = [ - (max(0, lb + start), min(dim, ub + start)) - for lb, ub, start, dim in zip( - lower_bound, upper_bound, start_index, dimensions - ) - ] - - slices = [slice(None)] - slices.extend(slice(lb, ub) for lb, ub in indices) - boundary_mask = boundary_mask.at[tuple(slices)].set(id_number) - - return boundary_mask, missing_mask - - def _construct_warp(self): - # Make constants for warp - _c = self.velocity_set.wp_c - _q = wp.constant(self.velocity_set.q) - - @wp.kernel - def kernel2d( - lower_bound: wp.vec2i, - upper_bound: wp.vec2i, - direction: wp.vec2i, - id_number: wp.uint8, - boundary_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - start_index: wp.vec2i, - ): - i, j = wp.tid() - lb_x, lb_y = lower_bound.x + start_index.x, lower_bound.y + start_index.y - ub_x, ub_y = upper_bound.x + start_index.x, upper_bound.y + start_index.y - - if lb_x <= i < ub_x and lb_y <= j < ub_y: - boundary_mask[0, i, j] = id_number - - @wp.kernel - def kernel3d( - lower_bound: wp.vec3i, - upper_bound: wp.vec3i, - direction: wp.vec3i, - id_number: wp.uint8, - boundary_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - start_index: wp.vec3i, - ): - i, j, k = wp.tid() - lb_x, lb_y, lb_z = ( - lower_bound.x + start_index.x, - lower_bound.y + start_index.y, - lower_bound.z + start_index.z, - ) - ub_x, ub_y, ub_z = ( - upper_bound.x + start_index.x, - upper_bound.y + start_index.y, - upper_bound.z + start_index.z, - ) - - if lb_x <= i < ub_x and lb_y <= j < ub_y and lb_z <= k < ub_z: - boundary_mask[0, i, j, k] = id_number - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - - return None, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation( - self, - lower_bound, - upper_bound, - direction, - id_number, - boundary_mask, - missing_mask, - start_index=None, - ): - if start_index is None: - start_index = (0,) * self.velocity_set.d - - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - lower_bound, - upper_bound, - direction, - id_number, - boundary_mask, - missing_mask, - start_index, - ], - dim=missing_mask.shape[1:], - ) - - return boundary_mask, missing_mask diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 8b50abd..38e8e15 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -50,8 +50,6 @@ def decorator(func): return decorator - return decorator - def __call__(self, *args, callback=None, **kwargs): method_candidates = [ (key, method) From 6be34958cad5f6d2e4c28bd27dbb05a705375266 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 15 Jul 2024 12:06:19 -0400 Subject: [PATCH 043/144] Fixed a few bugs. JAX works. Need to look into Warp next --- examples/cfd/flow_past_sphere_3d.py | 2 +- .../indices_boundary_masker.py | 1 - xlb/operator/stepper/nse_stepper.py | 4 +- xlb/operator/stepper/stepper.py | 60 +++++++++---------- 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 12d3a00..cf2f7f6 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -34,7 +34,7 @@ #TODO HS: proposal: we should include indices as part of the construction of the bc operators and then have a single call to construcut boundary_mask and missing_mask fields based on bc_list. # Define grid -grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 +grid_size_x, grid_size_y, grid_size_z = 512//4, 128//4, 128//4 grid_shape = (grid_size_x, grid_size_y, grid_size_z) # Define fields on the grid diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 3559e33..b9eb60e 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -44,7 +44,6 @@ def jax_implementation( self, bclist, boundary_mask, mask, start_index=None ): # define a helper function - @jit def compute_boundary_id_and_mask(boundary_mask, mask): if dim == 2: boundary_mask = boundary_mask.at[ diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index d52f41d..11b5615 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -63,7 +63,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: if bc.implementation_step == ImplementationStep.COLLISION: - f_0 = bc( + f_post_collision = bc( f_0, f_post_collision, boundary_mask, @@ -71,7 +71,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): ) # Apply streaming - f_1 = self.stream(f_0) + f_1 = self.stream(f_post_collision) # Apply boundary conditions for bc in self.boundary_conditions: diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c11b39b..553e20c 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -88,36 +88,36 @@ def __init__(self, operators, boundary_conditions): elif isinstance(bc, FullwayBounceBackBC): self.fullway_bounce_back_bc = bc - if self.equilibrium_bc is None: - # Select the equilibrium operator based on its type - self.equilibrium_bc = EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=next( - (op for op in self.operators if isinstance(op, Equilibrium)), None - ), - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.do_nothing_bc is None: - self.do_nothing_bc = DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.halfway_bounce_back_bc is None: - self.halfway_bounce_back_bc = HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.fullway_bounce_back_bc is None: - self.fullway_bounce_back_bc = FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) + # if self.equilibrium_bc is None: + # # Select the equilibrium operator based on its type + # self.equilibrium_bc = EquilibriumBC( + # rho=1.0, + # u=(0.0, 0.0, 0.0), + # equilibrium_operator=next( + # (op for op in self.operators if isinstance(op, Equilibrium)), None + # ), + # velocity_set=velocity_set, + # precision_policy=precision_policy, + # compute_backend=compute_backend, + # ) + # if self.do_nothing_bc is None: + # self.do_nothing_bc = DoNothingBC( + # velocity_set=velocity_set, + # precision_policy=precision_policy, + # compute_backend=compute_backend, + # ) + # if self.halfway_bounce_back_bc is None: + # self.halfway_bounce_back_bc = HalfwayBounceBackBC( + # velocity_set=velocity_set, + # precision_policy=precision_policy, + # compute_backend=compute_backend, + # ) + # if self.fullway_bounce_back_bc is None: + # self.fullway_bounce_back_bc = FullwayBounceBackBC( + # velocity_set=velocity_set, + # precision_policy=precision_policy, + # compute_backend=compute_backend, + # ) ############################################ # Initialize operator From 3a9a97e32a53b3e84e0b90c931f94d82f9c03fcf Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 16 Jul 2024 00:03:36 -0400 Subject: [PATCH 044/144] Warp working with the new BC setup --- examples/cfd/flow_past_sphere_3d.py | 4 +- .../bc_halfway_bounce_back.py | 4 +- .../indices_boundary_masker.py | 33 ++++++---- xlb/operator/stepper/stepper.py | 64 ++++++++++--------- 4 files changed, 59 insertions(+), 46 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index cf2f7f6..9ab0d6d 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -17,7 +17,7 @@ import jax.numpy as jnp # Initial setup and backend configuration -backend = ComputeBackend.JAX +backend = ComputeBackend.WARP velocity_set = xlb.velocity_set.D3Q19() precision_policy = PrecisionPolicy.FP32FP32 @@ -34,7 +34,7 @@ #TODO HS: proposal: we should include indices as part of the construction of the bc operators and then have a single call to construcut boundary_mask and missing_mask fields based on bc_list. # Define grid -grid_size_x, grid_size_y, grid_size_z = 512//4, 128//4, 128//4 +grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 grid_shape = (grid_size_x, grid_size_y, grid_size_z) # Define fields on the grid diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index b8ded63..c4f323e 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -8,7 +8,7 @@ from functools import partial import numpy as np import warp as wp -from typing import Tuple, Any +from typing import Tuple, Any, List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -34,12 +34,14 @@ class HalfwayBounceBackBC(BoundaryCondition): def __init__( self, + indices: List[int], velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, ): # Call the parent constructor super().__init__( + indices, ImplementationStep.STREAMING, velocity_set, precision_policy, diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index b9eb60e..80f0c02 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -32,14 +32,10 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) - @staticmethod - def _indices_to_tuple(indices): - """ - Converts a tensor of indices to a tuple for indexing - """ - return tuple(indices.T) @Operator.register_backend(ComputeBackend.JAX) + # TODO HS: figure out why uncommenting the line below fails unlike other operators! + # @partial(jit, static_argnums=(0)) def jax_implementation( self, bclist, boundary_mask, mask, start_index=None ): @@ -81,7 +77,7 @@ def _construct_warp(self): @wp.kernel def kernel2d( indices: wp.array2d(dtype=wp.int32), - id_number: wp.int32, + id_number: wp.array1d(dtype=wp.uint8), boundary_mask: wp.array3d(dtype=wp.uint8), mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, @@ -111,13 +107,13 @@ def kernel2d( # Set the boundary id and mask mask[l, push_index[0], push_index[1]] = True - boundary_mask[0, index[0], index[1]] = wp.uint8(id_number) + boundary_mask[0, index[0], index[1]] = id_number[ii] # Construct the warp 3D kernel @wp.kernel def kernel3d( indices: wp.array2d(dtype=wp.int32), - id_number: wp.int32, + id_number: wp.array1d(dtype=wp.uint8), boundary_mask: wp.array4d(dtype=wp.uint8), mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, @@ -150,7 +146,7 @@ def kernel3d( # Set the mask mask[l, push_index[0], push_index[1], push_index[2]] = True - boundary_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number) + boundary_mask[0, index[0], index[1], index[2]] = id_number[ii] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d @@ -158,12 +154,23 @@ def kernel3d( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation( - self, indices, id_number, boundary_mask, missing_mask, start_index=None + self, bclist, boundary_mask, missing_mask, start_index=None ): + + dim = self.velocity_set.d + index_list = [[] for _ in range(dim)] + id_list = [] + for bc in bclist: + for d in range(dim): + index_list[d] += bc.indices[d] + id_list += [bc.id] * len(bc.indices[0]) + + indices = wp.array2d(index_list, dtype = wp.int32) + id_number = wp.array1d(id_list, dtype = wp.uint8) + if start_index is None: - start_index = (0,) * self.velocity_set.d + start_index = (0,) * dim - indices = wp.array(indices, dtype=wp.int32) # Launch the warp kernel wp.launch( self.warp_kernel, diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 553e20c..6af51c2 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -88,36 +88,40 @@ def __init__(self, operators, boundary_conditions): elif isinstance(bc, FullwayBounceBackBC): self.fullway_bounce_back_bc = bc - # if self.equilibrium_bc is None: - # # Select the equilibrium operator based on its type - # self.equilibrium_bc = EquilibriumBC( - # rho=1.0, - # u=(0.0, 0.0, 0.0), - # equilibrium_operator=next( - # (op for op in self.operators if isinstance(op, Equilibrium)), None - # ), - # velocity_set=velocity_set, - # precision_policy=precision_policy, - # compute_backend=compute_backend, - # ) - # if self.do_nothing_bc is None: - # self.do_nothing_bc = DoNothingBC( - # velocity_set=velocity_set, - # precision_policy=precision_policy, - # compute_backend=compute_backend, - # ) - # if self.halfway_bounce_back_bc is None: - # self.halfway_bounce_back_bc = HalfwayBounceBackBC( - # velocity_set=velocity_set, - # precision_policy=precision_policy, - # compute_backend=compute_backend, - # ) - # if self.fullway_bounce_back_bc is None: - # self.fullway_bounce_back_bc = FullwayBounceBackBC( - # velocity_set=velocity_set, - # precision_policy=precision_policy, - # compute_backend=compute_backend, - # ) + if self.equilibrium_bc is None: + # Select the equilibrium operator based on its type + self.equilibrium_bc = EquilibriumBC( + [], + rho=1.0, + u=(0.0, 0.0, 0.0), + equilibrium_operator=next( + (op for op in self.operators if isinstance(op, Equilibrium)), None + ), + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + if self.do_nothing_bc is None: + self.do_nothing_bc = DoNothingBC( + [], + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + if self.halfway_bounce_back_bc is None: + self.halfway_bounce_back_bc = HalfwayBounceBackBC( + [], + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + if self.fullway_bounce_back_bc is None: + self.fullway_bounce_back_bc = FullwayBounceBackBC( + [], + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) ############################################ # Initialize operator From 9b9b184b777abc3e6c03d31ad946aa582d6da3d5 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 16 Jul 2024 11:16:03 -0400 Subject: [PATCH 045/144] converting the example script to be class-based --- examples/cfd/flow_past_sphere_3d.py | 190 ++++++++++++++-------------- 1 file changed, 98 insertions(+), 92 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 9ab0d6d..be9863a 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -16,98 +16,104 @@ import numpy as np import jax.numpy as jnp -# Initial setup and backend configuration -backend = ComputeBackend.WARP +class FlowOverSphere: + def __init__(self, grid_shape, velocity_set, backend, precision_policy): + + # initialize backend + xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, + ) + + self.grid_shape = grid_shape + self.velocity_set = velocity_set + self.backend = backend + self.precision_policy = precision_policy + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.stepper = None + self.boundary_conditions = [] + + def define_boundary_indices(self): + inlet = self.grid.boundingBoxIndices['left'] + outlet = self.grid.boundingBoxIndices['right'] + walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['top'][i] + + self.grid.boundingBoxIndices['front'][i] + self.grid.boundingBoxIndices['back'][i] for i in range(self.velocity_set.d)] + + sphere_radius = self.grid_shape[1] // 12 + x = np.arange(self.grid_shape[0]) + y = np.arange(self.grid_shape[1]) + z = np.arange(self.grid_shape[2]) + X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + indices = np.where( + (X - self.grid_shape[0] // 6) ** 2 + + (Y - self.grid_shape[1] // 2) ** 2 + + (Z - self.grid_shape[2] // 2) ** 2 + < sphere_radius**2 + ) + sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)] + + return inlet, outlet, walls, sphere + + def instantiate_boundary_conditions(self, inlet, outlet, walls, sphere): + bc_left = EquilibriumBC(inlet, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) + bc_walls = FullwayBounceBackBC(walls) + bc_do_nothing = DoNothingBC(outlet) + bc_sphere = FullwayBounceBackBC(sphere) + self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_sphere] + + def set_boundary_masks(self): + indices_boundary_masker = IndicesBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.backend, + ) + self.boundary_mask, self.missing_mask = indices_boundary_masker( + self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0) + ) + + def initialize_fields(self): + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + + def setup_stepper(self, omega): + self.stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=self.boundary_conditions + ) + + def run_simulation(self, num_steps): + for i in range(num_steps): + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.f_1, self.f_0 + + def post_process(self, i): + # Write the results. We'll use JAX backend for the post-processing + if not isinstance(self.f_0, jnp.ndarray): + self.f_0 = wp.to_jax(self.f_0) + + macro = Macroscopic(compute_backend=ComputeBackend.JAX) + rho, u = macro(self.f_0) + + # remove boundary cells + u = u[:, 1:-1, 1:-1, 1:-1] + u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 + + fields = {"u_magnitude": u_magnitude} + + save_fields_vtk(fields, timestep=i) + save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) + + +# Running the simulation +grid_shape = (512, 128, 128) velocity_set = xlb.velocity_set.D3Q19() +backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 -xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, -) - -#TODO HS: check inconsistency between grid_shape and velocity_set -#TODO HS: why is boundary_mask and missing_mask in the same function?! they should be separated -#TODO HS: missing_mask needs to be created based on ALL boundary indices and a SINGLE streaming operation not one streaming call per bc! -#TODO HS: why bc operatores need to be stated twice: once in making boundary_mask and missing_mask and one in making bc list. -#TODO HS: proposal: we should include indices as part of the construction of the bc operators and then have a single call to construcut boundary_mask and missing_mask fields based on bc_list. - -# Define grid -grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 -grid_shape = (grid_size_x, grid_size_y, grid_size_z) - -# Define fields on the grid -grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) - -# Specify BC indices -inlet = grid.boundingBoxIndices['left'] -outlet = grid.boundingBoxIndices['right'] -walls = [grid.boundingBoxIndices['bottom'][i] + grid.boundingBoxIndices['top'][i] + - grid.boundingBoxIndices['front'][i] + grid.boundingBoxIndices['back'][i] for i in range(velocity_set.d)] - -# indices for sphere -sphere_radius = grid_size_y // 12 -x = np.arange(grid_size_x) -y = np.arange(grid_size_y) -z = np.arange(grid_size_z) -X, Y, Z = np.meshgrid(x, y, z, indexing='ij') -indices = np.where( - (X - grid_size_x // 6) ** 2 - + (Y - grid_size_y // 2) ** 2 - + (Z - grid_size_z // 2) ** 2 - < sphere_radius**2 -) -sphere = [tuple(indices[i]) for i in range(velocity_set.d)] - -# Instantiate BC objects -bc_left = EquilibriumBC(inlet, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) -bc_walls = FullwayBounceBackBC(walls) -bc_do_nothing = DoNothingBC(outlet) -bc_sphere = FullwayBounceBackBC(sphere) - -# Set boundary_id and missing_mask for all BCs in boundary_conditions list -boundary_condition_list = [bc_left, bc_walls, bc_do_nothing, bc_sphere] -indices_boundary_masker = IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=backend, -) - -boundary_mask, missing_mask = indices_boundary_masker( - boundary_condition_list, boundary_mask, missing_mask, (0, 0, 0) -) -# Note: In case we want to remove indices from BC objects -# for bc in boundary_condition_list: -# bc.__dict__.pop('indices', None) - - -# Initialize fields to start the run -f_0 = initialize_eq(f_0, grid, velocity_set, backend) -omega = 1.8 - -stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=boundary_condition_list -) - -for i in range(10000): - f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) - f_0, f_1 = f_1, f_0 - - -# Write the results. We'll use JAX backend for the post-processing -if not isinstance(f_0, jnp.ndarray): - f_0 = wp.to_jax(f_0) - -macro = Macroscopic(compute_backend=ComputeBackend.JAX) - -rho, u = macro(f_0) - -# remove boundary cells -u = u[:, 1:-1, 1:-1, 1:-1] -u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - -fields = {"u_magnitude": u_magnitude} - -save_fields_vtk(fields, timestep=i) -save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) +simulation = FlowOverSphere(grid_shape, velocity_set, backend, precision_policy) +inlet, outlet, walls, sphere = simulation.define_boundary_indices() +simulation.instantiate_boundary_conditions(inlet, outlet, walls, sphere) +simulation.set_boundary_masks() +simulation.initialize_fields() +simulation.setup_stepper(omega=1.8) +simulation.run_simulation(num_steps=10000) +simulation.post_process(i=10000) From 1b31ccf030d021f083acbda07a82565197a66a35 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 16 Jul 2024 13:35:00 -0400 Subject: [PATCH 046/144] updated all examples to run properly --- examples/cfd/flow_past_sphere_3d.py | 29 +- examples/cfd/lid_driven_cavity_2d.py | 157 +++++----- .../cfd/lid_driven_cavity_2d_distributed.py | 115 +++----- examples/cfd/windtunnel_3d.py | 270 +++++++++--------- examples/performance/mlups_3d.py | 33 +-- 5 files changed, 291 insertions(+), 313 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index be9863a..2ea71a8 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -55,7 +55,8 @@ def define_boundary_indices(self): return inlet, outlet, walls, sphere - def instantiate_boundary_conditions(self, inlet, outlet, walls, sphere): + def setup_boundary_conditions(self): + inlet, outlet, walls, sphere = self.define_boundary_indices() bc_left = EquilibriumBC(inlet, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) bc_walls = FullwayBounceBackBC(walls) bc_do_nothing = DoNothingBC(outlet) @@ -103,17 +104,17 @@ def post_process(self, i): save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) -# Running the simulation -grid_shape = (512, 128, 128) -velocity_set = xlb.velocity_set.D3Q19() -backend = ComputeBackend.WARP -precision_policy = PrecisionPolicy.FP32FP32 +if __name__ == "__main__": + # Running the simulation + grid_shape = (512, 128, 128) + velocity_set = xlb.velocity_set.D3Q19() + backend = ComputeBackend.WARP + precision_policy = PrecisionPolicy.FP32FP32 -simulation = FlowOverSphere(grid_shape, velocity_set, backend, precision_policy) -inlet, outlet, walls, sphere = simulation.define_boundary_indices() -simulation.instantiate_boundary_conditions(inlet, outlet, walls, sphere) -simulation.set_boundary_masks() -simulation.initialize_fields() -simulation.setup_stepper(omega=1.8) -simulation.run_simulation(num_steps=10000) -simulation.post_process(i=10000) + simulation = FlowOverSphere(grid_shape, velocity_set, backend, precision_policy) + simulation.setup_boundary_conditions() + simulation.set_boundary_masks() + simulation.initialize_fields() + simulation.setup_stepper(omega=1.8) + simulation.run_simulation(num_steps=10000) + simulation.post_process(i=10000) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 84ed1a7..08c6580 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -1,7 +1,8 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.helper import create_nse_fields, initialize_eq +from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC from xlb.operator.equilibrium import QuadraticEquilibrium @@ -10,70 +11,92 @@ import warp as wp import jax.numpy as jnp -backend = ComputeBackend.JAX -velocity_set = xlb.velocity_set.D2Q9() -precision_policy = PrecisionPolicy.FP32FP32 -xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, -) - -grid_size = 128 -grid_shape = (grid_size, grid_size) - -grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) - -# Velocity on top face (2D) -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["top"] -) - -bc_eq = QuadraticEquilibrium() - -bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), equilibrium_operator=bc_eq) - - -# Wall on all other faces (2D) -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, - missing_mask, - grid_shape, - FullwayBounceBackBC.id, - ["bottom", "left", "right"], -) - -bc_walls = FullwayBounceBackBC() - - -f_0 = initialize_eq(f_0, grid, velocity_set, backend) -boundary_conditions = [bc_top, bc_walls] -omega = 1.6 - -stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=boundary_conditions -) - -for i in range(500): - f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) - f_0, f_1 = f_1, f_0 - - -# Write the results. We'll use JAX backend for the post-processing -if not isinstance(f_0, jnp.ndarray): - f_0 = wp.to_jax(f_0) - -macro = Macroscopic(compute_backend=ComputeBackend.JAX) - -rho, u = macro(f_0) - -# remove boundary cells -rho = rho[:, 1:-1, 1:-1] -u = u[:, 1:-1, 1:-1] -u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5 - -fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} - -save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") -save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") +class LidDrivenCavity2D: + def __init__(self, grid_shape, velocity_set, backend, precision_policy): + + # initialize backend + xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, + ) + + self.grid_shape = grid_shape + self.velocity_set = velocity_set + self.backend = backend + self.precision_policy = precision_policy + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.stepper = None + self.boundary_conditions = [] + + def define_boundary_indices(self): + lid = self.grid.boundingBoxIndices['top'] + walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['left'][i] + + self.grid.boundingBoxIndices['right'][i] for i in range(self.velocity_set.d)] + return lid, walls + + def setup_boundary_conditions(self): + lid, walls = self.define_boundary_indices() + bc_top = EquilibriumBC(lid, rho=1.0, u=(0.02, 0.0), equilibrium_operator=QuadraticEquilibrium()) + bc_walls = FullwayBounceBackBC(walls) + self.boundary_conditions = [bc_top, bc_walls] + + def set_boundary_masks(self): + indices_boundary_masker = IndicesBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.backend, + ) + self.boundary_mask, self.missing_mask = indices_boundary_masker( + self.boundary_conditions, self.boundary_mask, self.missing_mask + ) + + def initialize_fields(self): + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + + def setup_stepper(self, omega): + self.stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=self.boundary_conditions + ) + + def run_simulation(self, num_steps): + for i in range(num_steps): + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.f_1, self.f_0 + + def post_process(self, i): + # Write the results. We'll use JAX backend for the post-processing + if not isinstance(self.f_0, jnp.ndarray): + self.f_0 = wp.to_jax(self.f_0) + + macro = Macroscopic(compute_backend=ComputeBackend.JAX) + + rho, u = macro(self.f_0) + + # remove boundary cells + rho = rho[:, 1:-1, 1:-1] + u = u[:, 1:-1, 1:-1] + u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5 + + fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} + + save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") + save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") + + +if __name__ == "__main__": + # Running the simulation + grid_size = 128 + grid_shape = (grid_size, grid_size) + backend = ComputeBackend.JAX + velocity_set = xlb.velocity_set.D2Q9() + precision_policy = PrecisionPolicy.FP32FP32 + + simulation = LidDrivenCavity2D(grid_shape, velocity_set, backend, precision_policy) + simulation.setup_boundary_conditions() + simulation.set_boundary_masks() + simulation.initialize_fields() + simulation.setup_stepper(omega=1.6) + simulation.run_simulation(num_steps=500) + simulation.post_process(i=500) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 3f1d50e..c974f6e 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -1,87 +1,38 @@ -from math import dist import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.macroscopic import Macroscopic -from xlb.utils import save_fields_vtk, save_image from xlb.distribute import distribute -import warp as wp -import jax.numpy as jnp - -backend = ComputeBackend.JAX -velocity_set = xlb.velocity_set.D2Q9() -precision_policy = PrecisionPolicy.FP32FP32 - -xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, -) - -grid_size = 512 -grid_shape = (grid_size, grid_size) - -grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) - -# Velocity on top face (2D) -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, - missing_mask, - grid_shape, - EquilibriumBC.id, - ["top"], -) - -# Wall on all other faces (2D) -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, - missing_mask, - grid_shape, - FullwayBounceBackBC.id, - ["bottom", "left", "right"], - backend=ComputeBackend.JAX, -) - -bc_eq = QuadraticEquilibrium(compute_backend=backend) - -bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), equilibrium_operator=bc_eq) - -bc_walls = FullwayBounceBackBC(compute_backend=backend) - - -f_0 = initialize_eq(f_0, grid, velocity_set, backend=ComputeBackend.JAX) -boundary_conditions = [bc_top, bc_walls] -omega = 1.6 - -stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=boundary_conditions -) -distributed_stepper = distribute( - stepper, grid, velocity_set, sharding_flags=(True, True, True, True, False) -) -for i in range(5000): - f_1 = distributed_stepper(f_0, f_1, boundary_mask, missing_mask, i) - f_0, f_1 = f_1, f_0 - - -# Write the results. We'll use JAX backend for the post-processing -if not isinstance(f_0, jnp.ndarray): - f_0 = wp.to_jax(f_0) - -macro = Macroscopic(compute_backend=ComputeBackend.JAX) - -rho, u = macro(f_0) - -# remove boundary cells -rho = rho[:, 1:-1, 1:-1] -u = u[:, 1:-1, 1:-1] -u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5 - -fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} - -save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") -save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity") +from lid_driven_cavity_2d import LidDrivenCavity2D + + +class LidDrivenCavity2D_distributed(LidDrivenCavity2D): + def __init__(self, grid_shape, velocity_set, backend, precision_policy): + super().__init__(grid_shape, velocity_set, backend, precision_policy) + + def setup_stepper(self, omega): + stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=self.boundary_conditions + ) + distributed_stepper = distribute( + stepper, self.grid, self.velocity_set, sharding_flags=(True, True, True, True, False) + ) + self.stepper = distributed_stepper + return + + +if __name__ == "__main__": + # Running the simulation + grid_size = 512 + grid_shape = (grid_size, grid_size) + backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! + velocity_set = xlb.velocity_set.D2Q9() + precision_policy = PrecisionPolicy.FP32FP32 + + simulation = LidDrivenCavity2D_distributed(grid_shape, velocity_set, backend, precision_policy) + simulation.setup_boundary_conditions() + simulation.set_boundary_masks() + simulation.initialize_fields() + simulation.setup_stepper(omega=1.6) + simulation.run_simulation(num_steps=5000) + simulation.post_process(i=5000) \ No newline at end of file diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 72bc854..a610306 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -3,7 +3,7 @@ import time from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, @@ -12,138 +12,148 @@ ) from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic +from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np import jax.numpy as jnp -# Configuration -backend = ComputeBackend.WARP -velocity_set = xlb.velocity_set.D3Q27() -precision_policy = PrecisionPolicy.FP32FP32 -grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 -prescribed_vel = 0.02 -Re = 50000.0 -max_iterations = 100000 -print_interval = 1000 - -xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, -) - -# Print simulation info -print("Simulation Configuration:") -print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") -print(f"Backend: {backend}") -print(f"Velocity set: {velocity_set}") -print(f"Precision policy: {precision_policy}") -print(f"Prescribed velocity: {prescribed_vel}") -print(f"Reynolds number: {Re}") -print(f"Max iterations: {max_iterations}") -print("\n" + "=" * 50 + "\n") - - -grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 -grid_shape = (grid_size_x, grid_size_y, grid_size_z) - -grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) - -# Velocity on left face (3D) -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["left"] -) - - -# Wall on all other faces (3D) except right -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, - missing_mask, - grid_shape, - FullwayBounceBackBC.id, - ["bottom", "right", "front", "back"], -) - -# Do nothing on right face -boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, missing_mask, grid_shape, DoNothingBC.id, ["right"] -) - -bc_eq = QuadraticEquilibrium() -bc_left = EquilibriumBC( - rho=1.0, u=(prescribed_vel, 0.0, 0.0), equilibrium_operator=bc_eq -) -bc_walls = FullwayBounceBackBC() -bc_do_nothing = DoNothingBC() - - -def voxelize_stl(stl_filename, length_lbm_unit): - mesh = trimesh.load_mesh(stl_filename, process=False) - length_phys_unit = mesh.extents.max() - pitch = length_phys_unit / length_lbm_unit - mesh_voxelized = mesh.voxelized(pitch=pitch) - mesh_matrix = mesh_voxelized.matrix - return mesh_matrix, pitch - - -stl_filename = "../stl-files/DrivAer-Notchback.stl" -car_length_lbm_unit = grid_size_x / 4 -car_voxelized, pitch = voxelize_stl(stl_filename, car_length_lbm_unit) - -car_area = np.prod(car_voxelized.shape[1:]) -tx, ty, tz = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape -shift = [tx // 4, ty // 2, 0] -indices = np.argwhere(car_voxelized) + shift - -indices = np.array(indices).T - -# Set boundary conditions on the indices -indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=backend, -) - -boundary_mask, missing_mask = indices_boundary_masker( - indices, FullwayBounceBackBC.id, boundary_mask, missing_mask, (0, 0, 0) -) - -f_0 = initialize_eq(f_0, grid, velocity_set, backend) -boundary_conditions = [bc_left, bc_walls, bc_do_nothing] - -clength = grid_size_x - 1 - -visc = prescribed_vel * clength / Re -omega = 1.0 / (3.0 * visc + 0.5) - -stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=boundary_conditions, collision_type="KBC" -) - -start_time = time.time() -for i in range(max_iterations): - f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) - f_0, f_1 = f_1, f_0 - - if (i + 1) % print_interval == 0: - elapsed_time = time.time() - start_time - print(f"Iteration: {i+1}/{max_iterations} | Time elapsed: {elapsed_time:.2f}s") - - -# Write the results. We'll use JAX backend for the post-processing -if not isinstance(f_0, jnp.ndarray): - f_0 = wp.to_jax(f_0) - -macro = Macroscopic(compute_backend=ComputeBackend.JAX) - -rho, u = macro(f_0) - -# remove boundary cells -u = u[:, 1:-1, 1:-1, 1:-1] -u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - -fields = {"u_magnitude": u_magnitude} -save_fields_vtk(fields, timestep=i) -save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) +class WindTunnel3D: + def __init__(self, grid_shape, velocity_set, backend, precision_policy): + + # initialize backend + xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, + ) + + self.grid_shape = grid_shape + self.velocity_set = velocity_set + self.backend = backend + self.precision_policy = precision_policy + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.stepper = None + self.boundary_conditions = [] + + def voxelize_stl(self, stl_filename, length_lbm_unit): + mesh = trimesh.load_mesh(stl_filename, process=False) + length_phys_unit = mesh.extents.max() + pitch = length_phys_unit / length_lbm_unit + mesh_voxelized = mesh.voxelized(pitch=pitch) + mesh_matrix = mesh_voxelized.matrix + return mesh_matrix, pitch + + def define_boundary_indices(self): + inlet = self.grid.boundingBoxIndices['left'] + outlet = self.grid.boundingBoxIndices['right'] + walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['top'][i] + + self.grid.boundingBoxIndices['front'][i] + self.grid.boundingBoxIndices['back'][i] for i in range(self.velocity_set.d)] + + stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" + grid_size_x = self.grid_shape[0] + car_length_lbm_unit = grid_size_x / 4 + car_voxelized, pitch = self.voxelize_stl(stl_filename, car_length_lbm_unit) + + car_area = np.prod(car_voxelized.shape[1:]) + tx, ty, tz = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape + shift = [tx // 4, ty // 2, 0] + car = np.argwhere(car_voxelized) + shift + car = np.array(car).T + car = [tuple(car[i]) for i in range(self.velocity_set.d)] + + return inlet, outlet, walls, car + + def setup_boundary_conditions(self, wind_speed): + inlet, outlet, walls, car = self.define_boundary_indices() + bc_left = EquilibriumBC(inlet, rho=1.0, u=(wind_speed, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) + bc_walls = FullwayBounceBackBC(walls) + bc_do_nothing = DoNothingBC(outlet) + bc_car= FullwayBounceBackBC(car) + self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_car] + + def set_boundary_masks(self): + indices_boundary_masker = IndicesBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.backend, + ) + self.boundary_mask, self.missing_mask = indices_boundary_masker( + self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0) + ) + + def initialize_fields(self): + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + + def setup_stepper(self, omega): + self.stepper = IncompressibleNavierStokesStepper( + omega, boundary_conditions=self.boundary_conditions, collision_type="KBC" + ) + + def run_simulation(self, num_steps, print_interval): + start_time = time.time() + for i in range(num_steps): + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.f_1, self.f_0 + + if (i + 1) % print_interval == 0: + elapsed_time = time.time() - start_time + print(f"Iteration: {i+1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") + + def post_process(self, i): + # Write the results. We'll use JAX backend for the post-processing + if not isinstance(self.f_0, jnp.ndarray): + f_0 = wp.to_jax(self.f_0) + + macro = Macroscopic(compute_backend=ComputeBackend.JAX) + + rho, u = macro(f_0) + + # remove boundary cells + u = u[:, 1:-1, 1:-1, 1:-1] + u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 + + fields = {"u_magnitude": u_magnitude} + + save_fields_vtk(fields, timestep=i) + save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) + + +if __name__ == "__main__": + # Grid parameters + grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 + grid_shape = (grid_size_x, grid_size_y, grid_size_z) + + # Configuration + backend = ComputeBackend.WARP + velocity_set = xlb.velocity_set.D3Q27() + precision_policy = PrecisionPolicy.FP32FP32 + wind_speed = 0.02 + num_steps = 100000 + print_interval = 1000 + + # Set up Reynolds number and deduce relaxation time (omega) + Re = 50000.0 + clength = grid_size_x - 1 + visc = wind_speed * clength / Re + omega = 1.0 / (3.0 * visc + 0.5) + + # Print simulation info + print("Simulation Configuration:") + print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") + print(f"Backend: {backend}") + print(f"Velocity set: {velocity_set}") + print(f"Precision policy: {precision_policy}") + print(f"Prescribed velocity: {wind_speed}") + print(f"Reynolds number: {Re}") + print(f"Max iterations: {num_steps}") + print("\n" + "=" * 50 + "\n") + + simulation = WindTunnel3D(grid_shape, velocity_set, backend, precision_policy) + simulation.setup_boundary_conditions(wind_speed) + simulation.set_boundary_masks() + simulation.initialize_fields() + simulation.setup_stepper(omega) + simulation.run_simulation(num_steps, print_interval) + simulation.post_process(i=num_steps) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 6131a0d..4257bf7 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -1,11 +1,10 @@ -from turtle import back import xlb import argparse import time import warp as wp from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq, assign_bc_id_box_faces +from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC from xlb.operator.equilibrium import QuadraticEquilibrium @@ -45,32 +44,26 @@ def create_grid_and_fields(cube_edge): grid_shape = (cube_edge, cube_edge, cube_edge) grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) - # Velocity on top face (3D) - boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, missing_mask, grid_shape, EquilibriumBC.id, ["top"] - ) - - # Wall on all other faces (3D) - boundary_mask, missing_mask = assign_bc_id_box_faces( - boundary_mask, - missing_mask, - grid_shape, - FullwayBounceBackBC.id, - ["bottom", "left", "right", "front", "back"], - ) - return grid, f_0, f_1, missing_mask, boundary_mask -def setup_boundary_conditions(): +def define_boundary_indices(grid): + lid = grid.boundingBoxIndices['top'] + walls = [grid.boundingBoxIndices['bottom'][i] + grid.boundingBoxIndices['left'][i] + + grid.boundingBoxIndices['right'][i] + grid.boundingBoxIndices['front'][i] + + grid.boundingBoxIndices['back'][i] for i in range(xlb.velocity_set.D3Q19().d)] + return lid, walls + +def setup_boundary_conditions(grid): + lid, walls = define_boundary_indices(grid) bc_eq = QuadraticEquilibrium() - bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=bc_eq) - bc_walls = FullwayBounceBackBC() + bc_top = EquilibriumBC(lid, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=bc_eq) + bc_walls = FullwayBounceBackBC(walls) return [bc_top, bc_walls] def run_simulation(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=setup_boundary_conditions() + omega, boundary_conditions=setup_boundary_conditions(grid) ) if backend == ComputeBackend.JAX: From f42f245ea88d732aa83c53ac07f0468b38d42002 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 16 Jul 2024 17:42:33 -0400 Subject: [PATCH 047/144] All pytest unit tests pass now! --- .../bc_equilibrium/test_bc_equilibrium_jax.py | 12 +- .../test_bc_equilibrium_warp.py | 10 +- .../test_bc_fullway_bounce_back_jax.py | 7 +- .../test_bc_fullway_bounce_back_warp.py | 10 +- .../mask/test_bc_indices_masker_jax.py | 13 +- .../mask/test_bc_indices_masker_warp.py | 15 +- .../mask/test_bc_planar_masker_jax.py | 129 ------------------ .../mask/test_bc_planar_masker_warp.py | 128 ----------------- .../collision/test_bgk_collision_jax.py | 2 +- .../collision/test_bgk_collision_warp.py | 2 +- .../indices_boundary_masker.py | 10 -- xlb/operator/collision/bgk.py | 2 +- 12 files changed, 35 insertions(+), 305 deletions(-) delete mode 100644 tests/boundary_conditions/mask/test_bc_planar_masker_jax.py delete mode 100644 tests/boundary_conditions/mask/test_bc_planar_masker_warp.py diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index e937295..c22674b 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -5,6 +5,7 @@ from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig +from xlb.operator.boundary_masker import IndicesBoundaryMasker def init_xlb_env(velocity_set): xlb.init( @@ -34,7 +35,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + indices_boundary_masker = IndicesBoundaryMasker() # Make indices for boundary conditions (sphere) sphere_radius = grid_shape[0] // 4 @@ -52,18 +53,17 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): < sphere_radius**2 ) - indices = jnp.array(indices) - - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() + indices = [tuple(indices[i]) for i in range(velocity_set.d)] equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + indices, rho=1.0, u=(0.0, 0.0, 0.0) if dim == 3 else (0.0, 0.0), - equilibrium_operator=equilibrium, + equilibrium_operator=xlb.operator.equilibrium.QuadraticEquilibrium(), ) boundary_mask, missing_mask = indices_boundary_masker( - indices, equilibrium_bc.id, boundary_mask, missing_mask, start_index=None + [equilibrium_bc], boundary_mask, missing_mask, start_index=None ) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index a94fcbe..f57e5eb 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -5,6 +5,7 @@ from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig +from xlb.operator.boundary_masker import IndicesBoundaryMasker def init_xlb_env(velocity_set): xlb.init( @@ -34,7 +35,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + indices_boundary_masker = IndicesBoundaryMasker() # Make indices for boundary conditions (sphere) sphere_radius = grid_shape[0] // 4 @@ -52,18 +53,18 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): < sphere_radius**2 ) - indices = wp.array(indices, dtype=wp.int32) - + indices = [tuple(indices[i]) for i in range(velocity_set.d)] equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + indices, rho=1.0, u=(0.0, 0.0, 0.0) if dim == 3 else (0.0, 0.0), equilibrium_operator=equilibrium, ) boundary_mask, missing_mask = indices_boundary_masker( - indices, equilibrium_bc.id, boundary_mask, missing_mask, start_index=None + [equilibrium_bc], boundary_mask, missing_mask, start_index=None ) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -76,7 +77,6 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): f = f.numpy() f_post = f_post.numpy() - indices = indices.numpy() assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 6ec3edc..40c5528 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -35,8 +35,6 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) - fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC() - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -57,10 +55,11 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): < sphere_radius**2 ) - indices = jnp.array(indices) + indices = [tuple(indices[i]) for i in range(velocity_set.d)] + fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) boundary_mask, missing_mask = indices_boundary_masker( - indices, fullway_bc.id, boundary_mask, missing_mask, start_index=None + [fullway_bc], boundary_mask, missing_mask, start_index=None ) f_pre = my_grid.create_field( diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index a9b8c11..c994cbb 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -35,8 +35,6 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.BOOL ) - fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC() - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -57,10 +55,11 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): < sphere_radius**2 ) - indices = wp.array(indices, dtype=wp.int32) - + indices = [tuple(indices[i]) for i in range(velocity_set.d)] + fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + boundary_mask, missing_mask = indices_boundary_masker( - indices, fullway_bc.id, boundary_mask, missing_mask, start_index=None + [fullway_bc], boundary_mask, missing_mask, start_index=None ) # Generate a random field with the same shape @@ -76,7 +75,6 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): f = f_pre.numpy() f_post = f_post.numpy() - indices = indices.numpy() assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index 0480605..fc4043c 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -56,12 +56,13 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): < sphere_radius**2 ) - indices = jnp.array(indices) + indices = [tuple(indices[i]) for i in range(velocity_set.d)] - assert indices.shape[0] == dim - test_id = 5 + assert len(indices) == dim + test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + test_bc.id = 5 boundary_mask, missing_mask = indices_boundary_masker( - indices, test_id, boundary_mask, missing_mask, start_index=None + [test_bc], boundary_mask, missing_mask, start_index=None ) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype @@ -73,13 +74,13 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert jnp.all(boundary_mask[0, indices[0], indices[1]] == test_id) + assert jnp.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) # assert that the rest of the boundary_mask is zero boundary_mask = boundary_mask.at[0, indices[0], indices[1]].set(0) assert jnp.all(boundary_mask == 0) if dim == 3: assert jnp.all( - boundary_mask[0, indices[0], indices[1], indices[2]] == test_id + boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id ) # assert that the rest of the boundary_mask is zero boundary_mask = boundary_mask.at[ diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 782551c..be0ce40 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -55,13 +55,13 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): < sphere_radius**2 ) - indices = wp.array(indices, dtype=wp.int32) + indices = [tuple(indices[i]) for i in range(velocity_set.d)] - assert indices.shape[0] == dim - test_id = 5 + assert len(indices) == dim + test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + test_bc.id = 5 boundary_mask, missing_mask = indices_boundary_masker( - indices, - test_id, + [test_bc], boundary_mask, missing_mask, start_index=(0, 0, 0) if dim == 3 else (0, 0), @@ -72,20 +72,19 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): boundary_mask = boundary_mask.numpy() missing_mask = missing_mask.numpy() - indices = indices.numpy() assert boundary_mask.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert np.all(boundary_mask[0, indices[0], indices[1]] == test_id) + assert np.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) # assert that the rest of the boundary_mask is zero boundary_mask[0, indices[0], indices[1]]= 0 assert np.all(boundary_mask == 0) if dim == 3: assert np.all( - boundary_mask[0, indices[0], indices[1], indices[2]] == test_id + boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id ) # assert that the rest of the boundary_mask is zero boundary_mask[0, indices[0], indices[1], indices[2]] = 0 diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py b/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py deleted file mode 100644 index 96a382e..0000000 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_jax.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import jax.numpy as jnp -import xlb -from xlb.compute_backend import ComputeBackend -from xlb import DefaultConfig -from xlb.grid import grid_factory - - -def init_xlb_env(velocity_set): - xlb.init( - default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), - ) - - -@pytest.mark.parametrize( - "dim,velocity_set,grid_shape,lower_bound,upper_bound,direction", - [ - # 2D Grids - Different directions - ( - 2, - xlb.velocity_set.D2Q9, - (4, 4), - (0, 0), - (2, 4), - (1, 0), - ), # Horizontal direction - ( - 2, - xlb.velocity_set.D2Q9, - (50, 50), - (0, 0), - (50, 25), - (0, 1), - ), # Vertical direction - ( - 2, - xlb.velocity_set.D2Q9, - (100, 100), - (50, 0), - (100, 50), - (0, 1), - ), # Vertical direction - # 3D Grids - Different directions - ( - 3, - xlb.velocity_set.D3Q19, - (50, 50, 50), - (0, 0, 0), - (25, 50, 50), - (1, 0, 0), - ), # Along x-axis - ( - 3, - xlb.velocity_set.D3Q19, - (100, 100, 100), - (0, 50, 0), - (50, 100, 100), - (0, 1, 0), - ), # Along y-axis - ( - 3, - xlb.velocity_set.D3Q27, - (50, 50, 50), - (0, 0, 0), - (50, 25, 50), - (0, 0, 1), - ), # Along z-axis - ( - 3, - xlb.velocity_set.D3Q27, - (100, 100, 100), - (0, 0, 0), - (50, 100, 50), - (1, 0, 0), - ), # Along x-axis - ], -) -def test_planar_masker_jax( - dim, velocity_set, grid_shape, lower_bound, upper_bound, direction -): - init_xlb_env(velocity_set) - my_grid = grid_factory(grid_shape) - velocity_set = DefaultConfig.velocity_set - - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) - - fill_value = 0 - boundary_mask = my_grid.create_field( - cardinality=1, dtype=xlb.Precision.UINT8, fill_value=fill_value - ) - - planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() - - start_index = (0,) * dim - id_number = 1 - - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - id_number, - boundary_mask, - missing_mask, - start_index, - ) - - # Assert that the boundary condition is set on the left side of the domain based on the lower and upper bounds - expected_slice = (slice(None),) + tuple( - slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) - ) - assert jnp.all( - boundary_mask[expected_slice] == id_number - ), "Boundary not set correctly" - - # Assert that the rest of the domain is not affected and is equal to fill_value - full_slice = tuple(slice(None) for _ in grid_shape) - mask = jnp.ones_like(boundary_mask, dtype=bool) - mask = mask.at[expected_slice].set(False) - assert jnp.all( - boundary_mask[full_slice][mask] == fill_value - ), "Rest of domain incorrectly affected" - - -if __name__ == "__main__": - pytest.main() diff --git a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py b/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py deleted file mode 100644 index deee70b..0000000 --- a/tests/boundary_conditions/mask/test_bc_planar_masker_warp.py +++ /dev/null @@ -1,128 +0,0 @@ -import pytest -import numpy as np -import xlb -import warp as wp - -from xlb.compute_backend import ComputeBackend -from xlb import DefaultConfig -from xlb.grid import grid_factory - - -def init_xlb_env(velocity_set): - xlb.init( - default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), - ) - - -@pytest.mark.parametrize( - "dim,velocity_set,grid_shape,lower_bound,upper_bound,direction", - [ - # 2D Grids - Different directions - ( - 2, - xlb.velocity_set.D2Q9, - (4, 4), - (0, 0), - (2, 4), - (1, 0), - ), # Horizontal direction - ( - 2, - xlb.velocity_set.D2Q9, - (50, 50), - (0, 0), - (50, 25), - (0, 1), - ), # Vertical direction - ( - 2, - xlb.velocity_set.D2Q9, - (100, 100), - (50, 0), - (100, 50), - (0, 1), - ), # Vertical direction - # 3D Grids - Different directions - ( - 3, - xlb.velocity_set.D3Q19, - (50, 50, 50), - (0, 0, 0), - (25, 50, 50), - (1, 0, 0), - ), # Along x-axis - ( - 3, - xlb.velocity_set.D3Q19, - (100, 100, 100), - (0, 50, 0), - (50, 100, 100), - (0, 1, 0), - ), # Along y-axis - ( - 3, - xlb.velocity_set.D3Q27, - (50, 50, 50), - (0, 0, 0), - (50, 25, 50), - (0, 0, 1), - ), # Along z-axis - ( - 3, - xlb.velocity_set.D3Q27, - (100, 100, 100), - (0, 0, 0), - (50, 100, 50), - (1, 0, 0), - ), # Along x-axis - ], -) -def test_planar_masker_warp( - dim, velocity_set, grid_shape, lower_bound, upper_bound, direction -): - init_xlb_env(velocity_set) - my_grid = grid_factory(grid_shape) - velocity_set = DefaultConfig.velocity_set - - # Create required fields - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) - fill_value = 0 - boundary_mask = my_grid.create_field( - cardinality=1, dtype=xlb.Precision.UINT8, fill_value=fill_value - ) - - planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker() - start_index = (0,) * dim - id_number = 1 - - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - id_number, - boundary_mask, - missing_mask, - start_index, - ) - - boundary_mask = boundary_mask.numpy() - - # Assertions to verify boundary settings - expected_slice = (slice(None),) + tuple( - slice(lb, ub) for lb, ub in zip(lower_bound, upper_bound) - ) - assert np.all( - boundary_mask[expected_slice] == id_number - ), "Boundary not set correctly" - - # Assertions for non-affected areas - full_slice = tuple(slice(None) for _ in grid_shape) - mask = np.ones_like(boundary_mask, dtype=bool) - mask[expected_slice] = False - assert np.all( - boundary_mask[full_slice][mask] == fill_value - ), "Rest of domain incorrectly affected" diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 76d7233..cce5ca4 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -43,7 +43,7 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega): f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) - f_out = compute_collision(f_orig, f_eq) + f_out = compute_collision(f_orig, f_eq, rho, u) assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq)) diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index f0fff4a..7509c1d 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -45,7 +45,7 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) - f_out = compute_collision(f_orig, f_eq, f_out) + f_out = compute_collision(f_orig, f_eq, f_out, rho, u) f_eq = f_eq.numpy() f_out = f_out.numpy() diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 80f0c02..d8b1d39 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -1,15 +1,5 @@ -# Base class for all equilibriums - -from functools import partial import numpy as np -import jax.numpy as jnp -from jax import jit, lax import warp as wp -from typing import Tuple - -from xlb import DefaultConfig -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from xlb.operator.stream.stream import Stream diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index deb5fd0..9dbfabd 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -55,7 +55,7 @@ def kernel2d( _feq[l] = feq[l, index[0], index[1]] # Compute the collision - _fout = functional(_f, _feq) + _fout = functional(_f, _feq, rho, u) # Write the result for l in range(self.velocity_set.q): From 83abb384bbd8ac0e45c969822fbb160f91bae67a Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 19 Jul 2024 17:20:59 -0400 Subject: [PATCH 048/144] addressing comments during the PR review and more clean-up --- examples/cfd/flow_past_sphere_3d.py | 33 +++++++++++-------- examples/cfd/lid_driven_cavity_2d.py | 29 +++++++++------- .../cfd/lid_driven_cavity_2d_distributed.py | 13 +++----- examples/cfd/windtunnel_3d.py | 33 +++++++++++-------- examples/performance/mlups_3d.py | 10 +++--- .../bc_equilibrium/test_bc_equilibrium_jax.py | 2 +- .../test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_jax.py | 2 +- .../test_bc_fullway_bounce_back_warp.py | 2 +- .../mask/test_bc_indices_masker_jax.py | 2 +- .../mask/test_bc_indices_masker_warp.py | 2 +- .../boundary_condition/bc_do_nothing.py | 4 +-- .../boundary_condition/bc_equilibrium.py | 11 ++++--- .../bc_fullway_bounce_back.py | 4 +-- .../bc_halfway_bounce_back.py | 4 +-- .../boundary_condition/boundary_condition.py | 2 +- .../indices_boundary_masker.py | 6 ++++ xlb/operator/stepper/stepper.py | 4 --- 18 files changed, 89 insertions(+), 76 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2ea71a8..a1682b8 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -8,7 +8,6 @@ EquilibriumBC, DoNothingBC, ) -from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image @@ -17,7 +16,7 @@ import jax.numpy as jnp class FlowOverSphere: - def __init__(self, grid_shape, velocity_set, backend, precision_policy): + def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): # initialize backend xlb.init( @@ -33,6 +32,15 @@ def __init__(self, grid_shape, velocity_set, backend, precision_policy): self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] + + # Setup the simulation BC, its initial conditions, and the stepper + self._setup(omega) + + def _setup(self, omega): + self.setup_boundary_conditions() + self.setup_boundary_masks() + self.initialize_fields() + self.setup_stepper(omega) def define_boundary_indices(self): inlet = self.grid.boundingBoxIndices['left'] @@ -57,13 +65,13 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = EquilibriumBC(inlet, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) - bc_walls = FullwayBounceBackBC(walls) - bc_do_nothing = DoNothingBC(outlet) - bc_sphere = FullwayBounceBackBC(sphere) + bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=inlet) + bc_walls = FullwayBounceBackBC(indices=walls) + bc_do_nothing = DoNothingBC(indices=outlet) + bc_sphere = FullwayBounceBackBC(indices=sphere) self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_sphere] - def set_boundary_masks(self): + def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, @@ -81,7 +89,7 @@ def setup_stepper(self, omega): omega, boundary_conditions=self.boundary_conditions ) - def run_simulation(self, num_steps): + def run(self, num_steps): for i in range(num_steps): self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 @@ -110,11 +118,8 @@ def post_process(self, i): velocity_set = xlb.velocity_set.D3Q19() backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + omega = 1.6 - simulation = FlowOverSphere(grid_shape, velocity_set, backend, precision_policy) - simulation.setup_boundary_conditions() - simulation.set_boundary_masks() - simulation.initialize_fields() - simulation.setup_stepper(omega=1.8) - simulation.run_simulation(num_steps=10000) + simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) + simulation.run(num_steps=10000) simulation.post_process(i=10000) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 08c6580..eac2909 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -5,7 +5,6 @@ from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC -from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp @@ -13,7 +12,7 @@ class LidDrivenCavity2D: - def __init__(self, grid_shape, velocity_set, backend, precision_policy): + def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): # initialize backend xlb.init( @@ -30,6 +29,15 @@ def __init__(self, grid_shape, velocity_set, backend, precision_policy): self.stepper = None self.boundary_conditions = [] + # Setup the simulation BC, its initial conditions, and the stepper + self._setup(omega) + + def _setup(self, omega): + self.setup_boundary_conditions() + self.setup_boundary_masks() + self.initialize_fields() + self.setup_stepper(omega) + def define_boundary_indices(self): lid = self.grid.boundingBoxIndices['top'] walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['left'][i] + @@ -38,11 +46,11 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): lid, walls = self.define_boundary_indices() - bc_top = EquilibriumBC(lid, rho=1.0, u=(0.02, 0.0), equilibrium_operator=QuadraticEquilibrium()) - bc_walls = FullwayBounceBackBC(walls) + bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid) + bc_walls = FullwayBounceBackBC(indices=walls) self.boundary_conditions = [bc_top, bc_walls] - def set_boundary_masks(self): + def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, @@ -60,7 +68,7 @@ def setup_stepper(self, omega): omega, boundary_conditions=self.boundary_conditions ) - def run_simulation(self, num_steps): + def run(self, num_steps): for i in range(num_steps): self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 @@ -92,11 +100,8 @@ def post_process(self, i): backend = ComputeBackend.JAX velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 + omega = 1.6 - simulation = LidDrivenCavity2D(grid_shape, velocity_set, backend, precision_policy) - simulation.setup_boundary_conditions() - simulation.set_boundary_masks() - simulation.initialize_fields() - simulation.setup_stepper(omega=1.6) - simulation.run_simulation(num_steps=500) + simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy) + simulation.run(num_steps=500) simulation.post_process(i=500) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index c974f6e..72c72a2 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -7,8 +7,8 @@ class LidDrivenCavity2D_distributed(LidDrivenCavity2D): - def __init__(self, grid_shape, velocity_set, backend, precision_policy): - super().__init__(grid_shape, velocity_set, backend, precision_policy) + def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): + super().__init__(omega, grid_shape, velocity_set, backend, precision_policy) def setup_stepper(self, omega): stepper = IncompressibleNavierStokesStepper( @@ -28,11 +28,8 @@ def setup_stepper(self, omega): backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 + omega=1.6 - simulation = LidDrivenCavity2D_distributed(grid_shape, velocity_set, backend, precision_policy) - simulation.setup_boundary_conditions() - simulation.set_boundary_masks() - simulation.initialize_fields() - simulation.setup_stepper(omega=1.6) - simulation.run_simulation(num_steps=5000) + simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy) + simulation.run(num_steps=5000) simulation.post_process(i=5000) \ No newline at end of file diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index a610306..49e26f7 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -10,7 +10,6 @@ EquilibriumBC, DoNothingBC, ) -from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image @@ -20,7 +19,7 @@ class WindTunnel3D: - def __init__(self, grid_shape, velocity_set, backend, precision_policy): + def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy): # initialize backend xlb.init( @@ -37,6 +36,15 @@ def __init__(self, grid_shape, velocity_set, backend, precision_policy): self.stepper = None self.boundary_conditions = [] + # Setup the simulation BC, its initial conditions, and the stepper + self._setup(omega, wind_speed) + + def _setup(self, omega, wind_speed): + self.setup_boundary_conditions(wind_speed) + self.setup_boundary_masks() + self.initialize_fields() + self.setup_stepper(omega) + def voxelize_stl(self, stl_filename, length_lbm_unit): mesh = trimesh.load_mesh(stl_filename, process=False) length_phys_unit = mesh.extents.max() @@ -67,13 +75,13 @@ def define_boundary_indices(self): def setup_boundary_conditions(self, wind_speed): inlet, outlet, walls, car = self.define_boundary_indices() - bc_left = EquilibriumBC(inlet, rho=1.0, u=(wind_speed, 0.0, 0.0), equilibrium_operator=QuadraticEquilibrium()) - bc_walls = FullwayBounceBackBC(walls) - bc_do_nothing = DoNothingBC(outlet) - bc_car= FullwayBounceBackBC(car) + bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) + bc_walls = FullwayBounceBackBC(indices=walls) + bc_do_nothing = DoNothingBC(indices=outlet) + bc_car= FullwayBounceBackBC(indices=car) self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_car] - def set_boundary_masks(self): + def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, @@ -91,7 +99,7 @@ def setup_stepper(self, omega): omega, boundary_conditions=self.boundary_conditions, collision_type="KBC" ) - def run_simulation(self, num_steps, print_interval): + def run(self, num_steps, print_interval): start_time = time.time() for i in range(num_steps): self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) @@ -140,6 +148,7 @@ def post_process(self, i): omega = 1.0 / (3.0 * visc + 0.5) # Print simulation info + print("\n" + "=" * 50 + "\n") print("Simulation Configuration:") print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") print(f"Backend: {backend}") @@ -150,10 +159,6 @@ def post_process(self, i): print(f"Max iterations: {num_steps}") print("\n" + "=" * 50 + "\n") - simulation = WindTunnel3D(grid_shape, velocity_set, backend, precision_policy) - simulation.setup_boundary_conditions(wind_speed) - simulation.set_boundary_masks() - simulation.initialize_fields() - simulation.setup_stepper(omega) - simulation.run_simulation(num_steps, print_interval) + simulation = WindTunnel3D(omega, wind_speed, grid_shape, velocity_set, backend, precision_policy) + simulation.run(num_steps, print_interval) simulation.post_process(i=num_steps) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 4257bf7..b8f238c 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -7,7 +7,6 @@ from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC -from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.distribute import distribute def parse_arguments(): @@ -55,12 +54,11 @@ def define_boundary_indices(grid): def setup_boundary_conditions(grid): lid, walls = define_boundary_indices(grid) - bc_eq = QuadraticEquilibrium() - bc_top = EquilibriumBC(lid, rho=1.0, u=(0.02, 0.0, 0.0), equilibrium_operator=bc_eq) - bc_walls = FullwayBounceBackBC(walls) + bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid) + bc_walls = FullwayBounceBackBC(indices=walls) return [bc_top, bc_walls] -def run_simulation(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): +def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper( omega, boundary_conditions=setup_boundary_conditions(grid) @@ -92,7 +90,7 @@ def main(): grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run_simulation(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) + elapsed_time = run(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index c22674b..9017a9c 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -56,10 +56,10 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - indices, rho=1.0, u=(0.0, 0.0, 0.0) if dim == 3 else (0.0, 0.0), equilibrium_operator=xlb.operator.equilibrium.QuadraticEquilibrium(), + indices=indices, ) boundary_mask, missing_mask = indices_boundary_masker( diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index f57e5eb..7bb78cf 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -57,10 +57,10 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - indices, rho=1.0, u=(0.0, 0.0, 0.0) if dim == 3 else (0.0, 0.0), equilibrium_operator=equilibrium, + indices=indices, ) boundary_mask, missing_mask = indices_boundary_masker( diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 40c5528..b6ce4c3 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -56,7 +56,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): ) indices = [tuple(indices[i]) for i in range(velocity_set.d)] - fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) boundary_mask, missing_mask = indices_boundary_masker( [fullway_bc], boundary_mask, missing_mask, start_index=None diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index c994cbb..3f8f0d0 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -56,7 +56,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): ) indices = [tuple(indices[i]) for i in range(velocity_set.d)] - fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) boundary_mask, missing_mask = indices_boundary_masker( [fullway_bc], boundary_mask, missing_mask, start_index=None diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index fc4043c..0de8805 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -59,7 +59,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] assert len(indices) == dim - test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 boundary_mask, missing_mask = indices_boundary_masker( [test_bc], boundary_mask, missing_mask, start_index=None diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index be0ce40..43911f6 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -58,7 +58,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] assert len(indices) == dim - test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices) + test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 boundary_mask, missing_mask = indices_boundary_masker( [test_bc], diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 1938207..38697c3 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -31,17 +31,17 @@ class DoNothingBC(BoundaryCondition): def __init__( self, - indices: List[int], velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, + indices = None, ): super().__init__( - indices, ImplementationStep.STREAMING, velocity_set, precision_policy, compute_backend, + indices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 2ac632b..0af61e1 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -7,12 +7,13 @@ import jax.lax as lax from functools import partial import warp as wp -from typing import Tuple, Any, List +from typing import Tuple, Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.operator import Operator from xlb.operator.boundary_condition.boundary_condition import ( ImplementationStep, @@ -32,29 +33,29 @@ class EquilibriumBC(BoundaryCondition): def __init__( self, - indices: List[int], rho: float, u: Tuple[float, float, float], - equilibrium_operator: Operator, + equilibrium_operator : Operator = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, + indices = None, ): # Store the equilibrium information self.rho = rho self.u = u - self.equilibrium_operator = equilibrium_operator + self.equilibrium_operator = QuadraticEquilibrium() if equilibrium_operator is None else equilibrium_operator # Raise error if equilibrium operator is not a subclass of Equilibrium if not issubclass(type(self.equilibrium_operator), Equilibrium): raise ValueError("Equilibrium operator must be a subclass of Equilibrium") # Call the parent constructor super().__init__( - indices, ImplementationStep.STREAMING, velocity_set, precision_policy, compute_backend, + indices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index f528634..a445e07 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -30,17 +30,17 @@ class FullwayBounceBackBC(BoundaryCondition): def __init__( self, - indices: List[int], velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, + indices = None, ): super().__init__( - indices, ImplementationStep.COLLISION, velocity_set, precision_policy, compute_backend, + indices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index c4f323e..55472a3 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -34,18 +34,18 @@ class HalfwayBounceBackBC(BoundaryCondition): def __init__( self, - indices: List[int], velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, + indices = None, ): # Call the parent constructor super().__init__( - indices, ImplementationStep.STREAMING, velocity_set, precision_policy, compute_backend, + indices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index ec6ba1e..dbeadbc 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -28,11 +28,11 @@ class BoundaryCondition(Operator): def __init__( self, - indices: List[int], implementation_step: ImplementationStep, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, + indices = None, ): velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index d8b1d39..460bd3b 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -51,9 +51,12 @@ def compute_boundary_id_and_mask(boundary_mask, mask): start_index = (0,) * dim for bc in bclist: + assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC!' id_number = bc.id local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] boundary_mask, mask = compute_boundary_id_and_mask(boundary_mask, mask) + # We are done with bc.indices. Remove them from BC objects + bc.__dict__.pop('indices', None) mask = self.stream(mask) return boundary_mask, mask @@ -151,9 +154,12 @@ def warp_implementation( index_list = [[] for _ in range(dim)] id_list = [] for bc in bclist: + assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) + # We are done with bc.indices. Remove them from BC objects + bc.__dict__.pop('indices', None) indices = wp.array2d(index_list, dtype = wp.int32) id_number = wp.array1d(id_list, dtype = wp.uint8) diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 6af51c2..c11b39b 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -91,7 +91,6 @@ def __init__(self, operators, boundary_conditions): if self.equilibrium_bc is None: # Select the equilibrium operator based on its type self.equilibrium_bc = EquilibriumBC( - [], rho=1.0, u=(0.0, 0.0, 0.0), equilibrium_operator=next( @@ -103,21 +102,18 @@ def __init__(self, operators, boundary_conditions): ) if self.do_nothing_bc is None: self.do_nothing_bc = DoNothingBC( - [], velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, ) if self.halfway_bounce_back_bc is None: self.halfway_bounce_back_bc = HalfwayBounceBackBC( - [], velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, ) if self.fullway_bounce_back_bc is None: self.fullway_bounce_back_bc = FullwayBounceBackBC( - [], velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, From 6f7397eac196021b67fcc8b1745c1b13c74453ea Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Mon, 29 Jul 2024 16:58:47 -0400 Subject: [PATCH 049/144] Removed the need to pass sharding_flag in for distributed workload --- .../cfd/lid_driven_cavity_2d_distributed.py | 2 +- xlb/distribute/distribute.py | 70 +++++++++++-------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 72c72a2..e71da52 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -15,7 +15,7 @@ def setup_stepper(self, omega): omega, boundary_conditions=self.boundary_conditions ) distributed_stepper = distribute( - stepper, self.grid, self.velocity_set, sharding_flags=(True, True, True, True, False) + stepper, self.grid, self.velocity_set, ) self.stepper = distributed_stepper return diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index bb72072..ad07dd0 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -14,7 +14,6 @@ def distribute( operator: Operator, grid, velocity_set, - sharding_flags: Tuple[bool, ...], num_results=1, ops="permute", ) -> Operator: @@ -27,47 +26,62 @@ def _sharded_operator(*args): rightPerm = [(i, (i + 1) % grid.nDevices) for i in range(grid.nDevices)] leftPerm = [((i + 1) % grid.nDevices, i) for i in range(grid.nDevices)] - right_comm = lax.ppermute( + left_comm, right_comm = ( result[velocity_set.right_indices, :1, ...], + result[velocity_set.left_indices, -1:, ...], + ) + + left_comm = lax.ppermute( + left_comm, perm=rightPerm, axis_name="x", ) - left_comm = lax.ppermute( - result[velocity_set.left_indices, -1:, ...], + right_comm = lax.ppermute( + right_comm, perm=leftPerm, axis_name="x", ) - result = result.at[velocity_set.right_indices, :1, ...].set( - right_comm - ) - result = result.at[velocity_set.left_indices, -1:, ...].set( - left_comm - ) + result = result.at[velocity_set.right_indices, :1, ...].set(left_comm) + result = result.at[velocity_set.left_indices, -1:, ...].set(right_comm) return result else: raise NotImplementedError(f"Operation {ops} not implemented") - in_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() - for flag in sharding_flags - ) - out_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) - ) + # Build sharding_flags and in_specs based on args + def build_specs(grid, *args): + sharding_flags = [] + in_specs = [] + for arg in args: + if arg.shape[1:] == grid.shape: + sharding_flags.append(True) + else: + sharding_flags.append(False) + + in_specs = tuple( + P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() + for flag in sharding_flags + ) + out_specs = tuple( + P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) + ) + return tuple(sharding_flags), in_specs, out_specs + + def _wrapped_operator(*args): + sharding_flags, in_specs, out_specs = build_specs(grid, *args) - if len(out_specs) == 1: - out_specs = out_specs[0] + if len(out_specs) == 1: + out_specs = out_specs[0] - distributed_operator = shard_map( - _sharded_operator, - mesh=grid.global_mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - ) - distributed_operator = jit(distributed_operator) + distributed_operator = shard_map( + _sharded_operator, + mesh=grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + return distributed_operator(*args) - return distributed_operator + return jit(_wrapped_operator) From e2616bd876b4ac181ed5200401bf3939170e5c56 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 30 Jul 2024 00:04:49 -0400 Subject: [PATCH 050/144] somewhat improved bc handling using structs --- xlb/operator/stepper/nse_stepper.py | 61 ++++++++++++++------ xlb/operator/stepper/stepper.py | 86 ++++++++--------------------- 2 files changed, 68 insertions(+), 79 deletions(-) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 11b5615..29ad088 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -95,11 +95,15 @@ def _construct_warp(self): self.velocity_set.q, dtype=wp.uint8 ) # TODO fix vec bool - # Get the boundary condition ids - _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) - _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) - _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) - _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) + @wp.struct + class BoundaryConditionIDStruct: + # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning + # One needs to manually add the names of additional BC's as they are added. + # TODO: Anyway to improve this + id_EquilibriumBC: wp.uint8 + id_DoNothingBC: wp.uint8 + id_HalfwayBounceBackBC: wp.uint8 + id_FullwayBounceBackBC: wp.uint8 @wp.kernel def kernel2d( @@ -107,6 +111,7 @@ def kernel2d( f_1: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -124,20 +129,20 @@ def kernel2d( _missing_mask[l] = wp.uint8(0) # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC: # Regular streaming f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + elif _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional( f_0, _missing_mask, index @@ -158,7 +163,7 @@ def kernel2d( ) # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -177,6 +182,7 @@ def kernel3d( f_1: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -194,20 +200,20 @@ def kernel3d( _missing_mask[l] = wp.uint8(0) # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC: # Regular streaming f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + elif _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional( f_0, _missing_mask, index ) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional( f_0, _missing_mask, index @@ -223,7 +229,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -238,10 +244,32 @@ def kernel3d( # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel + return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + + # Get the boundary condition ids + from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + bc_to_id = boundary_condition_registry.bc_to_id + + bc_struct = self.warp_functional() + bc_attribute_list = [] + for bc in self.boundary_conditions: + # Setting the Struct attributes based on the BC class names + attribute_str = bc.__class__.__name__ + setattr(bc_struct, 'id_' + attribute_str, bc_to_id[attribute_str]) + bc_attribute_list.append('id_' + attribute_str) + + # Unused attributes of the struct are set to inernal (id=0) + ll = vars(bc_struct) + for var in ll: + if var not in bc_attribute_list and not var.startswith('_'): + # set unassigned boundaries to the maximum integer in uint8 + attribute_str = bc.__class__.__name__ + setattr(bc_struct, var, 255) + + # Launch the warp kernel wp.launch( self.warp_kernel, @@ -250,6 +278,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_1, boundary_mask, missing_mask, + bc_struct, timestep, ], dim=f_0.shape[1:], diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c11b39b..fca088e 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,17 +1,5 @@ # Base class for all stepper operators - -from ast import Raise -from functools import partial -import jax.numpy as jnp -from jax import jit -import warp as wp - -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.precision_caster import PrecisionCaster -from xlb.operator.equilibrium import Equilibrium from xlb import DefaultConfig @@ -59,65 +47,37 @@ def __init__(self, operators, boundary_conditions): ) # Add boundary conditions - # Warp cannot handle lists of functions currently - # Because of this we manually unpack the boundary conditions ############################################ + # Warp cannot handle lists of functions currently # TODO: Fix this later ############################################ from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import ( - HalfwayBounceBackBC, - ) - from xlb.operator.boundary_condition.bc_fullway_bounce_back import ( - FullwayBounceBackBC, - ) + from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + + + # Define a list of tuples with attribute names and their corresponding classes + conditions = [ + ("equilibrium_bc", EquilibriumBC), + ("do_nothing_bc", DoNothingBC), + ("halfway_bounce_back_bc", HalfwayBounceBackBC), + ("fullway_bounce_back_bc", FullwayBounceBackBC), + ] + + # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. + bc_fallback = boundary_conditions[0] - self.equilibrium_bc = None - self.do_nothing_bc = None - self.halfway_bounce_back_bc = None - self.fullway_bounce_back_bc = None + # Iterate over each boundary condition + for attr_name, bc_class in conditions: + for bc in boundary_conditions: + if isinstance(bc, bc_class): + setattr(self, attr_name, bc) + break + elif not hasattr(self, attr_name): + setattr(self, attr_name, bc_fallback) - for bc in boundary_conditions: - if isinstance(bc, EquilibriumBC): - self.equilibrium_bc = bc - elif isinstance(bc, DoNothingBC): - self.do_nothing_bc = bc - elif isinstance(bc, HalfwayBounceBackBC): - self.halfway_bounce_back_bc = bc - elif isinstance(bc, FullwayBounceBackBC): - self.fullway_bounce_back_bc = bc - if self.equilibrium_bc is None: - # Select the equilibrium operator based on its type - self.equilibrium_bc = EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=next( - (op for op in self.operators if isinstance(op, Equilibrium)), None - ), - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.do_nothing_bc is None: - self.do_nothing_bc = DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.halfway_bounce_back_bc is None: - self.halfway_bounce_back_bc = HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.fullway_bounce_back_bc is None: - self.fullway_bounce_back_bc = FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) ############################################ # Initialize operator From 8c2bc40e9de8282c30f287b2bf44f27dff2d1c31 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 30 Jul 2024 15:04:41 -0400 Subject: [PATCH 051/144] minor clean up. Warp MLUPs is not affected. --- xlb/operator/stepper/nse_stepper.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 29ad088..eba6364 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -128,11 +128,11 @@ def kernel2d( else: _missing_mask[l] = wp.uint8(0) - # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == bc_struct.id_EquilibriumBC: + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Apply post-streaming type boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index @@ -162,7 +162,7 @@ def kernel2d( u, ) - # Apply collision type boundary conditions + # Apply post-collision type boundary conditions if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( @@ -199,11 +199,11 @@ def kernel3d( else: _missing_mask[l] = wp.uint8(0) - # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == bc_struct.id_EquilibriumBC: + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Apply post-streaming boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional( f_0, _missing_mask, index From d6e265d7193f6062fa23fb16b72e00397107b24e Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Wed, 31 Jul 2024 19:05:39 -0400 Subject: [PATCH 052/144] modified BC kernels kernels in warp to resemble jax and also NSE kernels. --- .../bc_equilibrium/test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_warp.py | 5 +++-- xlb/operator/boundary_condition/bc_do_nothing.py | 12 +++++------- xlb/operator/boundary_condition/bc_equilibrium.py | 12 +++++------- .../boundary_condition/bc_fullway_bounce_back.py | 12 +++++------- .../boundary_condition/bc_halfway_bounce_back.py | 12 +++++------- 6 files changed, 24 insertions(+), 31 deletions(-) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 7bb78cf..9a2c824 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -73,7 +73,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask, f) + f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 3f8f0d0..4881892 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -6,6 +6,7 @@ from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig +from xlb.operator.boundary_masker import IndicesBoundaryMasker def init_xlb_env(velocity_set): xlb.init( @@ -37,7 +38,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + indices_boundary_masker = IndicesBoundaryMasker() # Make indices for boundary conditions (sphere) sphere_radius = grid_shape[0] // 4 @@ -71,7 +72,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask, f_pre) + f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 38697c3..a769fdf 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -87,7 +87,6 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.uint8), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() @@ -113,7 +112,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -122,7 +121,6 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -148,7 +146,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d @@ -156,11 +154,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 0af61e1..1a4586d 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -99,7 +99,6 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() @@ -125,7 +124,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] @wp.func def functional3d( @@ -143,7 +142,6 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -169,7 +167,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d functional = functional3d if self.velocity_set.d == 3 else functional2d @@ -177,11 +175,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index a445e07..90ff564 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -77,7 +77,6 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) @@ -107,7 +106,7 @@ def kernel2d( # Write the result to the output for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -116,7 +115,6 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -147,18 +145,18 @@ def kernel3d( # Write the result to the output for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 55472a3..8852e8f 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -134,7 +134,6 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() @@ -160,7 +159,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -169,7 +168,6 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -195,7 +193,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d functional = functional3d if self.velocity_set.d == 3 else functional2d @@ -203,11 +201,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post From 50cc016b8d16ca425339bb50183e472ff7f24d1d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 1 Aug 2024 11:58:20 -0400 Subject: [PATCH 053/144] Fixed precision policy properties --- xlb/precision_policy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 25b0583..5a59b97 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -54,13 +54,13 @@ def compute_precision(self): if self == PrecisionPolicy.FP64FP64: return Precision.FP64 elif self == PrecisionPolicy.FP64FP32: - return Precision.FP32 + return Precision.FP64 elif self == PrecisionPolicy.FP64FP16: - return Precision.FP16 + return Precision.FP64 elif self == PrecisionPolicy.FP32FP32: return Precision.FP32 elif self == PrecisionPolicy.FP32FP16: - return Precision.FP16 + return Precision.FP32 else: raise ValueError("Invalid precision policy") From 9822d4f684a56aa82ce9f0f778cb597cfcf2ec8d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 1 Aug 2024 12:07:30 -0400 Subject: [PATCH 054/144] Removed deprecated sharding_flags --- examples/performance/mlups_3d.py | 48 ++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index b8f238c..75ecccc 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -9,23 +9,31 @@ from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC from xlb.distribute import distribute + def parse_arguments(): parser = argparse.ArgumentParser( description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" ) - parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") + parser.add_argument( + "cube_edge", type=int, help="Length of the edge of the cubic grid" + ) parser.add_argument("num_steps", type=int, help="Timestep for the simulation") - parser.add_argument("backend", type=str, help="Backend for the simulation (jax or warp)") - parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)") + parser.add_argument( + "backend", type=str, help="Backend for the simulation (jax or warp)" + ) + parser.add_argument( + "precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)" + ) return parser.parse_args() + def setup_simulation(args): backend = ComputeBackend.JAX if args.backend == "jax" else ComputeBackend.WARP precision_policy_map = { "fp32/fp32": PrecisionPolicy.FP32FP32, "fp64/fp64": PrecisionPolicy.FP64FP64, "fp64/fp32": PrecisionPolicy.FP64FP32, - "fp32/fp16": PrecisionPolicy.FP32FP16 + "fp32/fp16": PrecisionPolicy.FP32FP16, } precision_policy = precision_policy_map.get(args.precision) if precision_policy is None: @@ -39,25 +47,34 @@ def setup_simulation(args): return backend, precision_policy + def create_grid_and_fields(cube_edge): grid_shape = (cube_edge, cube_edge, cube_edge) grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) return grid, f_0, f_1, missing_mask, boundary_mask + def define_boundary_indices(grid): - lid = grid.boundingBoxIndices['top'] - walls = [grid.boundingBoxIndices['bottom'][i] + grid.boundingBoxIndices['left'][i] + - grid.boundingBoxIndices['right'][i] + grid.boundingBoxIndices['front'][i] + - grid.boundingBoxIndices['back'][i] for i in range(xlb.velocity_set.D3Q19().d)] + lid = grid.boundingBoxIndices["top"] + walls = [ + grid.boundingBoxIndices["bottom"][i] + + grid.boundingBoxIndices["left"][i] + + grid.boundingBoxIndices["right"][i] + + grid.boundingBoxIndices["front"][i] + + grid.boundingBoxIndices["back"][i] + for i in range(xlb.velocity_set.D3Q19().d) + ] return lid, walls - + + def setup_boundary_conditions(grid): lid, walls = define_boundary_indices(grid) bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid) bc_walls = FullwayBounceBackBC(indices=walls) return [bc_top, bc_walls] + def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper( @@ -66,7 +83,9 @@ def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): if backend == ComputeBackend.JAX: stepper = distribute( - stepper, grid, xlb.velocity_set.D3Q19(), sharding_flags=(True, True, True, True, False) + stepper, + grid, + xlb.velocity_set.D3Q19(), ) start_time = time.time() @@ -79,22 +98,27 @@ def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): end_time = time.time() return end_time - start_time + def calculate_mlups(cube_edge, num_steps, elapsed_time): total_lattice_updates = cube_edge**3 * num_steps mlups = (total_lattice_updates / elapsed_time) / 1e6 return mlups + def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) + elapsed_time = run( + f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps + ) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") print(f"MLUPs: {mlups:.2f}") + if __name__ == "__main__": - main() \ No newline at end of file + main() From a9f35fcd618d6a8f13f77ce46d0a3bb4d43391a2 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 2 Aug 2024 12:25:04 -0400 Subject: [PATCH 055/144] WIP: fixing missing mask in both jax and warp --- .../indices_boundary_masker.py | 107 ++++++++++++------ xlb/operator/stream/stream.py | 2 + 2 files changed, 76 insertions(+), 33 deletions(-) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 460bd3b..b446884 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -1,9 +1,12 @@ import numpy as np import warp as wp +import jax +import jax.numpy as jnp from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from xlb.operator.stream.stream import Stream - +from xlb.grid import grid_factory +from xlb.precision_policy import Precision class IndicesBoundaryMasker(Operator): """ @@ -27,8 +30,27 @@ def __init__( # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) def jax_implementation( - self, bclist, boundary_mask, mask, start_index=None + self, bclist, boundary_mask, missing_mask, start_index=None ): + + # Pad the missing mask to create a grid mask to identify out of bound boundaries + # Set padded regin to True (i.e. boundary) + dim = missing_mask.ndim - 1 + nDevices = jax.device_count() + pad_x, pad_y, pad_z = nDevices, 1, 1 + if dim == 2: + grid_mask = jnp.pad(missing_mask, ((0,0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) + if dim == 3: + grid_mask = jnp.pad(missing_mask, ((0,0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) + + + # shift indices + shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) + if start_index is None: + start_index = shift_tup + else: + start_index = tuple( a + b for a, b in zip(start_index, shift_tup)) + # define a helper function def compute_boundary_id_and_mask(boundary_mask, mask): if dim == 2: @@ -46,20 +68,21 @@ def compute_boundary_id_and_mask(boundary_mask, mask): ].set(True) return boundary_mask, mask - dim = mask.ndim - 1 - if start_index is None: - start_index = (0,) * dim for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC!' id_number = bc.id - local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] - boundary_mask, mask = compute_boundary_id_and_mask(boundary_mask, mask) + local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] + boundary_mask, grid_mask = compute_boundary_id_and_mask(boundary_mask, grid_mask) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop('indices', None) - mask = self.stream(mask) - return boundary_mask, mask + grid_mask = self.stream(grid_mask) + if dim == 2: + missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] + if dim == 3: + missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] + return boundary_mask, missing_mask def _construct_warp(self): # Make constants for warp @@ -72,7 +95,7 @@ def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), boundary_mask: wp.array3d(dtype=wp.uint8), - mask: wp.array3d(dtype=wp.bool), + missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): # Get the index of indices @@ -80,25 +103,33 @@ def kernel2d( # Get local indices index = wp.vec2i() - index[0] = indices[0, ii] - start_index[0] - index[1] = indices[1, ii] - start_index[1] + index[0] = indices[0, ii] + start_index[0] + index[1] = indices[1, ii] + start_index[1] - # Check if in bounds + # Check if index is in bounds if ( index[0] >= 0 - and index[0] < mask.shape[1] + and index[0] < missing_mask.shape[1] and index[1] >= 0 - and index[1] < mask.shape[2] + and index[1] < missing_mask.shape[2] ): # Stream indices for l in range(_q): # Get the index of the streaming direction - push_index = wp.vec2i() + pull_index = wp.vec2i() for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] - - # Set the boundary id and mask - mask[l, push_index[0], push_index[1]] = True + pull_index[d] = index[d] - _c[d, l] + + # check if pull index is out of bound + # These directions will have missing information after streaming + if ( + pull_index[0] < 0 + or pull_index[0] >= missing_mask.shape[1] + or pull_index[1] < 0 + or pull_index[1] >= missing_mask.shape[2] + ): + # Set the missing mask + missing_mask[l, index[0], index[1]] = True boundary_mask[0, index[0], index[1]] = id_number[ii] @@ -108,7 +139,7 @@ def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), boundary_mask: wp.array4d(dtype=wp.uint8), - mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): # Get the index of indices @@ -116,28 +147,38 @@ def kernel3d( # Get local indices index = wp.vec3i() - index[0] = indices[0, ii] - start_index[0] - index[1] = indices[1, ii] - start_index[1] - index[2] = indices[2, ii] - start_index[2] + index[0] = indices[0, ii] + start_index[0] + index[1] = indices[1, ii] + start_index[1] + index[2] = indices[2, ii] + start_index[2] - # Check if in bounds + # Check if index is in bounds if ( index[0] >= 0 - and index[0] < mask.shape[1] + and index[0] < missing_mask.shape[1] and index[1] >= 0 - and index[1] < mask.shape[2] + and index[1] < missing_mask.shape[2] and index[2] >= 0 - and index[2] < mask.shape[3] + and index[2] < missing_mask.shape[3] ): # Stream indices for l in range(_q): # Get the index of the streaming direction - push_index = wp.vec3i() + pull_index = wp.vec3i() for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] - - # Set the mask - mask[l, push_index[0], push_index[1], push_index[2]] = True + pull_index[d] = index[d] - _c[d, l] + + # check if pull index is out of bound + # These directions will have missing information after streaming + if ( + pull_index[0] < 0 + or pull_index[0] >= missing_mask.shape[1] + or pull_index[1] < 0 + or pull_index[1] >= missing_mask.shape[2] + or pull_index[2] < 0 + or pull_index[2] >= missing_mask.shape[3] + ): + # Set the missing mask + missing_mask[l, index[0], index[1], index[2]] = True boundary_mask[0, index[0], index[1], index[2]] = id_number[ii] diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 77cf22d..7ea528b 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -72,6 +72,7 @@ def functional2d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + # impose periodicity for out of bound values if pull_index[d] < 0: pull_index[d] = f.shape[d + 1] - 1 elif pull_index[d] >= f.shape[d + 1]: @@ -112,6 +113,7 @@ def functional3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + # impose periodicity for out of bound values if pull_index[d] < 0: pull_index[d] = f.shape[d + 1] - 1 elif pull_index[d] >= f.shape[d + 1]: From 1b179459df66af62b47afffffa105f239a38b220 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 2 Aug 2024 14:21:01 -0400 Subject: [PATCH 056/144] WIP: making the signature of all BC functional consistent in Warp --- xlb/operator/boundary_condition/bc_do_nothing.py | 10 +++++----- xlb/operator/boundary_condition/bc_equilibrium.py | 8 ++++---- .../boundary_condition/bc_fullway_bounce_back.py | 6 +++--- .../boundary_condition/bc_halfway_bounce_back.py | 8 ++++---- xlb/operator/stepper/nse_stepper.py | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index a769fdf..3a76a34 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -61,24 +61,24 @@ def _construct_warp(self): @wp.func def functional2d( - f: wp.array3d(dtype=Any), + f_pre: wp.array3d(dtype=Any), missing_mask: Any, index: Any, ): _f = _f_vec() for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] + _f[l] = f_pre[l, index[0], index[1]] return _f @wp.func def functional3d( - f: wp.array4d(dtype=Any), + f_pre: wp.array4d(dtype=Any), missing_mask: Any, index: Any, ): _f = _f_vec() for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1], index[2]] + _f[l] = f_pre[l, index[0], index[1], index[2]] return _f @wp.kernel @@ -104,7 +104,7 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional2d(f_pre, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 1a4586d..7c9eaa4 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -85,7 +85,7 @@ def _construct_warp(self): # Construct the funcional to get streamed indices @wp.func def functional2d( - f: wp.array3d(dtype=Any), + f_pre: wp.array3d(dtype=Any), missing_mask: Any, index: Any, ): @@ -116,7 +116,7 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional2d(f_post, _missing_mask, index) + _f = functional2d(f_pre, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): @@ -128,7 +128,7 @@ def kernel2d( @wp.func def functional3d( - f: wp.array4d(dtype=Any), + f_pre: wp.array4d(dtype=Any), missing_mask: Any, index: Any, ): @@ -159,7 +159,7 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional3d(f_post, _missing_mask, index) + _f = functional3d(f_pre, _missing_mask, index) else: _f = _f_vec() for l in range(self.velocity_set.q): diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 90ff564..eca2d3d 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -63,8 +63,8 @@ def _construct_warp(self): @wp.func def functional( f_pre: Any, - f_post: Any, missing_mask: Any, + index: Any, ): fliped_f = _f_vec() for l in range(_q): @@ -100,7 +100,7 @@ def kernel2d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _mask) + _f = functional(_f_pre, _mask, index) else: _f = _f_post @@ -139,7 +139,7 @@ def kernel3d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _mask) + _f = functional(_f_pre, _mask, index) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 8852e8f..3117602 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -70,7 +70,7 @@ def _construct_warp(self): @wp.func def functional2d( - f: wp.array3d(dtype=Any), + f_pre: wp.array3d(dtype=Any), missing_mask: Any, index: Any, ): @@ -93,14 +93,14 @@ def functional2d( pull_index[d] = index[d] - _c[d, l] # Get the distribution function - _f[l] = f[use_l, pull_index[0], pull_index[1]] + _f[l] = f_pre[use_l, pull_index[0], pull_index[1]] return _f # Construct the funcional to get streamed indices @wp.func def functional3d( - f: wp.array4d(dtype=Any), + f_pre: wp.array4d(dtype=Any), missing_mask: Any, index: Any, ): @@ -123,7 +123,7 @@ def functional3d( pull_index[d] = index[d] - _c[d, l] # Get the distribution function - _f[l] = f[use_l, pull_index[0], pull_index[1], pull_index[2]] + _f[l] = f_pre[use_l, pull_index[0], pull_index[1], pull_index[2]] return _f diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index eba6364..50ad5ec 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -167,8 +167,8 @@ def kernel2d( # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, - f_post_collision, _missing_mask, + index ) # Set the output @@ -233,8 +233,8 @@ def kernel3d( # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, - f_post_collision, _missing_mask, + index ) # Set the output From a2ec3a243365df2ebd8cbc01b7a769a4d15f4674 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 2 Aug 2024 16:11:56 -0400 Subject: [PATCH 057/144] fix a bug in the latest jax implementation of indices_masker --- .../indices_boundary_masker.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index b446884..922a5d6 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -47,36 +47,23 @@ def jax_implementation( # shift indices shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) if start_index is None: - start_index = shift_tup - else: - start_index = tuple( a + b for a, b in zip(start_index, shift_tup)) - - # define a helper function - def compute_boundary_id_and_mask(boundary_mask, mask): - if dim == 2: - boundary_mask = boundary_mask.at[ - 0, local_indices[0], local_indices[1] - ].set(id_number) - mask = mask.at[:, local_indices[0], local_indices[1]].set(True) - - if dim == 3: - boundary_mask = boundary_mask.at[ - 0, local_indices[0], local_indices[1], local_indices[2] - ].set(id_number) - mask = mask.at[ - :, local_indices[0], local_indices[1], local_indices[2] - ].set(True) - return boundary_mask, mask - + start_index = (0,) * dim + bid = boundary_mask[0] for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC!' id_number = bc.id local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] - boundary_mask, grid_mask = compute_boundary_id_and_mask(boundary_mask, grid_mask) + global_indices = local_indices + np.array(shift_tup)[:, np.newaxis] + bid = bid.at[tuple(local_indices)].set(id_number) + if dim == 2: + grid_mask = grid_mask.at[:, global_indices[0], global_indices[1]].set(True) + if dim == 3: + grid_mask = grid_mask.at[:, global_indices[0], global_indices[1], global_indices[2]].set(True) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop('indices', None) + boundary_mask = boundary_mask.at[0].set(bid) grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] From 1810e9c43a4097a5247be363d81e8c7da1ce7f8d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 2 Aug 2024 14:25:50 -0400 Subject: [PATCH 058/144] Fixed distribute and post processing bugs --- examples/cfd/flow_past_sphere_3d.py | 55 ++++++++++------- examples/cfd/lid_driven_cavity_2d.py | 47 +++++++++------ .../cfd/lid_driven_cavity_2d_distributed.py | 3 +- examples/cfd/windtunnel_3d.py | 60 ++++++++++++------- xlb/distribute/distribute.py | 44 ++++++++++++-- xlb/operator/stepper/nse_stepper.py | 4 +- 6 files changed, 147 insertions(+), 66 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index a1682b8..a7d9dac 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -15,44 +15,51 @@ import numpy as np import jax.numpy as jnp + class FlowOverSphere: def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): - # initialize backend xlb.init( velocity_set=velocity_set, default_backend=backend, default_precision_policy=precision_policy, ) - + self.grid_shape = grid_shape self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( + create_nse_fields(grid_shape) + ) self.stepper = None self.boundary_conditions = [] # Setup the simulation BC, its initial conditions, and the stepper self._setup(omega) - + def _setup(self, omega): self.setup_boundary_conditions() self.setup_boundary_masks() self.initialize_fields() self.setup_stepper(omega) - + def define_boundary_indices(self): - inlet = self.grid.boundingBoxIndices['left'] - outlet = self.grid.boundingBoxIndices['right'] - walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['top'][i] + - self.grid.boundingBoxIndices['front'][i] + self.grid.boundingBoxIndices['back'][i] for i in range(self.velocity_set.d)] - + inlet = self.grid.boundingBoxIndices["left"] + outlet = self.grid.boundingBoxIndices["right"] + walls = [ + self.grid.boundingBoxIndices["bottom"][i] + + self.grid.boundingBoxIndices["top"][i] + + self.grid.boundingBoxIndices["front"][i] + + self.grid.boundingBoxIndices["back"][i] + for i in range(self.velocity_set.d) + ] + sphere_radius = self.grid_shape[1] // 12 x = np.arange(self.grid_shape[0]) y = np.arange(self.grid_shape[1]) z = np.arange(self.grid_shape[2]) - X, Y, Z = np.meshgrid(x, y, z, indexing='ij') + X, Y, Z = np.meshgrid(x, y, z, indexing="ij") indices = np.where( (X - self.grid_shape[0] // 6) ** 2 + (Y - self.grid_shape[1] // 2) ** 2 @@ -62,7 +69,7 @@ def define_boundary_indices(self): sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)] return inlet, outlet, walls, sphere - + def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=inlet) @@ -83,24 +90,31 @@ def setup_boundary_masks(self): def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) - + def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper( omega, boundary_conditions=self.boundary_conditions ) - def run(self, num_steps): + def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper( + self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i + ) self.f_0, self.f_1 = self.f_1, self.f_0 + if i % post_process_interval == 0 or i == num_steps - 1: + self.post_process(i) + def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing if not isinstance(self.f_0, jnp.ndarray): - self.f_0 = wp.to_jax(self.f_0) + f_0 = wp.to_jax(self.f_0) + else: + f_0 = self.f_0 macro = Macroscopic(compute_backend=ComputeBackend.JAX) - rho, u = macro(self.f_0) + rho, u = macro(f_0) # remove boundary cells u = u[:, 1:-1, 1:-1, 1:-1] @@ -120,6 +134,7 @@ def post_process(self, i): precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 - simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps=10000) - simulation.post_process(i=10000) + simulation = FlowOverSphere( + omega, grid_shape, velocity_set, backend, precision_policy + ) + simulation.run(num_steps=10000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index eac2909..ecc6f0b 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -13,22 +13,23 @@ class LidDrivenCavity2D: def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): - # initialize backend xlb.init( velocity_set=velocity_set, default_backend=backend, default_precision_policy=precision_policy, ) - + self.grid_shape = grid_shape self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( + create_nse_fields(grid_shape) + ) self.stepper = None self.boundary_conditions = [] - + # Setup the simulation BC, its initial conditions, and the stepper self._setup(omega) @@ -39,11 +40,15 @@ def _setup(self, omega): self.setup_stepper(omega) def define_boundary_indices(self): - lid = self.grid.boundingBoxIndices['top'] - walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['left'][i] + - self.grid.boundingBoxIndices['right'][i] for i in range(self.velocity_set.d)] + lid = self.grid.boundingBoxIndices["top"] + walls = [ + self.grid.boundingBoxIndices["bottom"][i] + + self.grid.boundingBoxIndices["left"][i] + + self.grid.boundingBoxIndices["right"][i] + for i in range(self.velocity_set.d) + ] return lid, walls - + def setup_boundary_conditions(self): lid, walls = self.define_boundary_indices() bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid) @@ -62,25 +67,32 @@ def setup_boundary_masks(self): def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) - + def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper( omega, boundary_conditions=self.boundary_conditions ) - def run(self, num_steps): + def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper( + self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i + ) self.f_0, self.f_1 = self.f_1, self.f_0 + if i % post_process_interval == 0 or i == num_steps - 1: + self.post_process(i) + def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing if not isinstance(self.f_0, jnp.ndarray): - self.f_0 = wp.to_jax(self.f_0) + f_0 = wp.to_jax(self.f_0) + else: + f_0 = self.f_0 macro = Macroscopic(compute_backend=ComputeBackend.JAX) - rho, u = macro(self.f_0) + rho, u = macro(f_0) # remove boundary cells rho = rho[:, 1:-1, 1:-1] @@ -95,13 +107,14 @@ def post_process(self, i): if __name__ == "__main__": # Running the simulation - grid_size = 128 + grid_size = 500 grid_shape = (grid_size, grid_size) backend = ComputeBackend.JAX velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 - simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps=500) - simulation.post_process(i=500) + simulation = LidDrivenCavity2D( + omega, grid_shape, velocity_set, backend, precision_policy + ) + simulation.run(num_steps=5000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index e71da52..7a43a14 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -31,5 +31,4 @@ def setup_stepper(self, omega): omega=1.6 simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps=5000) - simulation.post_process(i=5000) \ No newline at end of file + simulation.run(num_steps=5000, post_process_interval=1000) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 49e26f7..35140ee 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -19,23 +19,26 @@ class WindTunnel3D: - def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy): - + def __init__( + self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy + ): # initialize backend xlb.init( velocity_set=velocity_set, default_backend=backend, default_precision_policy=precision_policy, ) - + self.grid_shape = grid_shape self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( + create_nse_fields(grid_shape) + ) self.stepper = None self.boundary_conditions = [] - + # Setup the simulation BC, its initial conditions, and the stepper self._setup(omega, wind_speed) @@ -54,31 +57,38 @@ def voxelize_stl(self, stl_filename, length_lbm_unit): return mesh_matrix, pitch def define_boundary_indices(self): - inlet = self.grid.boundingBoxIndices['left'] - outlet = self.grid.boundingBoxIndices['right'] - walls = [self.grid.boundingBoxIndices['bottom'][i] + self.grid.boundingBoxIndices['top'][i] + - self.grid.boundingBoxIndices['front'][i] + self.grid.boundingBoxIndices['back'][i] for i in range(self.velocity_set.d)] - + inlet = self.grid.boundingBoxIndices["left"] + outlet = self.grid.boundingBoxIndices["right"] + walls = [ + self.grid.boundingBoxIndices["bottom"][i] + + self.grid.boundingBoxIndices["top"][i] + + self.grid.boundingBoxIndices["front"][i] + + self.grid.boundingBoxIndices["back"][i] + for i in range(self.velocity_set.d) + ] + stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" grid_size_x = self.grid_shape[0] car_length_lbm_unit = grid_size_x / 4 car_voxelized, pitch = self.voxelize_stl(stl_filename, car_length_lbm_unit) car_area = np.prod(car_voxelized.shape[1:]) - tx, ty, tz = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape + tx, ty, tz = ( + np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape + ) shift = [tx // 4, ty // 2, 0] car = np.argwhere(car_voxelized) + shift car = np.array(car).T car = [tuple(car[i]) for i in range(self.velocity_set.d)] return inlet, outlet, walls, car - + def setup_boundary_conditions(self, wind_speed): inlet, outlet, walls, car = self.define_boundary_indices() bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = DoNothingBC(indices=outlet) - bc_car= FullwayBounceBackBC(indices=car) + bc_car = FullwayBounceBackBC(indices=car) self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_car] def setup_boundary_masks(self): @@ -93,26 +103,35 @@ def setup_boundary_masks(self): def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) - + def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper( omega, boundary_conditions=self.boundary_conditions, collision_type="KBC" ) - def run(self, num_steps, print_interval): + def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper( + self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i + ) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: elapsed_time = time.time() - start_time - print(f"Iteration: {i+1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") + print( + f"Iteration: {i+1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s" + ) + + if i % post_process_interval == 0 or i == num_steps - 1: + self.post_process(i) def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing if not isinstance(self.f_0, jnp.ndarray): f_0 = wp.to_jax(self.f_0) + else: + f_0 = self.f_0 macro = Macroscopic(compute_backend=ComputeBackend.JAX) @@ -159,6 +178,7 @@ def post_process(self, i): print(f"Max iterations: {num_steps}") print("\n" + "=" * 50 + "\n") - simulation = WindTunnel3D(omega, wind_speed, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps, print_interval) - simulation.post_process(i=num_steps) + simulation = WindTunnel3D( + omega, wind_speed, grid_shape, velocity_set, backend, precision_policy + ) + simulation.run(num_steps, print_interval, post_process_interval=1000) diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index ad07dd0..bcee7dd 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -1,16 +1,13 @@ from jax.sharding import PartitionSpec as P from xlb.operator import Operator -from xlb import DefaultConfig -from xlb import ComputeBackend +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from jax import lax from jax.experimental.shard_map import shard_map from jax import jit -import jax.numpy as jnp -import warp as wp -from typing import Tuple -def distribute( +def distribute_operator( operator: Operator, grid, velocity_set, @@ -85,3 +82,38 @@ def _wrapped_operator(*args): return distributed_operator(*args) return jit(_wrapped_operator) + + +def distribute(operator, grid, velocity_set, num_results=1, ops="permute"): + """ + Distribute an operator or a stepper. + If the operator is a stepper, check for post-streaming boundary conditions + before deciding how to distribute. + """ + if isinstance(operator, IncompressibleNavierStokesStepper): + # Check for post-streaming boundary conditions + has_post_streaming_bc = any( + bc.implementation_step == ImplementationStep.STREAMING + for bc in operator.boundary_conditions + ) + + if has_post_streaming_bc: + # If there are post-streaming BCs, only distribute the stream operator + distributed_stream = distribute_operator( + operator.stream, grid, velocity_set + ) + operator.stream = distributed_stream + else: + # If no post-streaming BCs, distribute the whole operator + distributed_op = distribute_operator( + operator, grid, velocity_set, num_results=num_results, ops=ops + ) + return distributed_op + + return operator + else: + # For other operators, apply the original distribution logic + distributed_op = distribute_operator( + operator, grid, velocity_set, num_results=num_results, ops=ops + ) + return distributed_op diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 11b5615..3f986d3 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -220,7 +220,9 @@ def kernel3d( feq = self.equilibrium.warp_functional(rho, u) # Apply collision - f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) + f_post_collision = self.collision.warp_functional( + f_post_stream, feq, rho, u + ) # Apply collision type boundary conditions if _boundary_id == _fullway_bounce_back_bc: From aeb17703941d336e71b33b15af317dcd8f2426b5 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 2 Aug 2024 16:49:02 -0400 Subject: [PATCH 059/144] Added ruff --- .github/workflows/lint.yml | 27 ++++++ .pre-commit-config.yaml | 6 ++ examples/cfd/flow_past_sphere_3d.py | 25 ++---- examples/cfd/lid_driven_cavity_2d.py | 24 ++--- .../cfd/lid_driven_cavity_2d_distributed.py | 16 ++-- examples/cfd/windtunnel_3d.py | 34 ++----- .../flow_past_sphere.py | 39 ++------ .../cfd_old_to_be_migrated/taylor_green.py | 79 +++++------------ examples/performance/mlups_3d.py | 24 ++--- requirements.txt | 3 +- ruff.toml | 42 +++++++++ setup.py | 3 +- .../bc_equilibrium/test_bc_equilibrium_jax.py | 22 ++--- .../test_bc_equilibrium_warp.py | 23 ++--- .../test_bc_fullway_bounce_back_jax.py | 18 ++-- .../test_bc_fullway_bounce_back_warp.py | 17 ++-- .../mask/test_bc_indices_masker_jax.py | 21 ++--- .../mask/test_bc_indices_masker_warp.py | 17 ++-- tests/grids/test_grid_jax.py | 3 +- .../collision/test_bgk_collision_jax.py | 1 + .../collision/test_bgk_collision_warp.py | 7 +- .../equilibrium/test_equilibrium_jax.py | 9 +- .../equilibrium/test_equilibrium_warp.py | 5 +- .../macroscopic/test_macroscopic_jax.py | 2 +- .../macroscopic/test_macroscopic_warp.py | 13 +-- tests/kernels/stream/test_stream_warp.py | 4 +- xlb/__init__.py | 11 +-- xlb/distribute/__init__.py | 2 +- xlb/distribute/distribute.py | 26 ++---- xlb/experimental/ooc/__init__.py | 4 +- xlb/experimental/ooc/ooc_array.py | 88 ++++--------------- xlb/experimental/ooc/out_of_core.py | 14 +-- xlb/experimental/ooc/tiles/compressed_tile.py | 47 +++------- xlb/experimental/ooc/tiles/dense_tile.py | 14 +-- xlb/experimental/ooc/tiles/dynamic_array.py | 10 +-- xlb/experimental/ooc/tiles/tile.py | 10 +-- xlb/experimental/ooc/utils.py | 2 +- xlb/grid/__init__.py | 2 +- xlb/grid/grid.py | 25 +++--- xlb/grid/jax_grid.py | 17 +--- xlb/grid/warp_grid.py | 9 +- xlb/helper/__init__.py | 4 +- xlb/helper/nse_solver.py | 17 +--- xlb/operator/__init__.py | 4 +- xlb/operator/boundary_condition/__init__.py | 12 +-- .../boundary_condition/bc_do_nothing.py | 8 +- .../boundary_condition/bc_equilibrium.py | 14 +-- .../bc_fullway_bounce_back.py | 10 +-- .../bc_halfway_bounce_back.py | 12 +-- .../boundary_condition/boundary_condition.py | 8 +- .../boundary_condition_registry.py | 8 +- xlb/operator/boundary_masker/__init__.py | 4 +- .../indices_boundary_masker.py | 45 +++------- .../boundary_masker/stl_boundary_masker.py | 18 +--- xlb/operator/collision/__init__.py | 6 +- xlb/operator/collision/kbc.py | 32 +++---- xlb/operator/equilibrium/__init__.py | 4 +- .../equilibrium/quadratic_equilibrium.py | 3 +- xlb/operator/macroscopic/__init__.py | 2 +- xlb/operator/macroscopic/macroscopic.py | 4 +- xlb/operator/operator.py | 21 +---- xlb/operator/parallel_operator.py | 8 +- xlb/operator/precision_caster/__init__.py | 2 +- .../precision_caster/precision_caster.py | 5 +- xlb/operator/stepper/__init__.py | 4 +- xlb/operator/stepper/nse_stepper.py | 38 ++------ xlb/operator/stepper/stepper.py | 55 +++--------- xlb/operator/stream/__init__.py | 2 +- xlb/operator/stream/stream.py | 5 +- xlb/precision_policy.py | 4 +- xlb/precision_policy/precision_policy.py | 21 ++--- xlb/utils/__init__.py | 14 +-- xlb/utils/utils.py | 37 +++----- xlb/velocity_set/__init__.py | 8 +- xlb/velocity_set/d2q9.py | 5 +- xlb/velocity_set/d3q19.py | 9 +- xlb/velocity_set/d3q27.py | 1 + xlb/velocity_set/velocity_set.py | 19 +--- 78 files changed, 415 insertions(+), 823 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 .pre-commit-config.yaml create mode 100644 ruff.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..1b44c5a --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,27 @@ +name: Lint + +on: + pull_request: + branches: + - major-refactoring # Remember to add main branch later + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Run Ruff + run: ruff check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6a2bd2f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.6 + hooks: + - id: ruff + args: [--fix] diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index a7d9dac..2a580aa 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -29,9 +29,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( - create_nse_fields(grid_shape) - ) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -61,10 +59,7 @@ def define_boundary_indices(self): z = np.arange(self.grid_shape[2]) X, Y, Z = np.meshgrid(x, y, z, indexing="ij") indices = np.where( - (X - self.grid_shape[0] // 6) ** 2 - + (Y - self.grid_shape[1] // 2) ** 2 - + (Z - self.grid_shape[2] // 2) ** 2 - < sphere_radius**2 + (X - self.grid_shape[0] // 6) ** 2 + (Y - self.grid_shape[1] // 2) ** 2 + (Z - self.grid_shape[2] // 2) ** 2 < sphere_radius**2 ) sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)] @@ -84,23 +79,17 @@ def setup_boundary_masks(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker( - self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0) - ) + self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions - ) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper( - self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i - ) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: @@ -134,7 +123,5 @@ def post_process(self, i): precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 - simulation = FlowOverSphere( - omega, grid_shape, velocity_set, backend, precision_policy - ) + simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps=10000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index ecc6f0b..488ebc1 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -24,9 +24,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( - create_nse_fields(grid_shape) - ) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -42,9 +40,7 @@ def _setup(self, omega): def define_boundary_indices(self): lid = self.grid.boundingBoxIndices["top"] walls = [ - self.grid.boundingBoxIndices["bottom"][i] - + self.grid.boundingBoxIndices["left"][i] - + self.grid.boundingBoxIndices["right"][i] + self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["left"][i] + self.grid.boundingBoxIndices["right"][i] for i in range(self.velocity_set.d) ] return lid, walls @@ -61,23 +57,17 @@ def setup_boundary_masks(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker( - self.boundary_conditions, self.boundary_mask, self.missing_mask - ) + self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions - ) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper( - self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i - ) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: @@ -114,7 +104,5 @@ def post_process(self, i): precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 - simulation = LidDrivenCavity2D( - omega, grid_shape, velocity_set, backend, precision_policy - ) + simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps=5000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 7a43a14..225d6bd 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -11,24 +11,24 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): super().__init__(omega, grid_shape, velocity_set, backend, precision_policy) def setup_stepper(self, omega): - stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions - ) + stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) distributed_stepper = distribute( - stepper, self.grid, self.velocity_set, - ) + stepper, + self.grid, + self.velocity_set, + ) self.stepper = distributed_stepper return - + if __name__ == "__main__": # Running the simulation grid_size = 512 grid_shape = (grid_size, grid_size) - backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! + backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 - omega=1.6 + omega = 1.6 simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps=5000, post_process_interval=1000) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 35140ee..e76b303 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -19,9 +19,7 @@ class WindTunnel3D: - def __init__( - self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy - ): + def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy): # initialize backend xlb.init( velocity_set=velocity_set, @@ -33,9 +31,7 @@ def __init__( self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( - create_nse_fields(grid_shape) - ) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -72,10 +68,8 @@ def define_boundary_indices(self): car_length_lbm_unit = grid_size_x / 4 car_voxelized, pitch = self.voxelize_stl(stl_filename, car_length_lbm_unit) - car_area = np.prod(car_voxelized.shape[1:]) - tx, ty, tz = ( - np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape - ) + # car_area = np.prod(car_voxelized.shape[1:]) + tx, ty, _ = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape shift = [tx // 4, ty // 2, 0] car = np.argwhere(car_voxelized) + shift car = np.array(car).T @@ -97,31 +91,23 @@ def setup_boundary_masks(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker( - self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0) - ) + self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions, collision_type="KBC" - ) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper( - self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i - ) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: elapsed_time = time.time() - start_time - print( - f"Iteration: {i+1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s" - ) + print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") if i % post_process_interval == 0 or i == num_steps - 1: self.post_process(i) @@ -178,7 +164,5 @@ def post_process(self, i): print(f"Max iterations: {num_steps}") print("\n" + "=" * 50 + "\n") - simulation = WindTunnel3D( - omega, wind_speed, grid_shape, velocity_set, backend, precision_policy - ) + simulation = WindTunnel3D(omega, wind_speed, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps, print_interval, post_process_interval=1000) diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py index 7e8af30..68d1c2b 100644 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -22,8 +22,8 @@ from xlb.operator import Operator -class UniformInitializer(Operator): +class UniformInitializer(Operator): def _construct_warp(self): # Construct the warp kernel @wp.kernel @@ -149,48 +149,27 @@ def warp_implementation(self, rho, u, vel): y = np.arange(nr) z = np.arange(nr) X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = np.array(indices).T indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_mask, missing_mask = indices_boundary_masker( - indices, - half_way_bc.id, - boundary_mask, - missing_mask, - (0, 0, 0) - ) + boundary_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_mask, missing_mask, (0, 0, 0)) # Set inlet bc lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - equilibrium_bc.id, - boundary_mask, - missing_mask, - (0, 0, 0) + lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_mask, missing_mask, (0, 0, 0) ) # Set outlet bc - lower_bound = (nr-1, 0, 0) - upper_bound = (nr-1, nr, nr) + lower_bound = (nr - 1, 0, 0) + upper_bound = (nr - 1, nr, nr) direction = (-1, 0, 0) boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - do_nothing_bc.id, - boundary_mask, - missing_mask, - (0, 0, 0) + lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_mask, missing_mask, (0, 0, 0) ) # Set initial conditions @@ -201,7 +180,7 @@ def warp_implementation(self, rho, u, vel): plot_freq = 512 save_dir = "flow_past_sphere" os.makedirs(save_dir, exist_ok=True) - #compute_mlup = False # Plotting results + # compute_mlup = False # Plotting results compute_mlup = True num_steps = 1024 * 8 start = time.time() @@ -225,4 +204,4 @@ def warp_implementation(self, rho, u, vel): end = time.time() # Print MLUPS - print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py index 10eb54f..9ed7fa6 100644 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -4,10 +4,8 @@ from tqdm import tqdm import os import matplotlib.pyplot as plt -from functools import partial from typing import Any import jax.numpy as jnp -from jax import jit import warp as wp wp.init() @@ -15,13 +13,14 @@ import xlb from xlb.operator import Operator + class TaylorGreenInitializer(Operator): """ Initialize the Taylor-Green vortex. """ @Operator.register_backend(xlb.ComputeBackend.JAX) - #@partial(jit, static_argnums=(0)) + # @partial(jit, static_argnums=(0)) def jax_implementation(self, vel, nr): # Make meshgrid x = jnp.linspace(0, 2 * jnp.pi, nr) @@ -33,24 +32,14 @@ def jax_implementation(self, vel, nr): u = jnp.stack( [ vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), - - vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), + -vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), jnp.zeros_like(X), ], axis=0, ) # Compute rho - rho = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * ( - jnp.cos(2.0 * X) - + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0)) - ) - + 1.0 - ) + rho = 3.0 * vel * vel * (1.0 / 16.0) * (jnp.cos(2.0 * X) + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0))) + 1.0 rho = jnp.expand_dims(rho, axis=0) return rho, u @@ -74,22 +63,11 @@ def kernel( # Compute u u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1, i, j, k] = - vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + u[1, i, j, k] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) u[2, i, j, k] = 0.0 # Compute rho - rho[0, i, j, k] = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * ( - wp.cos(2.0 * x) - + (wp.cos(2.0 * y) - * (wp.cos(2.0 * z) + 2.0)) - ) - + 1.0 - ) + rho[0, i, j, k] = 3.0 * vel * vel * (1.0 / 16.0) * (wp.cos(2.0 * x) + (wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0))) + 1.0 return None, kernel @@ -108,8 +86,8 @@ def warp_implementation(self, rho, u, vel, nr): ) return rho, u -def run_taylor_green(backend, compute_mlup=True): +def run_taylor_green(backend, compute_mlup=True): # Set the compute backend if backend == "warp": compute_backend = xlb.ComputeBackend.WARP @@ -139,35 +117,19 @@ def run_taylor_green(backend, compute_mlup=True): missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators - initializer = TaylorGreenInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - collision = xlb.operator.collision.BGK( - omega=1.9, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + initializer = TaylorGreenInitializer(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) + collision = xlb.operator.collision.BGK(omega=1.9, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend + ) + macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) + stream = xlb.operator.stream.Stream(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream) + collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream + ) # Parrallelize the stepper TODO: Add this functionality - #stepper = grid.parallelize_operator(stepper) + # stepper = grid.parallelize_operator(stepper) # Set initial conditions if backend == "warp": @@ -200,8 +162,7 @@ def run_taylor_green(backend, compute_mlup=True): elif backend == "jax": rho, local_u = macroscopic(f0) - - plt.imshow(local_u[0, :, nr//2, :]) + plt.imshow(local_u[0, :, nr // 2, :]) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() @@ -209,12 +170,12 @@ def run_taylor_green(backend, compute_mlup=True): end = time.time() # Print MLUPS - print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") -if __name__ == "__main__": +if __name__ == "__main__": # Run Taylor-Green vortex on different backends backends = ["warp", "jax"] - #backends = ["jax"] + # backends = ["jax"] for backend in backends: run_taylor_green(backend, compute_mlup=True) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 75ecccc..74bfa04 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -11,19 +11,11 @@ def parse_arguments(): - parser = argparse.ArgumentParser( - description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" - ) - parser.add_argument( - "cube_edge", type=int, help="Length of the edge of the cubic grid" - ) + parser = argparse.ArgumentParser(description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)") + parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") parser.add_argument("num_steps", type=int, help="Timestep for the simulation") - parser.add_argument( - "backend", type=str, help="Backend for the simulation (jax or warp)" - ) - parser.add_argument( - "precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)" - ) + parser.add_argument("backend", type=str, help="Backend for the simulation (jax or warp)") + parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)") return parser.parse_args() @@ -77,9 +69,7 @@ def setup_boundary_conditions(grid): def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): omega = 1.0 - stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=setup_boundary_conditions(grid) - ) + stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) if backend == ComputeBackend.JAX: stepper = distribute( @@ -111,9 +101,7 @@ def main(): grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run( - f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps - ) + elapsed_time = run(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/requirements.txt b/requirements.txt index ebae946..ee107af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git tqdm==4.66.2 warp-lang==1.0.2 numpy-stl==3.1.1 -pydantic==2.7.0 \ No newline at end of file +pydantic==2.7.0 +ruff==0.5.6 \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..83c6370 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,42 @@ +# Adopted from tinygrad's ruff.toml thanks @geohot +indent-width = 4 +preview = true +target-version = "py38" + +lint.select = [ + "F", # Pyflakes + "W6", + "E71", + "E72", + "E112", # no-indented-block + "E113", # unexpected-indentation + # "E124", + "E203", # whitespace-before-punctuation + "E272", # multiple-spaces-before-keyword + "E303", # too-many-blank-lines + "E304", # blank-line-after-decorator + "E501", # line-too-long + # "E502", + "E702", # multiple-statements-on-one-line-semicolon + "E703", # useless-semicolon + "E731", # lambda-assignment + "W191", # tab-indentation + "W291", # trailing-whitespace + "W293", # blank-line-with-whitespace + "UP039", # unnecessary-class-parentheses + "C416", # unnecessary-comprehension + "RET506", # superfluous-else-raise + "RET507", # superfluous-else-continue + "A", # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing + "SIM105", # suppressible-exception + "FURB110",# if-exp-instead-of-or-operator +] + +# unused-variable, shadowing a Python builtin module, Module imported but unused +lint.ignore = ["F841", "A005", "F401"] +line-length = 150 + +exclude = [ + "docs/", + "xlb/experimental/", +] \ No newline at end of file diff --git a/setup.py b/setup.py index 5ef3ed7..6f9780a 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ version="0.0.1", author="", packages=find_packages(), - install_requires=[ - ], + install_requires=[], include_package_data=True, ) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 9017a9c..3e50fdb 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -7,6 +7,7 @@ from xlb import DefaultConfig from xlb.operator.boundary_masker import IndicesBoundaryMasker + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -29,9 +30,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -48,10 +47,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] @@ -62,9 +58,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker( - [equilibrium_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -80,13 +74,9 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): weights = velocity_set.w for i, weight in enumerate(weights): if dim == 2: - assert jnp.allclose( - f[i, indices[0], indices[1]], weight - ), f"Direction {i} in f does not match the expected weight" + assert jnp.allclose(f[i, indices[0], indices[1]], weight), f"Direction {i} in f does not match the expected weight" else: - assert jnp.allclose( - f[i, indices[0], indices[1], indices[2]], weight - ), f"Direction {i} in f does not match the expected weight" + assert jnp.allclose(f[i, indices[0], indices[1], indices[2]], weight), f"Direction {i} in f does not match the expected weight" # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values mask_outside = np.ones(grid_shape, dtype=bool) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 7bb78cf..e319dbd 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -1,12 +1,12 @@ import pytest import numpy as np -import warp as wp import xlb from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig from xlb.operator.boundary_masker import IndicesBoundaryMasker + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -29,9 +29,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -48,10 +46,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() @@ -63,9 +58,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker( - [equilibrium_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -84,13 +77,9 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): weights = velocity_set.w for i, weight in enumerate(weights): if dim == 2: - assert np.allclose( - f[i, indices[0], indices[1]], weight - ), f"Direction {i} in f does not match the expected weight" + assert np.allclose(f[i, indices[0], indices[1]], weight), f"Direction {i} in f does not match the expected weight" else: - assert np.allclose( - f[i, indices[0], indices[1], indices[2]], weight - ), f"Direction {i} in f does not match the expected weight" + assert np.allclose(f[i, indices[0], indices[1], indices[2]], weight), f"Direction {i} in f does not match the expected weight" # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values mask_outside = np.ones(grid_shape, dtype=bool) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index b6ce4c3..1b7edc2 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -7,6 +7,7 @@ from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -31,9 +32,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -50,21 +49,14 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_mask, missing_mask = indices_boundary_masker( - [fullway_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) - f_pre = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0 - ) + f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0) # Generate a random field with the same shape key = jax.random.PRNGKey(0) random_field = jax.random.uniform(key, f_pre.shape) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 3f8f0d0..963e081 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -2,11 +2,11 @@ import numpy as np import warp as wp import xlb -import jax from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -31,9 +31,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -50,17 +48,12 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - - boundary_mask, missing_mask = indices_boundary_masker( - [fullway_bc], boundary_mask, missing_mask, start_index=None - ) + + boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) # Generate a random field with the same shape random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index 0de8805..ddbc761 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -32,9 +32,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -51,19 +49,14 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_mask, missing_mask = indices_boundary_masker( - [test_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([test_bc], boundary_mask, missing_mask, start_index=None) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype @@ -79,13 +72,9 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): boundary_mask = boundary_mask.at[0, indices[0], indices[1]].set(0) assert jnp.all(boundary_mask == 0) if dim == 3: - assert jnp.all( - boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id - ) + assert jnp.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) # assert that the rest of the boundary_mask is zero - boundary_mask = boundary_mask.at[ - 0, indices[0], indices[1], indices[2] - ].set(0) + boundary_mask = boundary_mask.at[0, indices[0], indices[1], indices[2]].set(0) assert jnp.all(boundary_mask == 0) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 43911f6..6919ba9 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -1,5 +1,4 @@ import pytest -import warp as wp import numpy as np import xlb from xlb.compute_backend import ComputeBackend @@ -31,9 +30,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -50,10 +47,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] @@ -80,15 +74,14 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): if dim == 2: assert np.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) # assert that the rest of the boundary_mask is zero - boundary_mask[0, indices[0], indices[1]]= 0 + boundary_mask[0, indices[0], indices[1]] = 0 assert np.all(boundary_mask == 0) if dim == 3: - assert np.all( - boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id - ) + assert np.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) # assert that the rest of the boundary_mask is zero boundary_mask[0, indices[0], indices[1], indices[2]] = 0 assert np.all(boundary_mask == 0) + if __name__ == "__main__": pytest.main() diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index ce4bc70..edd9dd0 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -7,6 +7,7 @@ from jax.experimental import mesh_utils import jax.numpy as jnp + def init_xlb_env(): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -51,6 +52,7 @@ def test_jax_3d_grid_initialization(grid_size): "z", ), "PartitionSpec is incorrect" + def test_jax_grid_create_field_fill_value(): init_xlb_env() grid_shape = (100, 100) @@ -62,7 +64,6 @@ def test_jax_grid_create_field_fill_value(): assert jnp.allclose(f, fill_value), "Field not properly initialized with fill_value" - @pytest.fixture(autouse=True) def setup_xlb_env(): init_xlb_env() diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index cce5ca4..5a400e0 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -7,6 +7,7 @@ from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 7509c1d..522ea33 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -1,14 +1,12 @@ import pytest -import warp as wp import numpy as np import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.macroscopic import Macroscopic from xlb.operator.collision import BGK from xlb.grid import grid_factory from xlb import DefaultConfig -from xlb.precision_policy import Precision + def init_xlb_env(velocity_set): xlb.init( @@ -17,6 +15,7 @@ def init_xlb_env(velocity_set): velocity_set=velocity_set(), ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape,omega", [ @@ -40,7 +39,6 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_eq = compute_macro(rho, u, f_eq) - compute_collision = BGK(omega=omega) f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) @@ -53,5 +51,6 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): assert np.allclose(f_out, f_orig - omega * (f_orig - f_eq), atol=1e-5) + if __name__ == "__main__": pytest.main() diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index fbdadb6..07bafe7 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -6,6 +6,7 @@ from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -38,16 +39,12 @@ def test_quadratic_equilibrium_jax(dim, velocity_set, grid_shape): # Test sum of f_eq across cardinality at each point sum_f_eq = np.sum(f_eq, axis=0) - assert np.allclose( - sum_f_eq, 1.0 - ), f"Sum of f_eq should be 1.0 across all directions at each grid point" + assert np.allclose(sum_f_eq, 1.0), "Sum of f_eq should be 1.0 across all directions at each grid point" # Test that each direction matches the expected weights weights = DefaultConfig.velocity_set.w for i, weight in enumerate(weights): - assert np.allclose( - f_eq[i, ...], weight - ), f"Direction {i} in f_eq does not match the expected weight" + assert np.allclose(f_eq[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" if __name__ == "__main__": diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index ef2287f..063a723 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -1,11 +1,12 @@ import pytest -import warp as wp import numpy as np import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.grid import grid_factory from xlb import DefaultConfig + + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -13,6 +14,7 @@ def init_xlb_env(velocity_set): velocity_set=velocity_set(), ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ @@ -45,6 +47,7 @@ def test_quadratic_equilibrium_warp(dim, velocity_set, grid_shape): for i, weight in enumerate(weights): assert np.allclose(f_eq_np[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" + # @pytest.fixture(autouse=True) # def setup_xlb_env(request): # dim, velocity_set, grid_shape = request.param diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 89ef393..50d1735 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -5,7 +5,7 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.grid import grid_factory -from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 7a4a8cd..d98a014 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -6,7 +6,6 @@ from xlb.operator.macroscopic import Macroscopic from xlb.grid import grid_factory from xlb import DefaultConfig -import warp as wp def init_xlb_env(velocity_set): @@ -25,8 +24,8 @@ def init_xlb_env(velocity_set): (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 2.0), (2, xlb.velocity_set.D2Q9, (50, 50), 1.1, 2.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0, 0.0), - (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. - (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), # TODO: Uncommenting will cause a Warp error. Needs investigation. + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), # TODO: Uncommenting will cause a Warp error. Needs investigation. ], ) def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): @@ -45,12 +44,8 @@ def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): rho_calc, u_calc = compute_macro(f_eq, rho_calc, u_calc) - assert np.allclose( - rho_calc.numpy(), rho - ), f"Computed density should be close to initialized density {rho}" - assert np.allclose( - u_calc.numpy(), velocity - ), f"Computed velocity should be close to initialized velocity {velocity}" + assert np.allclose(rho_calc.numpy(), rho), f"Computed density should be close to initialized density {rho}" + assert np.allclose(u_calc.numpy(), velocity), f"Computed velocity should be close to initialized velocity {velocity}" if __name__ == "__main__": diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index af70b4c..b83368d 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -70,9 +70,7 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape): f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q) f_streamed = stream_op(f_initial_warp, f_streamed) - assert jnp.allclose( - f_streamed.numpy(), np.array(expected) - ), "Streaming did not occur as expected" + assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected" if __name__ == "__main__": diff --git a/xlb/__init__.py b/xlb/__init__.py index be63d06..b58db3b 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -1,10 +1,10 @@ # Enum classes -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy, Precision -from xlb.physics_type import PhysicsType +from xlb.compute_backend import ComputeBackend as ComputeBackend +from xlb.precision_policy import PrecisionPolicy as PrecisionPolicy, Precision as Precision +from xlb.physics_type import PhysicsType as PhysicsType # Config -from .default_config import init, DefaultConfig +from .default_config import init as init, DefaultConfig as DefaultConfig # Velocity Set import xlb.velocity_set @@ -15,6 +15,7 @@ import xlb.operator.stream import xlb.operator.boundary_condition import xlb.operator.macroscopic + # Grids import xlb.grid @@ -25,4 +26,4 @@ import xlb.utils # Distributed computing -import xlb.distribute \ No newline at end of file +import xlb.distribute diff --git a/xlb/distribute/__init__.py b/xlb/distribute/__init__.py index 33e0d2b..25fa0af 100644 --- a/xlb/distribute/__init__.py +++ b/xlb/distribute/__init__.py @@ -1 +1 @@ -from .distribute import distribute \ No newline at end of file +from .distribute import distribute as distribute diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index bcee7dd..c62b915 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -57,13 +57,8 @@ def build_specs(grid, *args): else: sharding_flags.append(False) - in_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() - for flag in sharding_flags - ) - out_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) - ) + in_specs = tuple(P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() for flag in sharding_flags) + out_specs = tuple(P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results)) return tuple(sharding_flags), in_specs, out_specs def _wrapped_operator(*args): @@ -92,28 +87,19 @@ def distribute(operator, grid, velocity_set, num_results=1, ops="permute"): """ if isinstance(operator, IncompressibleNavierStokesStepper): # Check for post-streaming boundary conditions - has_post_streaming_bc = any( - bc.implementation_step == ImplementationStep.STREAMING - for bc in operator.boundary_conditions - ) + has_post_streaming_bc = any(bc.implementation_step == ImplementationStep.STREAMING for bc in operator.boundary_conditions) if has_post_streaming_bc: # If there are post-streaming BCs, only distribute the stream operator - distributed_stream = distribute_operator( - operator.stream, grid, velocity_set - ) + distributed_stream = distribute_operator(operator.stream, grid, velocity_set) operator.stream = distributed_stream else: # If no post-streaming BCs, distribute the whole operator - distributed_op = distribute_operator( - operator, grid, velocity_set, num_results=num_results, ops=ops - ) + distributed_op = distribute_operator(operator, grid, velocity_set, num_results=num_results, ops=ops) return distributed_op return operator else: # For other operators, apply the original distribution logic - distributed_op = distribute_operator( - operator, grid, velocity_set, num_results=num_results, ops=ops - ) + distributed_op = distribute_operator(operator, grid, velocity_set, num_results=num_results, ops=ops) return distributed_op diff --git a/xlb/experimental/ooc/__init__.py b/xlb/experimental/ooc/__init__.py index 5206cc1..801683d 100644 --- a/xlb/experimental/ooc/__init__.py +++ b/xlb/experimental/ooc/__init__.py @@ -1,2 +1,2 @@ -from xlb.experimental.ooc.out_of_core import OOCmap -from xlb.experimental.ooc.ooc_array import OOCArray +from xlb.experimental.ooc.out_of_core import OOCmap as OOCmap +from xlb.experimental.ooc.ooc_array import OOCArray as OOCArray diff --git a/xlb/experimental/ooc/ooc_array.py b/xlb/experimental/ooc/ooc_array.py index 6effde6..11aedc3 100644 --- a/xlb/experimental/ooc/ooc_array.py +++ b/xlb/experimental/ooc/ooc_array.py @@ -3,7 +3,6 @@ # from mpi4py import MPI import itertools -from dataclasses import dataclass from xlb.experimental.ooc.tiles.dense_tile import DenseTile, DenseGPUTile, DenseCPUTile from xlb.experimental.ooc.tiles.compressed_tile import ( @@ -63,9 +62,7 @@ def __init__( if self.codec is None: self.Tile = DenseTile self.DeviceTile = DenseGPUTile - self.HostTile = ( - DenseCPUTile # TODO: Possibly make HardDiskTile or something - ) + self.HostTile = DenseCPUTile # TODO: Possibly make HardDiskTile or something else: self.Tile = CompressedTile @@ -84,45 +81,33 @@ def __init__( # Get number of tiles per process if self.nr_tiles % self.nr_proc != 0: - raise ValueError( - f"Number of tiles {self.nr_tiles} does not divide number of processes {self.nr_proc}." - ) + raise ValueError(f"Number of tiles {self.nr_tiles} does not divide number of processes {self.nr_proc}.") self.nr_tiles_per_proc = self.nr_tiles // self.nr_proc # Make the tile mapppings self.tile_process_map = {} self.tile_device_map = {} - for i, tile_index in enumerate( - itertools.product(*[range(n) for n in self.tile_dims]) - ): + for i, tile_index in enumerate(itertools.product(*[range(n) for n in self.tile_dims])): self.tile_process_map[tile_index] = i % self.nr_proc - self.tile_device_map[tile_index] = devices[ - i % len(devices) - ] # Checkoboard pattern, TODO: may not be optimal + self.tile_device_map[tile_index] = devices[i % len(devices)] # Checkoboard pattern, TODO: may not be optimal # Get my device if self.nr_proc != len(self.devices): - raise ValueError( - f"Number of processes {self.nr_proc} does not equal number of devices {len(self.devices)}." - ) + raise ValueError(f"Number of processes {self.nr_proc} does not equal number of devices {len(self.devices)}.") self.device = self.devices[self.pid] # Make the tiles self.tiles = {} for tile_index in self.tile_process_map.keys(): if self.pid == self.tile_process_map[tile_index]: - self.tiles[tile_index] = self.HostTile( - self.tile_shape, self.dtype, self.padding, self.codec - ) + self.tiles[tile_index] = self.HostTile(self.tile_shape, self.dtype, self.padding, self.codec) # Make GPU tiles for copying data between CPU and GPU if self.nr_tiles % self.nr_compute_tiles != 0: raise ValueError( f"Number of tiles {self.nr_tiles} does not divide number of compute tiles {self.nr_compute_tiles}. This is used for asynchronous copies." ) - compute_array_shape = [ - s + 2 * p for (s, p) in zip(self.tile_shape, self.padding) - ] + compute_array_shape = [s + 2 * p for (s, p) in zip(self.tile_shape, self.padding)] self.compute_tiles_htd = [] self.compute_tiles_dth = [] self.compute_streams_htd = [] @@ -132,13 +117,9 @@ def __init__( with cp.cuda.Device(self.device): for i in range(self.nr_compute_tiles): # Make compute tiles for copying data - compute_tile = self.DeviceTile( - self.tile_shape, self.dtype, self.padding, self.codec - ) + compute_tile = self.DeviceTile(self.tile_shape, self.dtype, self.padding, self.codec) self.compute_tiles_htd.append(compute_tile) - compute_tile = self.DeviceTile( - self.tile_shape, self.dtype, self.padding, self.codec - ) + compute_tile = self.DeviceTile(self.tile_shape, self.dtype, self.padding, self.codec) self.compute_tiles_dth.append(compute_tile) # Make cupy stream @@ -185,9 +166,7 @@ def compression_ratio(self): def update_compute_index(self): """Update the current compute index.""" - self.current_compute_index = ( - self.current_compute_index + 1 - ) % self.nr_compute_tiles + self.current_compute_index = (self.current_compute_index + 1) % self.nr_compute_tiles def _guess_next_tile_index(self, tile_index): """Guess the next tile index to use for the compute array.""" @@ -291,9 +270,7 @@ def get_compute_array(self, tile_index): compute_tile.to_array(self.compute_arrays[self.current_compute_index]) # Return the compute array index in global array - global_index = tuple( - [i * s - p for (i, s, p) in zip(tile_index, self.tile_shape, self.padding)] - ) + global_index = tuple([i * s - p for (i, s, p) in zip(tile_index, self.tile_shape, self.padding)]) return self.compute_arrays[self.current_compute_index], global_index @@ -339,12 +316,7 @@ def update_padding(self): # Loop over all padding for pad_index in pad_ind: # Get neighboring tile index - neigh_tile_index = tuple( - [ - (i + p) % s - for (i, p, s) in zip(tile_index, pad_index, self.tile_dims) - ] - ) + neigh_tile_index = tuple([(i + p) % s for (i, p, s) in zip(tile_index, pad_index, self.tile_dims)]) neigh_pad_index = tuple([-p for p in pad_index]) # flip # 4 cases: @@ -354,10 +326,7 @@ def update_padding(self): # 4. the tile and neighboring tile are on different processes # Case 1: the tile and neighboring tile are on the same process - if ( - self.pid == self.tile_process_map[tile_index] - and self.pid == self.tile_process_map[neigh_tile_index] - ): + if self.pid == self.tile_process_map[tile_index] and self.pid == self.tile_process_map[neigh_tile_index]: # Get the tile and neighboring tile tile = self.tiles[tile_index] neigh_tile = self.tiles[neigh_tile_index] @@ -371,10 +340,7 @@ def update_padding(self): neigh_tile._buf_padding[neigh_pad_index] = padding # Case 2: the tile is on this process and the neighboring tile is on another process - if ( - self.pid == self.tile_process_map[tile_index] - and self.pid != self.tile_process_map[neigh_tile_index] - ): + if self.pid == self.tile_process_map[tile_index] and self.pid != self.tile_process_map[neigh_tile_index]: # Get the tile and padding tile = self.tiles[tile_index] padding = tile._padding[pad_index] @@ -387,10 +353,7 @@ def update_padding(self): ) # Case 3: the tile is on another process and the neighboring tile is on this process - if ( - self.pid != self.tile_process_map[tile_index] - and self.pid == self.tile_process_map[neigh_tile_index] - ): + if self.pid != self.tile_process_map[tile_index] and self.pid == self.tile_process_map[neigh_tile_index]: # Get the neighboring tile and padding neigh_tile = self.tiles[neigh_tile_index] neigh_padding = neigh_tile._buf_padding[neigh_pad_index] @@ -403,10 +366,7 @@ def update_padding(self): ) # Case 4: the tile and neighboring tile are on different processes - if ( - self.pid != self.tile_process_map[tile_index] - and self.pid != self.tile_process_map[neigh_tile_index] - ): + if self.pid != self.tile_process_map[tile_index] and self.pid != self.tile_process_map[neigh_tile_index]: pass # Increment the communication tag @@ -429,12 +389,7 @@ def get_array(self): comm_tag = 0 for tile_index in self.tile_process_map.keys(): # Set the center array in the full array - slice_index = tuple( - [ - slice(i * s, (i + 1) * s) - for (i, s) in zip(tile_index, self.tile_shape) - ] - ) + slice_index = tuple([slice(i * s, (i + 1) * s) for (i, s) in zip(tile_index, self.tile_shape)]) # if tile on this process compute the center array if self.comm.rank == self.tile_process_map[tile_index]: @@ -465,18 +420,13 @@ def get_array(self): if self.comm.rank == 0 and self.tile_process_map[tile_index] != 0: # Get the data from the other rank center_array = np.empty(self.tile_shape, dtype=self.dtype) - self.comm.Recv( - center_array, source=self.tile_process_map[tile_index], tag=comm_tag - ) + self.comm.Recv(center_array, source=self.tile_process_map[tile_index], tag=comm_tag) # Set the center array in the full array array[slice_index] = center_array # Case 3: the tile is on this rank and this process is not rank 0 - if ( - self.comm.rank != 0 - and self.tile_process_map[tile_index] == self.comm.rank - ): + if self.comm.rank != 0 and self.tile_process_map[tile_index] == self.comm.rank: # Send the data to rank 0 self.comm.Send(center_array, dest=0, tag=comm_tag) diff --git a/xlb/experimental/ooc/out_of_core.py b/xlb/experimental/ooc/out_of_core.py index 01851e8..bc42fab 100644 --- a/xlb/experimental/ooc/out_of_core.py +++ b/xlb/experimental/ooc/out_of_core.py @@ -1,17 +1,11 @@ # Out-of-core decorator for functions that take a lot of memory -import functools -import warp as wp import cupy as cp -import jax.dlpack as jdlpack -import jax -import numpy as np from xlb.experimental.ooc.ooc_array import OOCArray from xlb.experimental.ooc.utils import ( _cupy_to_backend, _backend_to_cupy, - _stream_to_backend, ) @@ -47,9 +41,7 @@ def wrapper(*args): # TODO: Add better checks for ooc_array in ooc_array_args: if ooc_array_args[0].tile_dims != ooc_array.tile_dims: - raise ValueError( - f"Tile dimensions of ooc arrays do not match. {ooc_array_args[0].tile_dims} != {ooc_array.tile_dims}" - ) + raise ValueError(f"Tile dimensions of ooc arrays do not match. {ooc_array_args[0].tile_dims} != {ooc_array.tile_dims}") # Apply the function to each of the ooc arrays for tile_index in ooc_array_args[0].tiles.keys(): @@ -79,9 +71,7 @@ def wrapper(*args): results = (results,) # Convert the results back to cupy arrays - results = tuple( - [_backend_to_cupy(result, backend) for result in results] - ) + results = tuple([_backend_to_cupy(result, backend) for result in results]) # Write the results back to the ooc array for arg_index, result in zip(ref_args, results): diff --git a/xlb/experimental/ooc/tiles/compressed_tile.py b/xlb/experimental/ooc/tiles/compressed_tile.py index 415f83b..ccdd2bb 100644 --- a/xlb/experimental/ooc/tiles/compressed_tile.py +++ b/xlb/experimental/ooc/tiles/compressed_tile.py @@ -1,9 +1,6 @@ import numpy as np import cupy as cp -import itertools -from dataclasses import dataclass import warnings -import time try: from kvikio._lib.arr import asarray @@ -98,9 +95,7 @@ def compression_ratio(self): # Get total number of bytes in uncompressed tile total_bytes_uncompressed = np.prod(self.shape) * self.dtype_itemsize for pad_ind in self.pad_ind: - total_bytes_uncompressed += ( - np.prod(self._padding_shape[pad_ind]) * self.dtype_itemsize - ) + total_bytes_uncompressed += np.prod(self._padding_shape[pad_ind]) * self.dtype_itemsize # Return compression ratio return total_bytes_uncompressed, total_bytes @@ -147,9 +142,7 @@ def to_gpu_tile(self, dst_gpu_tile): """Copy tile to a GPU tile.""" # Check tile is Compressed - assert isinstance( - dst_gpu_tile, CompressedGPUTile - ), "Destination tile must be a CompressedGPUTile" + assert isinstance(dst_gpu_tile, CompressedGPUTile), "Destination tile must be a CompressedGPUTile" # Copy array dst_gpu_tile._array[: len(self._array.array)].set(self._array.array) @@ -157,9 +150,7 @@ def to_gpu_tile(self, dst_gpu_tile): # Copy padding for pad_ind in self.pad_ind: - dst_gpu_tile._padding[pad_ind][: len(self._padding[pad_ind].array)].set( - self._padding[pad_ind].array - ) + dst_gpu_tile._padding[pad_ind][: len(self._padding[pad_ind].array)].set(self._padding[pad_ind].array) dst_gpu_tile._padding_bytes[pad_ind] = self._padding[pad_ind].nbytes @@ -186,9 +177,7 @@ def allocate_array(self, shape): """Returns a cupy array with the given shape.""" nbytes = np.prod(shape) * self.dtype_itemsize codec = self.codec() - max_compressed_buffer = codec._manager.configure_compression(nbytes)[ - "max_compressed_buffer_size" - ] + max_compressed_buffer = codec._manager.configure_compression(nbytes)["max_compressed_buffer_size"] array = cp.zeros((max_compressed_buffer,), dtype=np.uint8) return array @@ -198,9 +187,7 @@ def to_array(self, array): # Copy center array if self._array_codec is None: self._array_codec = self.codec() - self._array_codec._manager.configure_decompression_with_compressed_buffer( - asarray(self._array[: self._array_bytes]) - ) + self._array_codec._manager.configure_decompression_with_compressed_buffer(asarray(self._array[: self._array_bytes])) self._array_codec.decompression_config = self._array_codec._manager.configure_decompression_with_compressed_buffer( asarray(self._array[: self._array_bytes]) ) @@ -217,17 +204,13 @@ def to_array(self, array): self._padding_codec[pad_ind] = self.codec() self._padding_codec[pad_ind].decompression_config = self._padding_codec[ pad_ind - ]._manager.configure_decompression_with_compressed_buffer( - asarray(self._padding[pad_ind][: self._padding_bytes[pad_ind]]) - ) + ]._manager.configure_decompression_with_compressed_buffer(asarray(self._padding[pad_ind][: self._padding_bytes[pad_ind]])) self.dense_gpu_tile._padding[pad_ind] = _decode( self._padding[pad_ind][: self._padding_bytes[pad_ind]], self.dense_gpu_tile._padding[pad_ind], self._padding_codec[pad_ind], ) - array[self._slice_padding_to_array[pad_ind]] = self.dense_gpu_tile._padding[ - pad_ind - ] + array[self._slice_padding_to_array[pad_ind]] = self.dense_gpu_tile._padding[pad_ind] def from_array(self, array): """Copy a full array to tile.""" @@ -236,17 +219,13 @@ def from_array(self, array): if self._array_codec is None: self._array_codec = self.codec() self._array_codec.configure_compression(self._array.nbytes) - self._array_bytes = _encode( - array[self._slice_center], self._array, self._array_codec - ) + self._array_bytes = _encode(array[self._slice_center], self._array, self._array_codec) # Copy padding for pad_ind in self.pad_ind: if pad_ind not in self._padding_codec: self._padding_codec[pad_ind] = self.codec() - self._padding_codec[pad_ind].configure_compression( - self._padding[pad_ind].nbytes - ) + self._padding_codec[pad_ind].configure_compression(self._padding[pad_ind].nbytes) self._padding_bytes[pad_ind] = _encode( array[self._slice_array_to_padding[pad_ind]], self._padding[pad_ind], @@ -257,9 +236,7 @@ def to_cpu_tile(self, dst_cpu_tile): """Copy tile to a CPU tile.""" # Check tile is Compressed - assert isinstance( - dst_cpu_tile, CompressedCPUTile - ), "Destination tile must be a CompressedCPUTile" + assert isinstance(dst_cpu_tile, CompressedCPUTile), "Destination tile must be a CompressedCPUTile" # Copy array dst_cpu_tile._array.resize(self._array_bytes) @@ -268,6 +245,4 @@ def to_cpu_tile(self, dst_cpu_tile): # Copy padding for pad_ind in self.pad_ind: dst_cpu_tile._padding[pad_ind].resize(self._padding_bytes[pad_ind]) - self._padding[pad_ind][: self._padding_bytes[pad_ind]].get( - out=dst_cpu_tile._padding[pad_ind].array - ) + self._padding[pad_ind][: self._padding_bytes[pad_ind]].get(out=dst_cpu_tile._padding[pad_ind].array) diff --git a/xlb/experimental/ooc/tiles/dense_tile.py b/xlb/experimental/ooc/tiles/dense_tile.py index 8a303e4..41fc129 100644 --- a/xlb/experimental/ooc/tiles/dense_tile.py +++ b/xlb/experimental/ooc/tiles/dense_tile.py @@ -1,7 +1,5 @@ import numpy as np import cupy as cp -import itertools -from dataclasses import dataclass from xlb.experimental.ooc.tiles.tile import Tile @@ -46,9 +44,7 @@ def allocate_array(self, shape): """Returns a cupy array with the given shape.""" # TODO: Seems hacky, but it works. Is there a better way? mem = cp.cuda.alloc_pinned_memory(np.prod(shape) * self.dtype_itemsize) - array = np.frombuffer(mem, dtype=self.dtype, count=np.prod(shape)).reshape( - shape - ) + array = np.frombuffer(mem, dtype=self.dtype, count=np.prod(shape)).reshape(shape) self.nbytes += mem.size() return array @@ -62,9 +58,7 @@ def to_gpu_tile(self, dst_gpu_tile): dst_gpu_tile._array.set(self._array) # Copy padding - for src_array, dst_gpu_array in zip( - self._padding.values(), dst_gpu_tile._padding.values() - ): + for src_array, dst_gpu_array in zip(self._padding.values(), dst_gpu_tile._padding.values()): dst_gpu_array.set(src_array) @@ -90,7 +84,5 @@ def to_cpu_tile(self, dst_cpu_tile): self._array.get(out=dst_cpu_tile._array) # Copy padding - for src_array, dst_array in zip( - self._padding.values(), dst_cpu_tile._padding.values() - ): + for src_array, dst_array in zip(self._padding.values(), dst_cpu_tile._padding.values()): src_array.get(out=dst_array) diff --git a/xlb/experimental/ooc/tiles/dynamic_array.py b/xlb/experimental/ooc/tiles/dynamic_array.py index 2b05b2e..403d164 100644 --- a/xlb/experimental/ooc/tiles/dynamic_array.py +++ b/xlb/experimental/ooc/tiles/dynamic_array.py @@ -3,7 +3,6 @@ import math import cupy as cp import numpy as np -import time class DynamicArray: @@ -46,17 +45,12 @@ def resize(self, nbytes): self.nbytes = nbytes # Check if the number of bytes requested is less than 2xbytes_resize or if the number of bytes requested exceeds the allocated number of bytes - if ( - nbytes < (self.allocated_bytes - 2 * self.bytes_resize) - or nbytes > self.allocated_bytes - ): + if nbytes < (self.allocated_bytes - 2 * self.bytes_resize) or nbytes > self.allocated_bytes: ## Free the memory # del self.mem # Set the new number of allocated bytes - self.allocated_bytes = ( - math.ceil(nbytes / self.bytes_resize) * self.bytes_resize - ) + self.allocated_bytes = math.ceil(nbytes / self.bytes_resize) * self.bytes_resize # Allocate the memory self.mem = cp.cuda.alloc_pinned_memory(self.allocated_bytes) diff --git a/xlb/experimental/ooc/tiles/tile.py b/xlb/experimental/ooc/tiles/tile.py index 90c3334..9bb347b 100644 --- a/xlb/experimental/ooc/tiles/tile.py +++ b/xlb/experimental/ooc/tiles/tile.py @@ -1,7 +1,5 @@ -import numpy as np import cupy as cp import itertools -from dataclasses import dataclass class Tile: @@ -25,9 +23,7 @@ def __init__(self, shape, dtype, padding, codec=None): self.padding = padding self.dtype_itemsize = cp.dtype(self.dtype).itemsize self.nbytes = 0 # Updated when array is allocated - self.codec = ( - codec # Codec to use for compression TODO: Find better abstraction for this - ) + self.codec = codec # Codec to use for compression TODO: Find better abstraction for this # Make center array self._array = self.allocate_array(self.shape) @@ -59,9 +55,7 @@ def __init__(self, shape, dtype, padding, codec=None): self._buf_padding[ind] = self.allocate_array(shape) # Get slicing for array copies - self._slice_center = tuple( - [slice(pad, pad + shape) for (pad, shape) in zip(self.padding, self.shape)] - ) + self._slice_center = tuple([slice(pad, pad + shape) for (pad, shape) in zip(self.padding, self.shape)]) self._slice_padding_to_array = {} self._slice_array_to_padding = {} self._padding_shape = {} diff --git a/xlb/experimental/ooc/utils.py b/xlb/experimental/ooc/utils.py index f607128..1179c76 100644 --- a/xlb/experimental/ooc/utils.py +++ b/xlb/experimental/ooc/utils.py @@ -70,7 +70,7 @@ def _stream_to_backend(stream, backend): # Convert stream to backend stream if backend == "jax": raise ValueError("Jax currently does not support streams") - elif backend == "warp": + if backend == "warp": backend_stream = wp.Stream(cuda_stream=stream.ptr) elif backend == "cupy": backend_stream = stream diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index 692b453..7d9ec24 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1 +1 @@ -from xlb.grid.grid import grid_factory \ No newline at end of file +from xlb.grid.grid import grid_factory as grid_factory diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 483386d..7d8a678 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,14 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, Tuple +from typing import Tuple from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import Precision -def grid_factory( - shape: Tuple[int, ...], - compute_backend: ComputeBackend = None): +def grid_factory(shape: Tuple[int, ...], compute_backend: ComputeBackend = None): compute_backend = compute_backend or DefaultConfig.default_backend if compute_backend == ComputeBackend.WARP: from xlb.grid.warp_grid import WarpGrid @@ -38,16 +35,17 @@ def _bounding_box_indices(self): """ This function calculates the indices of the bounding box of a 2D or 3D grid. The bounding box is defined as the set of grid points on the outer edge of the grid. - + Returns ------- boundingBox (dict): A dictionary where keys are the names of the bounding box faces ("bottom", "top", "left", "right" for 2D; additional "front", "back" for 3D), and values are numpy arrays of indices corresponding to each face. """ - def to_tuple(list): - d = len(list[0]) - return [tuple([sublist[i] for sublist in list]) for i in range(d)] + + def to_tuple(lst): + d = len(lst[0]) + return [tuple([sublist[i] for sublist in lst]) for i in range(d)] if self.dim == 2: # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. @@ -58,9 +56,9 @@ def to_tuple(list): "bottom": to_tuple([[i, 0] for i in range(nx)]), "top": to_tuple([[i, ny - 1] for i in range(nx)]), "left": to_tuple([[0, i] for i in range(ny)]), - "right": to_tuple([[nx - 1, i] for i in range(ny)]) + "right": to_tuple([[nx - 1, i] for i in range(ny)]), } - + elif self.dim == 3: # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. # Each face is represented as an array of indices. For example, the bottom face includes all points @@ -72,7 +70,6 @@ def to_tuple(list): "left": to_tuple([[0, j, k] for j in range(ny) for k in range(nz)]), "right": to_tuple([[nx - 1, j, k] for j in range(ny) for k in range(nz)]), "front": to_tuple([[i, 0, k] for i in range(nx) for k in range(nz)]), - "back": to_tuple([[i, ny - 1, k] for i in range(nx) for k in range(nz)]) + "back": to_tuple([[i, ny - 1, k] for i in range(nx) for k in range(nz)]), } - return - + return diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 289790a..24eeb03 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -1,19 +1,16 @@ -from typing import Any, Literal, Optional, Tuple +from typing import Literal from jax.sharding import PartitionSpec as P from jax.sharding import NamedSharding, Mesh from jax.experimental import mesh_utils -from jax.experimental.shard_map import shard_map from xlb.compute_backend import ComputeBackend import jax.numpy as jnp -from jax import lax import jax from xlb import DefaultConfig from .grid import Grid -from xlb.operator import Operator from xlb.precision_policy import Precision @@ -25,9 +22,7 @@ def _initialize_backend(self): self.nDevices = jax.device_count() self.backend = jax.default_backend() self.device_mesh = ( - mesh_utils.create_device_mesh((1, self.nDevices, 1)) - if self.dim == 2 - else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) + mesh_utils.create_device_mesh((1, self.nDevices, 1)) if self.dim == 2 else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) ) self.global_mesh = ( Mesh(self.device_mesh, axis_names=("cardinality", "x", "y")) @@ -53,9 +48,7 @@ def create_field( dtype = dtype.jax_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.jax_dtype - for d, index in self.sharding.addressable_devices_indices_map( - full_shape - ).items(): + for d, index in self.sharding.addressable_devices_indices_map(full_shape).items(): jax.default_device = d if fill_value: x = jnp.full(device_shape, fill_value, dtype=dtype) @@ -63,6 +56,4 @@ def create_field( x = jnp.zeros(shape=device_shape, dtype=dtype) arrays += [jax.device_put(x, d)] jax.default_device = jax.devices()[0] - return jax.make_array_from_single_device_arrays( - full_shape, self.sharding, arrays - ) \ No newline at end of file + return jax.make_array_from_single_device_arrays(full_shape, self.sharding, arrays) diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index 75c3f14..5018962 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -1,13 +1,10 @@ -from dataclasses import field import warp as wp from .grid import Grid -from xlb.operator import Operator from xlb.precision_policy import Precision from xlb.compute_backend import ComputeBackend from typing import Literal from xlb import DefaultConfig -import numpy as np class WarpGrid(Grid): @@ -23,11 +20,7 @@ def create_field( dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, fill_value=None, ): - dtype = ( - dtype.wp_dtype - if dtype - else DefaultConfig.default_precision_policy.store_precision.wp_dtype - ) + dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype shape = (cardinality,) + (self.shape) if fill_value is None: diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py index 29ac3f6..92d3583 100644 --- a/xlb/helper/__init__.py +++ b/xlb/helper/__init__.py @@ -1,2 +1,2 @@ -from xlb.helper.nse_solver import create_nse_fields -from xlb.helper.initializers import initialize_eq \ No newline at end of file +from xlb.helper.nse_solver import create_nse_fields as create_nse_fields +from xlb.helper.initializers import initialize_eq as initialize_eq diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index 4462edc..a42c6ac 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -1,21 +1,13 @@ -import xlb -from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig from xlb.grid import grid_factory from xlb.precision_policy import Precision from typing import Tuple -def create_nse_fields( - grid_shape: Tuple[int, int, int], velocity_set=None, compute_backend=None, precision_policy=None -): - velocity_set = velocity_set if velocity_set else DefaultConfig.velocity_set - compute_backend = ( - compute_backend if compute_backend else DefaultConfig.default_backend - ) - precision_policy = ( - precision_policy if precision_policy else DefaultConfig.default_precision_policy - ) +def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compute_backend=None, precision_policy=None): + velocity_set = velocity_set or DefaultConfig.velocity_set + compute_backend = compute_backend or DefaultConfig.default_backend + precision_policy = precision_policy or DefaultConfig.default_precision_policy grid = grid_factory(grid_shape, compute_backend=compute_backend) # Create fields @@ -25,4 +17,3 @@ def create_nse_fields( boundary_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) return grid, f_0, f_1, missing_mask, boundary_mask - diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 02b8a59..c88ef83 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.operator import Operator -from xlb.operator.parallel_operator import ParallelOperator +from xlb.operator.operator import Operator as Operator +from xlb.operator.parallel_operator import ParallelOperator as ParallelOperator diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 275a074..1fd2152 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,8 +1,8 @@ -from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition as BoundaryCondition from xlb.operator.boundary_condition.boundary_condition_registry import ( - BoundaryConditionRegistry, + BoundaryConditionRegistry as BoundaryConditionRegistry, ) -from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC -from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC -from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC -from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC +from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC as EquilibriumBC +from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC as DoNothingBC +from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC +from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 38697c3..a57e427 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -6,7 +6,7 @@ from jax import jit from functools import partial import warp as wp -from typing import Any, List +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -34,7 +34,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): super().__init__( ImplementationStep.STREAMING, @@ -53,9 +53,7 @@ def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Construct the funcional to get streamed indices diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 0af61e1..f1d85d3 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -35,11 +35,11 @@ def __init__( self, rho: float, u: Tuple[float, float, float], - equilibrium_operator : Operator = None, + equilibrium_operator: Operator = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): # Store the equilibrium information self.rho = rho @@ -73,14 +73,8 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(self.rho) - _u = ( - _u_vec(self.u[0], self.u[1], self.u[2]) - if self.velocity_set.d == 3 - else _u_vec(self.u[0], self.u[1]) - ) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index a445e07..5b2f2c1 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -6,7 +6,7 @@ from jax import jit from functools import partial import warp as wp -from typing import Any, List +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -33,7 +33,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): super().__init__( ImplementationStep.COLLISION, @@ -48,16 +48,14 @@ def __init__( def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return jnp.where(boundary, f_pre[self.velocity_set.opp_indices,...], f_post) + return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _opp_indices = self.velocity_set.wp_opp_indices _q = wp.constant(self.velocity_set.q) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 55472a3..1fa4a7c 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -3,12 +3,10 @@ """ import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax +from jax import jit from functools import partial -import numpy as np import warp as wp -from typing import Tuple, Any, List +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -37,7 +35,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): # Call the parent constructor super().__init__( @@ -64,9 +62,7 @@ def _construct_warp(self): _c = self.velocity_set.wp_c _opp_indices = self.velocity_set.wp_opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func def functional2d( diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index dbeadbc..125e45d 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -2,12 +2,7 @@ Base class for boundary conditions in a LBM simulation. """ -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np from enum import Enum, auto -from typing import List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -15,6 +10,7 @@ from xlb.operator.operator import Operator from xlb import DefaultConfig + # Enum for implementation step class ImplementationStep(Enum): COLLISION = auto() @@ -32,7 +28,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 0a3b2c7..5b1e092 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -19,11 +19,11 @@ def register_boundary_condition(self, boundary_condition): """ Register a boundary condition. """ - id = self.next_id + _id = self.next_id self.next_id += 1 - self.id_to_bc[id] = boundary_condition - self.bc_to_id[boundary_condition] = id - return id + self.id_to_bc[_id] = boundary_condition + self.bc_to_id[boundary_condition] = _id + return _id boundary_condition_registry = BoundaryConditionRegistry() diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index cc80b85..262e638 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,6 +1,6 @@ from xlb.operator.boundary_masker.indices_boundary_masker import ( - IndicesBoundaryMasker, + IndicesBoundaryMasker as IndicesBoundaryMasker, ) from xlb.operator.boundary_masker.stl_boundary_masker import ( - STLBoundaryMasker, + STLBoundaryMasker as STLBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 460bd3b..7960083 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -22,41 +22,32 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) - @Operator.register_backend(ComputeBackend.JAX) - # TODO HS: figure out why uncommenting the line below fails unlike other operators! + # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation( - self, bclist, boundary_mask, mask, start_index=None - ): + def jax_implementation(self, bclist, boundary_mask, mask, start_index=None): # define a helper function def compute_boundary_id_and_mask(boundary_mask, mask): if dim == 2: - boundary_mask = boundary_mask.at[ - 0, local_indices[0], local_indices[1] - ].set(id_number) + boundary_mask = boundary_mask.at[0, local_indices[0], local_indices[1]].set(id_number) mask = mask.at[:, local_indices[0], local_indices[1]].set(True) if dim == 3: - boundary_mask = boundary_mask.at[ - 0, local_indices[0], local_indices[1], local_indices[2] - ].set(id_number) - mask = mask.at[ - :, local_indices[0], local_indices[1], local_indices[2] - ].set(True) + boundary_mask = boundary_mask.at[0, local_indices[0], local_indices[1], local_indices[2]].set(id_number) + mask = mask.at[:, local_indices[0], local_indices[1], local_indices[2]].set(True) return boundary_mask, mask - + dim = mask.ndim - 1 if start_index is None: start_index = (0,) * dim for bc in bclist: - assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC!' + assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" id_number = bc.id local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] boundary_mask, mask = compute_boundary_id_and_mask(boundary_mask, mask) # We are done with bc.indices. Remove them from BC objects - bc.__dict__.pop('indices', None) + bc.__dict__.pop("indices", None) mask = self.stream(mask) return boundary_mask, mask @@ -84,12 +75,7 @@ def kernel2d( index[1] = indices[1, ii] - start_index[1] # Check if in bounds - if ( - index[0] >= 0 - and index[0] < mask.shape[1] - and index[1] >= 0 - and index[1] < mask.shape[2] - ): + if index[0] >= 0 and index[0] < mask.shape[1] and index[1] >= 0 and index[1] < mask.shape[2]: # Stream indices for l in range(_q): # Get the index of the streaming direction @@ -146,10 +132,7 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation( - self, bclist, boundary_mask, missing_mask, start_index=None - ): - + def warp_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] @@ -159,10 +142,10 @@ def warp_implementation( index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) # We are done with bc.indices. Remove them from BC objects - bc.__dict__.pop('indices', None) - - indices = wp.array2d(index_list, dtype = wp.int32) - id_number = wp.array1d(id_list, dtype = wp.uint8) + bc.__dict__.pop("indices", None) + + indices = wp.array2d(index_list, dtype=wp.int32) + id_number = wp.array1d(id_list, dtype=wp.uint8) if start_index is None: start_index = (0,) * dim diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index c2cfc30..b4ea8ca 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -1,19 +1,13 @@ # Base class for all equilibriums -from functools import partial import numpy as np from stl import mesh as np_mesh -import jax.numpy as jnp -from jax import jit import warp as wp -from typing import Tuple -from xlb import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.stream.stream import Stream class STLBoundaryMasker(Operator): @@ -56,9 +50,7 @@ def kernel( index[2] = k - start_index[2] # position of the point - ijk = wp.vec3( - wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]) - ) + ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center pos = wp.cw_mul(ijk, spacing) + origin @@ -74,9 +66,7 @@ def kernel( face_u = float(0.0) face_v = float(0.0) sign = float(0.0) - if wp.mesh_query_point_sign_winding_number( - mesh, pos, max_length, sign, face_index, face_u, face_v - ): + if wp.mesh_query_point_sign_winding_number(mesh, pos, max_length, sign, face_index, face_u, face_v): # set point to be solid if sign <= 0: # TODO: fix this # Stream indices @@ -87,9 +77,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_mask[ - 0, push_index[0], push_index[1], push_index[2] - ] = wp.uint8(id_number) + boundary_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel diff --git a/xlb/operator/collision/__init__.py b/xlb/operator/collision/__init__.py index 77395e6..b48d0ce 100644 --- a/xlb/operator/collision/__init__.py +++ b/xlb/operator/collision/__init__.py @@ -1,3 +1,3 @@ -from xlb.operator.collision.collision import Collision -from xlb.operator.collision.bgk import BGK -from xlb.operator.collision.kbc import KBC +from xlb.operator.collision.collision import Collision as Collision +from xlb.operator.collision.bgk import BGK as BGK +from xlb.operator.collision.kbc import KBC as KBC diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index ec40b56..fa0857a 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -70,15 +70,13 @@ def jax_implementation( shear = self.decompose_shear_d3q27_jax(fneq) delta_s = shear * rho else: - raise NotImplementedError( - "Velocity set not supported: {}".format(type(self.velocity_set)) - ) + raise NotImplementedError("Velocity set not supported: {}".format(type(self.velocity_set))) # Perform collision delta_h = fneq - delta_s - gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product( - delta_s, delta_h, feq - ) / (self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq)) + gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product(delta_s, delta_h, feq) / ( + self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq) + ) fout = f - self.beta * (2.0 * delta_s + gamma[None, ...] * delta_h) @@ -206,11 +204,7 @@ def decompose_shear_d2q9_jax(self, fneq): def _construct_warp(self): # Raise error if velocity set is not supported if not isinstance(self.velocity_set, D3Q27): - raise NotImplementedError( - "Velocity set not supported for warp backend: {}".format( - type(self.velocity_set) - ) - ) + raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set))) # Set local constants TODO: This is a hack and should be fixed with warp update _w = self.velocity_set.wp_w @@ -323,9 +317,9 @@ def functional2d( # Perform collision delta_h = fneq - delta_s - gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( - delta_s, delta_h, feq - ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( + _epsilon + entropic_scalar_product(delta_h, delta_h, feq) + ) fout = f - _beta * (2.0 * delta_s + gamma * delta_h) return fout @@ -345,9 +339,9 @@ def functional3d( # Perform collision delta_h = fneq - delta_s - gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( - delta_s, delta_h, feq - ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( + _epsilon + entropic_scalar_product(delta_h, delta_h, feq) + ) fout = f - _beta * (2.0 * delta_s + gamma * delta_h) return fout @@ -362,12 +356,13 @@ def kernel2d( fout: wp.array3d(dtype=Any), ): # Get the global index - i, j, k = wp.tid() + i, j = wp.tid() index = wp.vec3i(i, j) # TODO: Warp needs to fix this # Load needed values _f = _f_vec() _feq = _f_vec() + _d = self.velocity_set.d for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1]] _feq[l] = feq[l, index[0], index[1]] @@ -399,6 +394,7 @@ def kernel3d( # Load needed values _f = _f_vec() _feq = _f_vec() + _d = self.velocity_set.d for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py index 42b601e..b9f9f08 100644 --- a/xlb/operator/equilibrium/__init__.py +++ b/xlb/operator/equilibrium/__init__.py @@ -1,4 +1,4 @@ from xlb.operator.equilibrium.quadratic_equilibrium import ( - Equilibrium, - QuadraticEquilibrium, + Equilibrium as Equilibrium, + QuadraticEquilibrium as QuadraticEquilibrium, ) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 794d78d..3af6b4a 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -4,11 +4,10 @@ import warp as wp from typing import Any -from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator import Operator -from xlb import DefaultConfig + class QuadraticEquilibrium(Equilibrium): """ diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 91eb36c..3463078 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1 +1 @@ -from xlb.operator.macroscopic.macroscopic import Macroscopic +from xlb.operator.macroscopic.macroscopic import Macroscopic as Macroscopic diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 7fa309f..13d3817 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -4,10 +4,8 @@ import jax.numpy as jnp from jax import jit import warp as wp -from typing import Tuple, Any +from typing import Any -from xlb import DefaultConfig -from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 38e8e15..83c6538 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,12 +1,7 @@ -# Base class for all operators, (collision, streaming, equilibrium, etc.) - import inspect -import warp as wp -from typing import Any import traceback from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy, Precision from xlb import DefaultConfig @@ -22,9 +17,7 @@ class Operator: def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): # Set the default values from the global config self.velocity_set = velocity_set or DefaultConfig.velocity_set - self.precision_policy = ( - precision_policy or DefaultConfig.default_precision_policy - ) + self.precision_policy = precision_policy or DefaultConfig.default_precision_policy self.compute_backend = compute_backend or DefaultConfig.default_backend # Check if the compute backend is supported @@ -52,17 +45,13 @@ def decorator(func): def __call__(self, *args, callback=None, **kwargs): method_candidates = [ - (key, method) - for key, method in self._backends.items() - if key[0] == self.__class__.__name__ and key[1] == self.compute_backend + (key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend ] bound_arguments = None for key, backend_method in method_candidates: try: # This attempts to bind the provided args and kwargs to the backend method's signature - bound_arguments = inspect.signature(backend_method).bind( - self, *args, **kwargs - ) + bound_arguments = inspect.signature(backend_method).bind(self, *args, **kwargs) bound_arguments.apply_defaults() # This fills in any default values result = backend_method(self, *args, **kwargs) callback_arg = result if result is not None else (args, kwargs) @@ -74,9 +63,7 @@ def __call__(self, *args, callback=None, **kwargs): traceback_str = traceback.format_exc() continue # This skips to the next candidate if binding fails - raise Exception( - f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}" - ) + raise Exception(f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}") @property def supported_compute_backend(self): diff --git a/xlb/operator/parallel_operator.py b/xlb/operator/parallel_operator.py index 9309b21..9f9b5c5 100644 --- a/xlb/operator/parallel_operator.py +++ b/xlb/operator/parallel_operator.py @@ -65,12 +65,8 @@ def _parallel_func(self, f): jax.numpy.ndarray The processed data. """ - rightPerm = [ - (i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices) - ] - leftPerm = [ - ((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices) - ] + rightPerm = [(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)] + leftPerm = [((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)] f = self.func(f) left_comm, right_comm = ( f[self.velocity_set.right_indices, :1, ...], diff --git a/xlb/operator/precision_caster/__init__.py b/xlb/operator/precision_caster/__init__.py index a027c52..c333ab7 100644 --- a/xlb/operator/precision_caster/__init__.py +++ b/xlb/operator/precision_caster/__init__.py @@ -1 +1 @@ -from xlb.operator.precision_caster.precision_caster import PrecisionCaster +from xlb.operator.precision_caster.precision_caster import PrecisionCaster as PrecisionCaster diff --git a/xlb/operator/precision_caster/precision_caster.py b/xlb/operator/precision_caster/precision_caster.py index cb441c5..5427cba 100644 --- a/xlb/operator/precision_caster/precision_caster.py +++ b/xlb/operator/precision_caster/precision_caster.py @@ -3,10 +3,9 @@ """ import jax.numpy as jnp -from jax import jit, device_count +from jax import jit +import warp as wp from functools import partial -import numpy as np -from enum import Enum from xlb.operator.operator import Operator from xlb.velocity_set import VelocitySet diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py index e5d159c..528375d 100644 --- a/xlb/operator/stepper/__init__.py +++ b/xlb/operator/stepper/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.stepper.stepper import Stepper -from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper +from xlb.operator.stepper.stepper import Stepper as Stepper +from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper as IncompressibleNavierStokesStepper diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 3f986d3..3fad2b1 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -1,13 +1,11 @@ # Base class for all stepper operators -from logging import warning from functools import partial from jax import jit import warp as wp from typing import Any from xlb import DefaultConfig -from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.stream import Stream @@ -32,9 +30,7 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): # Construct the operators self.stream = Stream(velocity_set, precision_policy, compute_backend) - self.equilibrium = QuadraticEquilibrium( - velocity_set, precision_policy, compute_backend - ) + self.equilibrium = QuadraticEquilibrium(velocity_set, precision_policy, compute_backend) self.macroscopic = Macroscopic(velocity_set, precision_policy, compute_backend) operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] @@ -91,9 +87,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Get the boundary condition ids _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) @@ -129,19 +123,13 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) elif _boundary_id == _equilibrium_bc: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _do_nothing_bc: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _halfway_bounce_back_bc: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -199,19 +187,13 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) elif _boundary_id == _equilibrium_bc: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _do_nothing_bc: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _halfway_bounce_back_bc: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -220,9 +202,7 @@ def kernel3d( feq = self.equilibrium.warp_functional(rho, u) # Apply collision - f_post_collision = self.collision.warp_functional( - f_post_stream, feq, rho, u - ) + f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions if _boundary_id == _fullway_bounce_back_bc: diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c11b39b..adc2564 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,16 +1,5 @@ # Base class for all stepper operators - -from ast import Raise -from functools import partial -import jax.numpy as jnp -from jax import jit -import warp as wp - -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.precision_caster import PrecisionCaster from xlb.operator.equilibrium import Equilibrium from xlb import DefaultConfig @@ -24,39 +13,17 @@ def __init__(self, operators, boundary_conditions): self.operators = operators self.boundary_conditions = boundary_conditions # Get velocity set, precision policy, and compute backend - velocity_sets = set( - [op.velocity_set for op in self.operators if op is not None] - ) - assert ( - len(velocity_sets) < 2 - ), "All velocity sets must be the same. Got {}".format(velocity_sets) - velocity_set = ( - DefaultConfig.velocity_set if not velocity_sets else velocity_sets.pop() - ) + velocity_sets = set([op.velocity_set for op in self.operators if op is not None]) + assert len(velocity_sets) < 2, "All velocity sets must be the same. Got {}".format(velocity_sets) + velocity_set = DefaultConfig.velocity_set if not velocity_sets else velocity_sets.pop() - precision_policies = set( - [op.precision_policy for op in self.operators if op is not None] - ) - assert ( - len(precision_policies) < 2 - ), "All precision policies must be the same. Got {}".format(precision_policies) - precision_policy = ( - DefaultConfig.default_precision_policy - if not precision_policies - else precision_policies.pop() - ) + precision_policies = set([op.precision_policy for op in self.operators if op is not None]) + assert len(precision_policies) < 2, "All precision policies must be the same. Got {}".format(precision_policies) + precision_policy = DefaultConfig.default_precision_policy if not precision_policies else precision_policies.pop() - compute_backends = set( - [op.compute_backend for op in self.operators if op is not None] - ) - assert ( - len(compute_backends) < 2 - ), "All compute backends must be the same. Got {}".format(compute_backends) - compute_backend = ( - DefaultConfig.default_backend - if not compute_backends - else compute_backends.pop() - ) + compute_backends = set([op.compute_backend for op in self.operators if op is not None]) + assert len(compute_backends) < 2, "All compute backends must be the same. Got {}".format(compute_backends) + compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() # Add boundary conditions # Warp cannot handle lists of functions currently @@ -93,9 +60,7 @@ def __init__(self, operators, boundary_conditions): self.equilibrium_bc = EquilibriumBC( rho=1.0, u=(0.0, 0.0, 0.0), - equilibrium_operator=next( - (op for op in self.operators if isinstance(op, Equilibrium)), None - ), + equilibrium_operator=next((op for op in self.operators if isinstance(op, Equilibrium)), None), velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, diff --git a/xlb/operator/stream/__init__.py b/xlb/operator/stream/__init__.py index 9093da7..2f5b2f3 100644 --- a/xlb/operator/stream/__init__.py +++ b/xlb/operator/stream/__init__.py @@ -1 +1 @@ -from xlb.operator.stream.stream import Stream +from xlb.operator.stream.stream import Stream as Stream diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 77cf22d..da724c2 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -6,7 +6,6 @@ import warp as wp from typing import Any -from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator @@ -49,9 +48,7 @@ def _streaming_jax_i(f, c): elif self.velocity_set.d == 3: return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) - return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( - f, jnp.array(self.velocity_set.c).T - ) + return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)(f, jnp.array(self.velocity_set.c).T) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 5a59b97..3b0f85f 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import warp as wp + class Precision(Enum): FP64 = auto() FP32 = auto() @@ -42,6 +43,7 @@ def jax_dtype(self): else: raise ValueError("Invalid precision") + class PrecisionPolicy(Enum): FP64FP64 = auto() FP64FP32 = auto() @@ -93,4 +95,4 @@ def cast_to_compute_warp(self, array): def cast_to_store_warp(self, array): store_precision = self.store_precision - return wp.array(array, dtype=store_precision.wp_dtype) \ No newline at end of file + return wp.array(array, dtype=store_precision.wp_dtype) diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 57f65c0..6b400ee 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig from xlb.precision_policy.jax_precision_policy import ( @@ -15,9 +14,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp64() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp64Fp32: @@ -25,9 +22,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp32() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp32Fp32: @@ -35,9 +30,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp32Fp32() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp64Fp16: @@ -45,9 +38,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp16() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp32Fp16: @@ -55,6 +46,4 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp32Fp16() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py index 3c8032e..6f1f61a 100644 --- a/xlb/utils/__init__.py +++ b/xlb/utils/__init__.py @@ -1,9 +1,9 @@ from .utils import ( - downsample_field, - save_image, - save_fields_vtk, - save_BCs_vtk, - rotate_geometry, - voxelize_stl, - axangle2mat, + downsample_field as downsample_field, + save_image as save_image, + save_fields_vtk as save_fields_vtk, + save_BCs_vtk as save_BCs_vtk, + rotate_geometry as rotate_geometry, + voxelize_stl as voxelize_stl, + axangle2mat as axangle2mat, ) diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index 7b7ac78..074177e 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -1,7 +1,6 @@ import numpy as np import matplotlib.pylab as plt from matplotlib import cm -import numpy as np from time import time import pyvista as pv from jax.image import resize @@ -38,9 +37,7 @@ def downsample_field(field, factor, method="bicubic"): else: new_shape = tuple(dim // factor for dim in field.shape[:-1]) downsampled_components = [] - for i in range( - field.shape[-1] - ): # Iterate over the last dimension (vector components) + for i in range(field.shape[-1]): # Iterate over the last dimension (vector components) resized = resize(field[..., i], new_shape, method=method) downsampled_components.append(resized) @@ -66,8 +63,10 @@ def save_image(fld, timestep, prefix=None): Notes ----- - This function saves the field as an image in the PNG format. The filename is based on the name of the main script file, the provided prefix, and the timestep number. - If the field is 3D, the magnitude of the field is calculated and saved. The image is saved with the 'nipy_spectral' colormap and the origin set to 'lower'. + This function saves the field as an image in the PNG format. + The filename is based on the name of the main script file, the provided prefix, and the timestep number. + If the field is 3D, the magnitude of the field is calculated and saved. + The image is saved with the 'nipy_spectral' colormap and the origin set to 'lower'. """ if prefix is None: fname = os.path.basename(__main__.__file__) @@ -79,7 +78,7 @@ def save_image(fld, timestep, prefix=None): if len(fld.shape) > 3: raise ValueError("The input field should be 2D!") - elif len(fld.shape) == 3: + if len(fld.shape) == 3: fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2) plt.clf() @@ -118,9 +117,7 @@ def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): if key == list(fields.keys())[0]: dimensions = value.shape else: - assert ( - value.shape == dimensions - ), "All fields must have the same dimensions!" + assert value.shape == dimensions, "All fields must have the same dimensions!" output_filename = os.path.join(output_dir, prefix + "_" + f"{timestep:07d}.vtk") @@ -231,15 +228,11 @@ def rotate_geometry(indices, origin, axis, angle): This function rotates the mesh by applying a rotation matrix to the voxel indices. The rotation matrix is calculated using the axis-angle representation of rotations. The origin of the rotation axis is assumed to be at (0, 0, 0). """ - indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat( - axis, angle - ) + origin + indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat(axis, angle) + origin return tuple(jnp.rint(indices_rotated).astype("int32").T) -def voxelize_stl( - stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None -): +def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None): """ Converts an STL file to a voxelized mesh. @@ -314,10 +307,8 @@ def axangle2mat(axis, angle, is_normalized=False): xyC = x * yC yzC = y * zC zxC = z * xC - return jnp.array( - [ - [x * xC + c, xyC - zs, zxC + ys], - [xyC + zs, y * yC + c, yzC - xs], - [zxC - ys, yzC + xs, z * zC + c], - ] - ) + return jnp.array([ + [x * xC + c, xyC - zs, zxC + ys], + [xyC + zs, y * yC + c, yzC - xs], + [zxC - ys, yzC + xs, z * zC + c], + ]) diff --git a/xlb/velocity_set/__init__.py b/xlb/velocity_set/__init__.py index 5b7b737..c1338db 100644 --- a/xlb/velocity_set/__init__.py +++ b/xlb/velocity_set/__init__.py @@ -1,4 +1,4 @@ -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.velocity_set.d2q9 import D2Q9 -from xlb.velocity_set.d3q19 import D3Q19 -from xlb.velocity_set.d3q27 import D3Q27 +from xlb.velocity_set.velocity_set import VelocitySet as VelocitySet +from xlb.velocity_set.d2q9 import D2Q9 as D2Q9 +from xlb.velocity_set.d3q19 import D3Q19 as D3Q19 +from xlb.velocity_set.d3q27 import D3Q27 as D3Q27 diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index 178c89e..700806c 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -12,14 +12,13 @@ class D2Q9(VelocitySet): D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ + def __init__(self): # Construct the velocity vectors and weights cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] c = np.array(tuple(zip(cx, cy))).T - w = np.array( - [4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36] - ) + w = np.array([4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36]) # Call the parent constructor super().__init__(2, 9, c, w) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 7f69019..97db1d9 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -13,15 +13,10 @@ class D3Q19(VelocitySet): D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ + def __init__(self): # Construct the velocity vectors and weights - c = np.array( - [ - ci - for ci in itertools.product([-1, 0, 1], repeat=3) - if np.sum(np.abs(ci)) <= 2 - ] - ).T + c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T w = np.zeros(19) for i in range(19): if np.sum(np.abs(c[:, i])) == 0: diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py index ac908eb..702acf4 100644 --- a/xlb/velocity_set/d3q27.py +++ b/xlb/velocity_set/d3q27.py @@ -13,6 +13,7 @@ class D3Q27(VelocitySet): D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ + def __init__(self): # Construct the velocity vectors and weights c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 6f2bf4e..47bbae4 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -2,9 +2,6 @@ import math import numpy as np -from functools import partial -import jax.numpy as jnp -from jax import jit, vmap import warp as wp @@ -48,15 +45,9 @@ def __init__(self, d, q, c, w): # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) - self.wp_w = wp.constant( - wp.vec(self.q, dtype=wp.float32)(self.w) - ) # TODO: Make type optional somehow - self.wp_opp_indices = wp.constant( - wp.vec(self.q, dtype=wp.int32)(self.opp_indices) - ) - self.wp_cc = wp.constant( - wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc) - ) + self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow + self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) + self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -127,9 +118,7 @@ def _construct_main_indices(self): return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] elif self.d == 3: - return np.nonzero( - (np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1) - )[0] + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] def _construct_right_indices(self): """ From d55a4383ef7e0e92f02ab35fab7b9bddc992e32f Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 2 Aug 2024 18:05:17 -0400 Subject: [PATCH 060/144] added missing ruff formatting --- xlb/operator/stepper/nse_stepper.py | 29 +++++++++++------------------ xlb/operator/stepper/stepper.py | 2 -- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 4c9714a..b26c558 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -124,18 +124,14 @@ def kernel2d( # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) - - # Apply post-streaming type boundary conditions + + # Apply post-streaming type boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) @@ -200,9 +196,7 @@ def kernel3d( f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) @@ -236,27 +230,26 @@ def kernel3d( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): - # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + bc_to_id = boundary_condition_registry.bc_to_id - + bc_struct = self.warp_functional() bc_attribute_list = [] for bc in self.boundary_conditions: # Setting the Struct attributes based on the BC class names attribute_str = bc.__class__.__name__ - setattr(bc_struct, 'id_' + attribute_str, bc_to_id[attribute_str]) - bc_attribute_list.append('id_' + attribute_str) + setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) + bc_attribute_list.append("id_" + attribute_str) # Unused attributes of the struct are set to inernal (id=0) ll = vars(bc_struct) for var in ll: - if var not in bc_attribute_list and not var.startswith('_'): + if var not in bc_attribute_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 attribute_str = bc.__class__.__name__ - setattr(bc_struct, var, 255) - + setattr(bc_struct, var, 255) # Launch the warp kernel wp.launch( diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 3e04e4b..2127ea6 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -34,7 +34,6 @@ def __init__(self, operators, boundary_conditions): from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC - # Define a list of tuples with attribute names and their corresponding classes conditions = [ ("equilibrium_bc", EquilibriumBC), @@ -55,7 +54,6 @@ def __init__(self, operators, boundary_conditions): elif not hasattr(self, attr_name): setattr(self, attr_name, bc_fallback) - ############################################ # Initialize operator From 329bd4ca56ade7d803c00ed7bcc3e63d4680f547 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 2 Aug 2024 18:31:52 -0400 Subject: [PATCH 061/144] Warp structs (#56) * somewhat improved bc handling using structs * minor clean up. Warp MLUPs is not affected. * added missing ruff formatting --- xlb/operator/stepper/nse_stepper.py | 74 +++++++++++++++++++--------- xlb/operator/stepper/stepper.py | 75 +++++++++-------------------- 2 files changed, 74 insertions(+), 75 deletions(-) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 3fad2b1..b26c558 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -89,11 +89,15 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - # Get the boundary condition ids - _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) - _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) - _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) - _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) + @wp.struct + class BoundaryConditionIDStruct: + # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning + # One needs to manually add the names of additional BC's as they are added. + # TODO: Anyway to improve this + id_EquilibriumBC: wp.uint8 + id_DoNothingBC: wp.uint8 + id_HalfwayBounceBackBC: wp.uint8 + id_FullwayBounceBackBC: wp.uint8 @wp.kernel def kernel2d( @@ -101,6 +105,7 @@ def kernel2d( f_1: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -117,17 +122,17 @@ def kernel2d( else: _missing_mask[l] = wp.uint8(0) - # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Apply post-streaming type boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) @@ -145,8 +150,8 @@ def kernel2d( u, ) - # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + # Apply post-collision type boundary conditions + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -165,6 +170,7 @@ def kernel3d( f_1: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -181,17 +187,17 @@ def kernel3d( else: _missing_mask[l] = wp.uint8(0) - # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Apply post-streaming boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) @@ -205,7 +211,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -220,10 +226,31 @@ def kernel3d( # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel + return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + # Get the boundary condition ids + from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + + bc_to_id = boundary_condition_registry.bc_to_id + + bc_struct = self.warp_functional() + bc_attribute_list = [] + for bc in self.boundary_conditions: + # Setting the Struct attributes based on the BC class names + attribute_str = bc.__class__.__name__ + setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) + bc_attribute_list.append("id_" + attribute_str) + + # Unused attributes of the struct are set to inernal (id=0) + ll = vars(bc_struct) + for var in ll: + if var not in bc_attribute_list and not var.startswith("_"): + # set unassigned boundaries to the maximum integer in uint8 + attribute_str = bc.__class__.__name__ + setattr(bc_struct, var, 255) + # Launch the warp kernel wp.launch( self.warp_kernel, @@ -232,6 +259,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_1, boundary_mask, missing_mask, + bc_struct, timestep, ], dim=f_0.shape[1:], diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index adc2564..2127ea6 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,6 +1,5 @@ # Base class for all stepper operators from xlb.operator import Operator -from xlb.operator.equilibrium import Equilibrium from xlb import DefaultConfig @@ -26,63 +25,35 @@ def __init__(self, operators, boundary_conditions): compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() # Add boundary conditions - # Warp cannot handle lists of functions currently - # Because of this we manually unpack the boundary conditions ############################################ + # Warp cannot handle lists of functions currently # TODO: Fix this later ############################################ from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import ( - HalfwayBounceBackBC, - ) - from xlb.operator.boundary_condition.bc_fullway_bounce_back import ( - FullwayBounceBackBC, - ) - - self.equilibrium_bc = None - self.do_nothing_bc = None - self.halfway_bounce_back_bc = None - self.fullway_bounce_back_bc = None - - for bc in boundary_conditions: - if isinstance(bc, EquilibriumBC): - self.equilibrium_bc = bc - elif isinstance(bc, DoNothingBC): - self.do_nothing_bc = bc - elif isinstance(bc, HalfwayBounceBackBC): - self.halfway_bounce_back_bc = bc - elif isinstance(bc, FullwayBounceBackBC): - self.fullway_bounce_back_bc = bc + from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + + # Define a list of tuples with attribute names and their corresponding classes + conditions = [ + ("equilibrium_bc", EquilibriumBC), + ("do_nothing_bc", DoNothingBC), + ("halfway_bounce_back_bc", HalfwayBounceBackBC), + ("fullway_bounce_back_bc", FullwayBounceBackBC), + ] + + # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. + bc_fallback = boundary_conditions[0] + + # Iterate over each boundary condition + for attr_name, bc_class in conditions: + for bc in boundary_conditions: + if isinstance(bc, bc_class): + setattr(self, attr_name, bc) + break + elif not hasattr(self, attr_name): + setattr(self, attr_name, bc_fallback) - if self.equilibrium_bc is None: - # Select the equilibrium operator based on its type - self.equilibrium_bc = EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=next((op for op in self.operators if isinstance(op, Equilibrium)), None), - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.do_nothing_bc is None: - self.do_nothing_bc = DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.halfway_bounce_back_bc is None: - self.halfway_bounce_back_bc = HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.fullway_bounce_back_bc is None: - self.fullway_bounce_back_bc = FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) ############################################ # Initialize operator From 5fa63f5f3fa379ba96111101ca0ed68edf6a967b Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Sat, 3 Aug 2024 01:20:51 -0400 Subject: [PATCH 062/144] major improvements to Warp impl of BC. all consistent now. --- examples/cfd/lid_driven_cavity_2d.py | 6 +- .../boundary_condition/bc_do_nothing.py | 47 +++++------ .../boundary_condition/bc_equilibrium.py | 40 +++++---- .../bc_fullway_bounce_back.py | 8 +- .../bc_halfway_bounce_back.py | 81 +++++-------------- xlb/operator/stepper/nse_stepper.py | 26 +++--- 6 files changed, 86 insertions(+), 122 deletions(-) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 488ebc1..b4540a9 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -4,7 +4,7 @@ from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC +from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp @@ -48,7 +48,7 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): lid, walls = self.define_boundary_indices() bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid) - bc_walls = FullwayBounceBackBC(indices=walls) + bc_walls = HalfwayBounceBackBC(indices=walls) self.boundary_conditions = [bc_top, bc_walls] def setup_boundary_masks(self): @@ -99,7 +99,7 @@ def post_process(self, i): # Running the simulation grid_size = 500 grid_shape = (grid_size, grid_size) - backend = ComputeBackend.JAX + backend = ComputeBackend.WARP velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index c8b7e71..f2a08b1 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -58,26 +58,12 @@ def _construct_warp(self): # Construct the funcional to get streamed indices @wp.func - def functional2d( - f_pre: wp.array3d(dtype=Any), - missing_mask: Any, - index: Any, - ): - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_pre[l, index[0], index[1]] - return _f - - @wp.func - def functional3d( - f_pre: wp.array4d(dtype=Any), + def functional( + f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_pre[l, index[0], index[1], index[2]] - return _f + return f_pre @wp.kernel def kernel2d( @@ -91,9 +77,15 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -102,11 +94,9 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional2d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): @@ -125,9 +115,15 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -136,17 +132,14 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1], index[2]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): f_post[l, index[0], index[1], index[2]] = _f[l] - functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 8d8544b..559a624 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -78,10 +78,10 @@ def _construct_warp(self): # Construct the funcional to get streamed indices @wp.func - def functional2d( - f_pre: wp.array3d(dtype=Any), + def functional( + f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f @@ -99,9 +99,15 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -110,25 +116,14 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional2d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): f_post[l, index[0], index[1]] = _f[l] - @wp.func - def functional3d( - f_pre: wp.array4d(dtype=Any), - missing_mask: Any, - index: Any, - ): - _f = self.equilibrium_operator.warp_functional(_rho, _u) - return _f - # Construct the warp kernel @wp.kernel def kernel3d( @@ -142,9 +137,15 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -153,18 +154,15 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1], index[2]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - functional = functional3d if self.velocity_set.d == 3 else functional2d return functional, kernel diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 7b5de88..52ac49e 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -61,8 +61,8 @@ def _construct_warp(self): @wp.func def functional( f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): fliped_f = _f_vec() for l in range(_q): @@ -87,6 +87,7 @@ def kernel2d( _f_post = _f_vec() _mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations _f_pre[l] = f_pre[l, index[0], index[1]] _f_post[l] = f_post[l, index[0], index[1]] @@ -98,7 +99,7 @@ def kernel2d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _mask, index) + _f = functional(_f_pre, _f_post, _mask) else: _f = _f_post @@ -126,6 +127,7 @@ def kernel3d( _f_post = _f_vec() _mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations _f_pre[l] = f_pre[l, index[0], index[1], index[2]] _f_post[l] = f_post[l, index[0], index[1], index[2]] @@ -137,7 +139,7 @@ def kernel3d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _mask, index) + _f = functional(_f_pre, _f_post, _mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 21d9344..b34cbb4 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -65,61 +65,18 @@ def _construct_warp(self): _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func - def functional2d( - f_pre: wp.array3d(dtype=Any), - missing_mask: Any, - index: Any, - ): - # Pull the distribution function - _f = _f_vec() - for l in range(self.velocity_set.q): - # Get pull index - pull_index = type(index)() - - # If the mask is missing then take the opposite index - if missing_mask[l] == wp.uint8(1): - use_l = _opp_indices[l] - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - - # Pull the distribution function - else: - use_l = l - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - - # Get the distribution function - _f[l] = f_pre[use_l, pull_index[0], pull_index[1]] - - return _f - - # Construct the funcional to get streamed indices - @wp.func - def functional3d( - f_pre: wp.array4d(dtype=Any), + def functional( + f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): - # Pull the distribution function - _f = _f_vec() + # Post-streaming values are only modified at missing direction + _f = f_post for l in range(self.velocity_set.q): - # Get pull index - pull_index = type(index)() - # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): - use_l = _opp_indices[l] - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - - # Pull the distribution function - else: - use_l = l - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - - # Get the distribution function - _f[l] = f_pre[use_l, pull_index[0], pull_index[1], pull_index[2]] + # Get the pre-streaming distribution function in oppisite direction + _f[l] = f_pre[_opp_indices[l]] return _f @@ -136,9 +93,14 @@ def kernel2d( index = wp.vec3i(i, j) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -147,11 +109,9 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional2d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1]] + _f = _f_post # Write the distribution function for l in range(self.velocity_set.q): @@ -170,9 +130,15 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -181,18 +147,15 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1], index[2]] + _f = _f_post # Write the distribution function for l in range(self.velocity_set.q): f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - functional = functional3d if self.velocity_set.d == 3 else functional2d return functional, kernel diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 650e3f1..bfc9d8c 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -93,7 +93,7 @@ def _construct_warp(self): class BoundaryConditionIDStruct: # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning # One needs to manually add the names of additional BC's as they are added. - # TODO: Anyway to improve this + # TODO: Any way to improve this? id_EquilibriumBC: wp.uint8 id_DoNothingBC: wp.uint8 id_HalfwayBounceBackBC: wp.uint8 @@ -113,9 +113,13 @@ def kernel2d( index = wp.vec2i(i, j) # TODO warp should fix this # Get the boundary id and missing mask + f_post_collision = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + f_post_collision[l] = f_0[l, index[0], index[1]] + # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -128,13 +132,13 @@ def kernel2d( # Apply post-streaming type boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -153,7 +157,7 @@ def kernel2d( # Apply post-collision type boundary conditions if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, _missing_mask, index) + f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) # Set the output for l in range(self.velocity_set.q): @@ -174,9 +178,13 @@ def kernel3d( index = wp.vec3i(i, j, k) # TODO warp should fix this # Get the boundary id and missing mask + f_post_collision = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -189,13 +197,13 @@ def kernel3d( # Apply post-streaming boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -209,7 +217,7 @@ def kernel3d( # Apply collision type boundary conditions if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, _missing_mask, index) + f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) # Set the output for l in range(self.velocity_set.q): From 67d7a19f70334542c67c45b48046b00802ec2548 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Sun, 4 Aug 2024 15:11:48 -0400 Subject: [PATCH 063/144] major improvements to Warp impl of BC. all consistent now. --- examples/cfd/lid_driven_cavity_2d.py | 6 +- examples/cfd/windtunnel_3d.py | 4 +- .../test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_warp.py | 5 +- .../boundary_condition/bc_do_nothing.py | 59 +++++----- .../boundary_condition/bc_equilibrium.py | 52 ++++----- .../bc_fullway_bounce_back.py | 14 +-- .../bc_halfway_bounce_back.py | 93 +++++---------- .../indices_boundary_masker.py | 106 +++++++++++------- xlb/operator/stepper/nse_stepper.py | 34 +++--- xlb/operator/stream/stream.py | 2 + 11 files changed, 178 insertions(+), 199 deletions(-) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 488ebc1..b4540a9 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -4,7 +4,7 @@ from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC +from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp @@ -48,7 +48,7 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): lid, walls = self.define_boundary_indices() bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid) - bc_walls = FullwayBounceBackBC(indices=walls) + bc_walls = HalfwayBounceBackBC(indices=walls) self.boundary_conditions = [bc_top, bc_walls] def setup_boundary_masks(self): @@ -99,7 +99,7 @@ def post_process(self, i): # Running the simulation grid_size = 500 grid_shape = (grid_size, grid_size) - backend = ComputeBackend.JAX + backend = ComputeBackend.WARP velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index e76b303..eaac67e 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -109,8 +109,8 @@ def run(self, num_steps, print_interval, post_process_interval=100): elapsed_time = time.time() - start_time print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") - if i % post_process_interval == 0 or i == num_steps - 1: - self.post_process(i) + if i % post_process_interval == 0 or i == num_steps - 1: + self.post_process(i) def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index e319dbd..0274eba 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -66,7 +66,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask, f) + f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 963e081..da76f5e 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -5,6 +5,7 @@ from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig +from xlb.operator.boundary_masker import IndicesBoundaryMasker def init_xlb_env(velocity_set): @@ -35,7 +36,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + indices_boundary_masker = IndicesBoundaryMasker() # Make indices for boundary conditions (sphere) sphere_radius = grid_shape[0] // 4 @@ -64,7 +65,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask, f_pre) + f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index a57e427..f2a08b1 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -58,26 +58,12 @@ def _construct_warp(self): # Construct the funcional to get streamed indices @wp.func - def functional2d( - f: wp.array3d(dtype=Any), + def functional( + f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - return _f - - @wp.func - def functional3d( - f: wp.array4d(dtype=Any), - missing_mask: Any, - index: Any, - ): - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1], index[2]] - return _f + return f_pre @wp.kernel def kernel2d( @@ -85,16 +71,21 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.uint8), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -103,15 +94,13 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -120,16 +109,21 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -138,27 +132,24 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1], index[2]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] - functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index f1d85d3..559a624 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -78,10 +78,10 @@ def _construct_warp(self): # Construct the funcional to get streamed indices @wp.func - def functional2d( - f: wp.array3d(dtype=Any), + def functional( + f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f @@ -93,16 +93,21 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -111,24 +116,13 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional2d(f_post, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] - - @wp.func - def functional3d( - f: wp.array4d(dtype=Any), - missing_mask: Any, - index: Any, - ): - _f = self.equilibrium_operator.warp_functional(_rho, _u) - return _f + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -137,16 +131,21 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -155,27 +154,24 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional3d(f_post, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1], index[2]] + _f = _f_post # Write the result for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - functional = functional3d if self.velocity_set.d == 3 else functional2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 5b2f2c1..52ac49e 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -75,7 +75,6 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) @@ -88,6 +87,7 @@ def kernel2d( _f_post = _f_vec() _mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations _f_pre[l] = f_pre[l, index[0], index[1]] _f_post[l] = f_post[l, index[0], index[1]] @@ -105,7 +105,7 @@ def kernel2d( # Write the result to the output for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -114,7 +114,6 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -128,6 +127,7 @@ def kernel3d( _f_post = _f_vec() _mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations _f_pre[l] = f_pre[l, index[0], index[1], index[2]] _f_post[l] = f_post[l, index[0], index[1], index[2]] @@ -145,18 +145,18 @@ def kernel3d( # Write the result to the output for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 1fa4a7c..b34cbb4 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -65,61 +65,18 @@ def _construct_warp(self): _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func - def functional2d( - f: wp.array3d(dtype=Any), + def functional( + f_pre: Any, + f_post: Any, missing_mask: Any, - index: Any, ): - # Pull the distribution function - _f = _f_vec() + # Post-streaming values are only modified at missing direction + _f = f_post for l in range(self.velocity_set.q): - # Get pull index - pull_index = type(index)() - - # If the mask is missing then take the opposite index - if missing_mask[l] == wp.uint8(1): - use_l = _opp_indices[l] - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - - # Pull the distribution function - else: - use_l = l - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - - # Get the distribution function - _f[l] = f[use_l, pull_index[0], pull_index[1]] - - return _f - - # Construct the funcional to get streamed indices - @wp.func - def functional3d( - f: wp.array4d(dtype=Any), - missing_mask: Any, - index: Any, - ): - # Pull the distribution function - _f = _f_vec() - for l in range(self.velocity_set.q): - # Get pull index - pull_index = type(index)() - # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): - use_l = _opp_indices[l] - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - - # Pull the distribution function - else: - use_l = l - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - - # Get the distribution function - _f[l] = f[use_l, pull_index[0], pull_index[1], pull_index[2]] + # Get the pre-streaming distribution function in oppisite direction + _f[l] = f_pre[_opp_indices[l]] return _f @@ -130,16 +87,20 @@ def kernel2d( f_post: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), - f: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() index = wp.vec3i(i, j) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -148,15 +109,13 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional2d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1]] + _f = _f_post # Write the distribution function for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = _f[l] # Construct the warp kernel @wp.kernel @@ -165,16 +124,21 @@ def kernel3d( f_post: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -183,27 +147,24 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional3d(f_pre, _missing_mask, index) + _f = functional(_f_pre, _f_post, _missing_mask) else: - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f_post[l, index[0], index[1], index[2]] + _f = _f_post # Write the distribution function for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - functional = functional3d if self.velocity_set.d == 3 else functional2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f): + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask, f], + inputs=[f_pre, f_post, boundary_mask, missing_mask], dim=f_pre.shape[1:], ) - return f + return f_post diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 7960083..b5c2fc4 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -1,8 +1,12 @@ import numpy as np import warp as wp +import jax +import jax.numpy as jnp from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from xlb.operator.stream.stream import Stream +from xlb.grid import grid_factory +from xlb.precision_policy import Precision class IndicesBoundaryMasker(Operator): @@ -25,32 +29,43 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation(self, bclist, boundary_mask, mask, start_index=None): - # define a helper function - def compute_boundary_id_and_mask(boundary_mask, mask): - if dim == 2: - boundary_mask = boundary_mask.at[0, local_indices[0], local_indices[1]].set(id_number) - mask = mask.at[:, local_indices[0], local_indices[1]].set(True) - - if dim == 3: - boundary_mask = boundary_mask.at[0, local_indices[0], local_indices[1], local_indices[2]].set(id_number) - mask = mask.at[:, local_indices[0], local_indices[1], local_indices[2]].set(True) - return boundary_mask, mask - - dim = mask.ndim - 1 + def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): + # Pad the missing mask to create a grid mask to identify out of bound boundaries + # Set padded regin to True (i.e. boundary) + dim = missing_mask.ndim - 1 + nDevices = jax.device_count() + pad_x, pad_y, pad_z = nDevices, 1, 1 + if dim == 2: + grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) + if dim == 3: + grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) + + # shift indices + shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) if start_index is None: start_index = (0,) * dim + bid = boundary_mask[0] for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" id_number = bc.id - local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] - boundary_mask, mask = compute_boundary_id_and_mask(boundary_mask, mask) + local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] + padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] + bid = bid.at[tuple(local_indices)].set(id_number) + if dim == 2: + grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) + if dim == 3: + grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - mask = self.stream(mask) - return boundary_mask, mask + boundary_mask = boundary_mask.at[0].set(bid) + grid_mask = self.stream(grid_mask) + if dim == 2: + missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] + if dim == 3: + missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] + return boundary_mask, missing_mask def _construct_warp(self): # Make constants for warp @@ -63,7 +78,7 @@ def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), boundary_mask: wp.array3d(dtype=wp.uint8), - mask: wp.array3d(dtype=wp.bool), + missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): # Get the index of indices @@ -71,20 +86,23 @@ def kernel2d( # Get local indices index = wp.vec2i() - index[0] = indices[0, ii] - start_index[0] - index[1] = indices[1, ii] - start_index[1] + index[0] = indices[0, ii] + start_index[0] + index[1] = indices[1, ii] + start_index[1] - # Check if in bounds - if index[0] >= 0 and index[0] < mask.shape[1] and index[1] >= 0 and index[1] < mask.shape[2]: + # Check if index is in bounds + if index[0] >= 0 and index[0] < missing_mask.shape[1] and index[1] >= 0 and index[1] < missing_mask.shape[2]: # Stream indices for l in range(_q): # Get the index of the streaming direction - push_index = wp.vec2i() + pull_index = wp.vec2i() for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] + pull_index[d] = index[d] - _c[d, l] - # Set the boundary id and mask - mask[l, push_index[0], push_index[1]] = True + # check if pull index is out of bound + # These directions will have missing information after streaming + if pull_index[0] < 0 or pull_index[0] >= missing_mask.shape[1] or pull_index[1] < 0 or pull_index[1] >= missing_mask.shape[2]: + # Set the missing mask + missing_mask[l, index[0], index[1]] = True boundary_mask[0, index[0], index[1]] = id_number[ii] @@ -94,7 +112,7 @@ def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), boundary_mask: wp.array4d(dtype=wp.uint8), - mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): # Get the index of indices @@ -102,28 +120,38 @@ def kernel3d( # Get local indices index = wp.vec3i() - index[0] = indices[0, ii] - start_index[0] - index[1] = indices[1, ii] - start_index[1] - index[2] = indices[2, ii] - start_index[2] + index[0] = indices[0, ii] + start_index[0] + index[1] = indices[1, ii] + start_index[1] + index[2] = indices[2, ii] + start_index[2] - # Check if in bounds + # Check if index is in bounds if ( index[0] >= 0 - and index[0] < mask.shape[1] + and index[0] < missing_mask.shape[1] and index[1] >= 0 - and index[1] < mask.shape[2] + and index[1] < missing_mask.shape[2] and index[2] >= 0 - and index[2] < mask.shape[3] + and index[2] < missing_mask.shape[3] ): # Stream indices for l in range(_q): # Get the index of the streaming direction - push_index = wp.vec3i() + pull_index = wp.vec3i() for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] - - # Set the mask - mask[l, push_index[0], push_index[1], push_index[2]] = True + pull_index[d] = index[d] - _c[d, l] + + # check if pull index is out of bound + # These directions will have missing information after streaming + if ( + pull_index[0] < 0 + or pull_index[0] >= missing_mask.shape[1] + or pull_index[1] < 0 + or pull_index[1] >= missing_mask.shape[2] + or pull_index[2] < 0 + or pull_index[2] >= missing_mask.shape[3] + ): + # Set the missing mask + missing_mask[l, index[0], index[1], index[2]] = True boundary_mask[0, index[0], index[1], index[2]] = id_number[ii] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index b26c558..bfc9d8c 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -93,7 +93,7 @@ def _construct_warp(self): class BoundaryConditionIDStruct: # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning # One needs to manually add the names of additional BC's as they are added. - # TODO: Anyway to improve this + # TODO: Any way to improve this? id_EquilibriumBC: wp.uint8 id_DoNothingBC: wp.uint8 id_HalfwayBounceBackBC: wp.uint8 @@ -113,9 +113,13 @@ def kernel2d( index = wp.vec2i(i, j) # TODO warp should fix this # Get the boundary id and missing mask + f_post_collision = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + f_post_collision[l] = f_0[l, index[0], index[1]] + # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) @@ -128,13 +132,13 @@ def kernel2d( # Apply post-streaming type boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -153,11 +157,7 @@ def kernel2d( # Apply post-collision type boundary conditions if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional( - f_post_stream, - f_post_collision, - _missing_mask, - ) + f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) # Set the output for l in range(self.velocity_set.q): @@ -178,9 +178,13 @@ def kernel3d( index = wp.vec3i(i, j, k) # TODO warp should fix this # Get the boundary id and missing mask + f_post_collision = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) @@ -193,13 +197,13 @@ def kernel3d( # Apply post-streaming boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -213,11 +217,7 @@ def kernel3d( # Apply collision type boundary conditions if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional( - f_post_stream, - f_post_collision, - _missing_mask, - ) + f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) # Set the output for l in range(self.velocity_set.q): diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index da724c2..f91b567 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -69,6 +69,7 @@ def functional2d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + # impose periodicity for out of bound values if pull_index[d] < 0: pull_index[d] = f.shape[d + 1] - 1 elif pull_index[d] >= f.shape[d + 1]: @@ -109,6 +110,7 @@ def functional3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + # impose periodicity for out of bound values if pull_index[d] < 0: pull_index[d] = f.shape[d + 1] - 1 elif pull_index[d] >= f.shape[d + 1]: From 025c28effab74f4b8d1ab3ff1cba7a6fbff6b0a1 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Wed, 7 Aug 2024 12:47:17 -0400 Subject: [PATCH 064/144] ZouHe BC added in JAX --- xlb/operator/boundary_condition/__init__.py | 1 + xlb/operator/boundary_condition/bc_zouhe.py | 271 ++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 xlb/operator/boundary_condition/bc_zouhe.py diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 1fd2152..4887085 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -6,3 +6,4 @@ from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC as DoNothingBC from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC +from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py new file mode 100644 index 0000000..8e75768 --- /dev/null +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -0,0 +1,271 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) +from xlb.operator.equilibrium import QuadraticEquilibrium + + +class ZouHeBC(BoundaryCondition): + """ + Zou-He boundary condition for a lattice Boltzmann method simulation. + + This class implements the Zou-He boundary condition, which is a non-equilibrium bounce-back boundary condition. + It can be used to set inflow and outflow boundary conditions with prescribed pressure or velocity. + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + bc_type=None, + prescribed_value=None, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + ): + assert bc_type in ["velocity", "pressure"], f'The boundary type must be either "velocity" or "pressure"' + self.bc_type = bc_type + self.equilibrium_operator = QuadraticEquilibrium() + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + indices, + ) + + # Set the prescribed value for pressure or velocity + dim = self.velocity_set.d + self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] + # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! + + @partial(jit, static_argnums=(0,), inline=True) + def _get_known_middle_mask(self, missing_mask): + known_mask = missing_mask[self.velocity_set.opp_indices] + middle_mask = ~(missing_mask | known_mask) + return known_mask, middle_mask + + @partial(jit, static_argnums=(0,), inline=True) + def _get_normal_vec(self, missing_mask): + main_c = self.velocity_set.c[:, self.velocity_set.main_indices] + m = missing_mask[self.velocity_set.main_indices] + normals = -jnp.tensordot(main_c, m, axes=(-1, 0)) + return normals + + @partial(jit, static_argnums=(0,), inline=True) + def get_rho(self, fpop, missing_mask): + if self.bc_type == "velocity": + vel = self.get_vel(fpop, missing_mask) + rho = self.calculate_rho(fpop, vel, missing_mask) + elif self.bc_type == "pressure": + rho = self.prescribed_value + else: + raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") + return rho + + @partial(jit, static_argnums=(0,), inline=True) + def get_vel(self, fpop, missing_mask): + if self.bc_type == "velocity": + vel = self.prescribed_value + elif self.bc_type == "pressure": + rho = self.get_rho(fpop, missing_mask) + vel = self.calculate_vel(fpop, rho, missing_mask) + else: + raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") + return vel + + @partial(jit, static_argnums=(0,), inline=True) + def calculate_vel(self, fpop, rho, missing_mask): + """ + Calculate velocity based on the prescribed pressure/density (Zou/He BC) + """ + + normals = self._get_normal_vec(missing_mask) + known_mask, middle_mask = self._get_known_middle_mask(missing_mask) + + unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=-1, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=-1, keepdims=True)) + + # Return the above unormal as a normal vector which sets the tangential velocities to zero + vel = unormal * normals + return vel + + @partial(jit, static_argnums=(0,), inline=True) + def calculate_rho(self, fpop, vel, missing_mask): + """ + Calculate density based on the prescribed velocity (Zou/He BC) + """ + normals = self._get_normal_vec(missing_mask) + known_mask, middle_mask = self._get_known_middle_mask(missing_mask) + unormal = jnp.sum(normals * vel, keepdims=True, axis=0) + rho = (1.0 / (1.0 + unormal)) * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) + return rho + + @partial(jit, static_argnums=(0,), inline=True) + def calculate_equilibrium(self, fpop, missing_mask): + """ + This is the ZouHe method of calculating the missing macroscopic variables at the boundary. + """ + rho = self.get_rho(fpop, missing_mask) + vel = self.get_vel(fpop, missing_mask) + + # compute feq at the boundary + feq = self.equilibrium_operator(rho, vel) + return feq + + @partial(jit, static_argnums=(0,), inline=True) + def bounceback_nonequilibrium(self, fpop, feq, missing_mask): + """ + Calculate unknown populations using bounce-back of non-equilibrium populations + a la original Zou & He formulation + """ + opp = self.velocity_set.opp_indices + fknown = fpop[opp] + feq - feq[opp] + fpop = jnp.where(missing_mask, fknown, fpop) + return fpop + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + # creat a mask to slice boundary cells + boundary = boundary_mask == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + + # compute the equilibrium based on prescribed values and the type of BC + feq = self.calculate_equilibrium(f_post, missing_mask) + + # set the unknown f populations based on the non-equilibrium bounce-back method + f_post_bd = self.bounceback_nonequilibrium(f_post, feq, missing_mask) + f_post = jnp.where(boundary, f_post_bd, f_post) + return f_post + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + + @wp.func + def functional( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + for l in range(self.velocity_set.q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + # Get the pre-streaming distribution function in oppisite direction + _f[l] = f_pre[_opp_indices[l]] + + return _f + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec3i(i, j) + + # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() + _boundary_id = boundary_mask[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + _f = functional(_f_pre, _f_post, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1]] = _f[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + _f = functional(_f_pre, _f_post, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post From 90969d156d6e6429f8f2fc29cb18da321d8fe3da Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Wed, 7 Aug 2024 18:49:47 -0400 Subject: [PATCH 065/144] added some warp helper functions in the BoundaryCondition class to read thread data --- .../boundary_condition/bc_do_nothing.py | 41 ++---------- .../boundary_condition/bc_equilibrium.py | 38 ++---------- .../bc_fullway_bounce_back.py | 45 +++----------- .../bc_halfway_bounce_back.py | 39 ++---------- xlb/operator/boundary_condition/bc_zouhe.py | 37 ++--------- .../boundary_condition/boundary_condition.py | 62 +++++++++++++++++++ 6 files changed, 91 insertions(+), 171 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index f2a08b1..5049b1f 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -51,12 +51,7 @@ def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - - # Construct the funcional to get streamed indices - + # Construct the functional for this BC @wp.func def functional( f_pre: Any, @@ -76,21 +71,8 @@ def kernel2d( i, j = wp.tid() index = wp.vec2i(i, j) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): @@ -114,21 +96,8 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 559a624..1937a17 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -70,13 +70,11 @@ def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(self.rho) _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - # Construct the funcional to get streamed indices + # Construct the functional for this BC @wp.func def functional( f_pre: Any, @@ -98,21 +96,8 @@ def kernel2d( i, j = wp.tid() index = wp.vec2i(i, j) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): @@ -136,21 +121,8 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 52ac49e..85a3c35 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -55,9 +55,8 @@ def _construct_warp(self): _opp_indices = self.velocity_set.wp_opp_indices _q = wp.constant(self.velocity_set.q) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - # Construct the funcional to get streamed indices + # Construct the functional for this BC @wp.func def functional( f_pre: Any, @@ -79,27 +78,12 @@ def kernel2d( i, j = wp.tid() index = wp.vec2i(i, j) - # Get the boundary id and missing mask - _boundary_id = boundary_mask[0, index[0], index[1]] - - # Make vectors for the lattice - _f_pre = _f_vec() - _f_post = _f_vec() - _mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _mask) + _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post @@ -119,27 +103,12 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get the boundary id and missing mask - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - - # Make vectors for the lattice - _f_pre = _f_vec() - _f_post = _f_vec() - _mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _mask) + _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index b34cbb4..df947c6 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -58,12 +58,10 @@ def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): ) def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c + # Set local constants _opp_indices = self.velocity_set.wp_opp_indices - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + # Construct the functional for this BC @wp.func def functional( f_pre: Any, @@ -92,20 +90,8 @@ def kernel2d( i, j = wp.tid() index = wp.vec3i(i, j) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): @@ -129,21 +115,8 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 8e75768..4e2eebc 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -193,23 +193,11 @@ def kernel2d( i, j = wp.tid() index = wp.vec3i(i, j) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_id == wp.uint8(ZouHeBC.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post @@ -230,24 +218,11 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_id == wp.uint8(ZouHeBC.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 125e45d..f4e0a1b 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -3,6 +3,8 @@ """ from enum import Enum, auto +import warp as wp +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -41,3 +43,63 @@ def __init__( # Set the implementation step self.implementation_step = implementation_step + + if self.compute_backend == ComputeBackend.WARP: + # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + + @wp.func + def _get_thread_data_2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + index: wp.vec2i, + ): + # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() + _boundary_id = boundary_mask[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1]] + _f_post[l] = f_post[l, index[0], index[1]] + + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + return _f_pre, _f_post, _boundary_id, _missing_mask + + @wp.func + def _get_thread_data_3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + index: wp.vec3i, + ): + # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of populations + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + return _f_pre, _f_post, _boundary_id, _missing_mask + + # Construct some helper warp functions for getting tid data + if self.compute_backend == ComputeBackend.WARP: + self._get_thread_data_2d = _get_thread_data_2d + self._get_thread_data_3d = _get_thread_data_3d From 4c6d7d553e436daed0e2176441b57c2f6b7af807 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 9 Aug 2024 16:40:49 -0400 Subject: [PATCH 066/144] added ZouHe in warp --- xlb/operator/boundary_condition/bc_zouhe.py | 172 ++++++++++++++++++-- xlb/operator/collision/kbc.py | 4 +- xlb/operator/stepper/nse_stepper.py | 11 +- xlb/operator/stepper/stepper.py | 2 + xlb/velocity_set/velocity_set.py | 1 + 5 files changed, 171 insertions(+), 19 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 4e2eebc..06be077 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -34,16 +34,17 @@ class ZouHeBC(BoundaryCondition): def __init__( self, - bc_type=None, - prescribed_value=None, + bc_type, + prescribed_value, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, ): - assert bc_type in ["velocity", "pressure"], f'The boundary type must be either "velocity" or "pressure"' + assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() + self.prescribed_value = prescribed_value # Call the parent constructor super().__init__( @@ -56,8 +57,9 @@ def __init__( # Set the prescribed value for pressure or velocity dim = self.velocity_set.d - self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] - # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! + if self.compute_backend == ComputeBackend.JAX: + self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] + # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! @partial(jit, static_argnums=(0,), inline=True) def _get_known_middle_mask(self, missing_mask): @@ -103,7 +105,7 @@ def calculate_vel(self, fpop, rho, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) - unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=-1, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=-1, keepdims=True)) + unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) # Return the above unormal as a normal vector which sets the tangential velocities to zero vel = unormal * normals @@ -159,26 +161,160 @@ def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): return f_post def _construct_warp(self): + # assign placeholders for both u and rho based on prescribed_value + _d = self.velocity_set.d + _q = self.velocity_set.q + u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d + rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 + # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c + # _u_vec = wp.vec(_d, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = wp.float32(rho) + _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.wp_opp_indices - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _c = self.velocity_set.wp_c + _c32 = self.velocity_set.wp_c32 + # TODO: this is way less than ideal. we should not be making new types + + @wp.func + def get_normal_vectors_2d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l]) + return normals @wp.func - def functional( + def get_normal_vectors_3d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return normals + + @wp.func + def _helper_functional( + fpop: Any, + fsum: Any, + missing_mask: Any, + lattice_direction: Any, + ): + l = lattice_direction + known_mask = missing_mask[_opp_indices[l]] + middle_mask = ~(missing_mask[l] | known_mask) + # fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask) + if middle_mask and known_mask: + fsum += fpop[l] + 2.0 * fpop[l] + elif middle_mask: + fsum += fpop[l] + elif known_mask: + fsum += 2.0 * fpop[l] + return fsum + + @wp.func + def bounceback_nonequilibrium( + fpop: Any, + missing_mask: Any, + density: Any, + velocity: Any, + ): + feq = self.equilibrium_operator.warp_functional(density, velocity) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] + return fpop + + @wp.func + def functional3d_velocity( f_pre: Any, f_post: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post - for l in range(self.velocity_set.q): - # If the mask is missing then take the opposite index + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): if missing_mask[l] == wp.uint8(1): - # Get the pre-streaming distribution function in oppisite direction - _f[l] = f_pre[_opp_indices[l]] + normals = get_normal_vectors_3d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = _fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional3d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_3d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + unormal = -1.0 + _fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional2d_velocity( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_2d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = _fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional2d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_2d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + unormal = -1.0 + _fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) return _f # Construct the warp kernel @@ -232,6 +368,14 @@ def kernel3d( f_post[l, index[0], index[1], index[2]] = _f[l] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + if self.velocity_set.d == 3 and self.bc_type == "velocity": + functional = functional3d_velocity + elif self.velocity_set.d == 3 and self.bc_type == "pressure": + functional = functional3d_pressure + elif self.bc_type == "velocity": + functional = functional2d_velocity + else: + functional = functional2d_pressure return functional, kernel diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index fa0857a..9297363 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -366,7 +366,7 @@ def kernel2d( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1]] _feq[l] = feq[l, index[0], index[1]] - _u = self._warp_u_vec() + _u = self.warp_u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1]] _rho = rho[0, index[0], index[1]] @@ -398,7 +398,7 @@ def kernel3d( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = self._warp_u_vec() + _u = self.warp_u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1], index[2]] _rho = rho[0, index[0], index[1], index[2]] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index bfc9d8c..84a0b8f 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -98,6 +98,7 @@ class BoundaryConditionIDStruct: id_DoNothingBC: wp.uint8 id_HalfwayBounceBackBC: wp.uint8 id_FullwayBounceBackBC: wp.uint8 + id_ZouHeBC: wp.uint8 @wp.kernel def kernel2d( @@ -139,6 +140,9 @@ def kernel2d( elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC: + # Zouhe boundary condition + f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -204,6 +208,9 @@ def kernel3d( elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC: + # Zouhe boundary condition + f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -237,9 +244,8 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): bc_struct = self.warp_functional() bc_attribute_list = [] - for bc in self.boundary_conditions: + for attribute_str in bc_to_id.keys(): # Setting the Struct attributes based on the BC class names - attribute_str = bc.__class__.__name__ setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) bc_attribute_list.append("id_" + attribute_str) @@ -248,7 +254,6 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): for var in ll: if var not in bc_attribute_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 - attribute_str = bc.__class__.__name__ setattr(bc_struct, var, 255) # Launch the warp kernel diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 2127ea6..1608342 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -33,6 +33,7 @@ def __init__(self, operators, boundary_conditions): from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC # Define a list of tuples with attribute names and their corresponding classes conditions = [ @@ -40,6 +41,7 @@ def __init__(self, operators, boundary_conditions): ("do_nothing_bc", DoNothingBC), ("halfway_bounce_back_bc", HalfwayBounceBackBC), ("fullway_bounce_back_bc", FullwayBounceBackBC), + ("zouhe_bc", ZouHeBC), ] # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 47bbae4..cd63b36 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -48,6 +48,7 @@ def __init__(self, d, q, c, w): self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) + self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) From 310aa370d84a10a9329c44eab76cdb41dcd4962b Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 13 Aug 2024 20:37:00 -0400 Subject: [PATCH 067/144] added a new macroscopic operator to compute second moments --- xlb/operator/collision/kbc.py | 60 ++-------- xlb/operator/macroscopic/__init__.py | 1 + xlb/operator/macroscopic/second_moment.py | 134 ++++++++++++++++++++++ 3 files changed, 142 insertions(+), 53 deletions(-) create mode 100644 xlb/operator/macroscopic/second_moment.py diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 9297363..da0aee5 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -6,12 +6,13 @@ from jax import jit import warp as wp from typing import Any +from functools import partial from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision from xlb.operator import Operator -from functools import partial +from xlb.operator.macroscopic import SecondMoment class KBC(Collision): @@ -28,6 +29,7 @@ def __init__( precision_policy=None, compute_backend=None, ): + self.momentum_flux = SecondMoment() self.epsilon = 1e-32 self.beta = omega * 0.5 self.inv_beta = 1.0 / self.beta @@ -94,33 +96,6 @@ def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarr """ return jnp.sum(x * y / feq, axis=0) - @partial(jit, static_argnums=(0,), donate_argnums=(1,)) - def momentum_flux_jax( - self, - fneq: jnp.ndarray, - ): - """ - This function computes the momentum flux, which is the product of the non-equilibrium - distribution functions (fneq) and the lattice moments (cc). - - The momentum flux is used in the computation of the stress tensor in the Lattice Boltzmann - Method (LBM). - - # TODO: probably move this to equilibrium calculation - - Parameters - ---------- - fneq: jax.numpy.ndarray - The non-equilibrium distribution functions. - - Returns - ------- - jax.numpy.ndarray - The computed momentum flux. - """ - - return jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) - @partial(jit, static_argnums=(0,), inline=True) def decompose_shear_d3q27_jax(self, fneq): """ @@ -138,7 +113,7 @@ def decompose_shear_d3q27_jax(self, fneq): """ # Calculate the momentum flux - Pi = self.momentum_flux_jax(fneq) + Pi = self.momentum_flux(fneq) # Calculating Nxz and Nyz with indices moved to the first dimension Nxz = Pi[0, ...] - Pi[5, ...] Nyz = Pi[3, ...] - Pi[5, ...] @@ -187,7 +162,7 @@ def decompose_shear_d2q9_jax(self, fneq): jax.numpy.array Shear components of fneq. """ - Pi = self.momentum_flux_jax(fneq) + Pi = self.momentum_flux(fneq) N = Pi[0, ...] - Pi[2, ...] s = jnp.zeros_like(fneq) s = s.at[3, ...].set(N) @@ -207,35 +182,14 @@ def _construct_warp(self): raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set))) # Set local constants TODO: This is a hack and should be fixed with warp update - _w = self.velocity_set.wp_w - _cc = self.velocity_set.wp_cc - _omega = wp.constant(self.compute_dtype(self.omega)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 - _pi_vec = wp.vec( - _pi_dim, - dtype=self.compute_dtype, - ) _epsilon = wp.constant(self.compute_dtype(self.epsilon)) _beta = wp.constant(self.compute_dtype(self.beta)) _inv_beta = wp.constant(self.compute_dtype(1.0 / self.beta)) - # Construct functional for computing momentum flux - @wp.func - def momentum_flux_warp( - fneq: Any, - ): - # Get momentum flux - pi = _pi_vec() - for d in range(_pi_dim): - pi[d] = 0.0 - for q in range(self.velocity_set.q): - pi[d] += _cc[q, d] * fneq[q] - return pi - @wp.func def decompose_shear_d2q9(fneq: Any): - pi = momentum_flux_warp(fneq) + pi = self.momentum_flux.warp_functional(fneq) N = pi[0] - pi[1] s = wp.vec9(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) s[3] = N @@ -254,7 +208,7 @@ def decompose_shear_d3q27( fneq: Any, ): # Get momentum flux - pi = momentum_flux_warp(fneq) + pi = self.momentum_flux.warp_functional(fneq) nxz = pi[0] - pi[5] nyz = pi[3] - pi[5] diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 3463078..c0755ad 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1 +1,2 @@ from xlb.operator.macroscopic.macroscopic import Macroscopic as Macroscopic +from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py new file mode 100644 index 0000000..db8fce6 --- /dev/null +++ b/xlb/operator/macroscopic/second_moment.py @@ -0,0 +1,134 @@ +# Base class for all equilibriums + +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class SecondMoment(Operator): + """ + Operator to calculate the second moment of distribution functions. + + The second moment may be used to compute the momentum flux in the computation of + the stress tensor in the Lattice Boltzmann Method (LBM). + + Important Note: + Note that this rank 2 symmetric tensor (dim*dim) has been converted into a rank one + vector where the diagonal and off-diagonal components correspond to the following elements of + the vector: + if self.grid.dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif self.grid.dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + + ** For any reduction operation on the full tensor it is crucial to account for the full tensor by + considering all diagonal and off-diagonal components. + """ + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0,), donate_argnums=(1,)) + def jax_implementation( + self, + fneq: jnp.ndarray, + ): + """ + This function computes the second order moment, which is the product of the + distribution functions (f) and the lattice moments (cc). + + Parameters + ---------- + fneq: jax.numpy.ndarray + The distribution functions. + + Returns + ------- + jax.numpy.ndarray + The computed second moment. + """ + return jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) + + def _construct_warp(self): + # Make constants for warp + _cc = self.velocity_set.wp_cc + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 + _pi_vec = wp.vec( + _pi_dim, + dtype=self.compute_dtype, + ) + + # Construct functional for computing second moment + @wp.func + def functional( + fneq: Any, + ): + # Get second order moment (a symmetric tensore shaped into a vector) + pi = _pi_vec() + for d in range(_pi_dim): + pi[d] = 0.0 + for q in range(self.velocity_set.q): + pi[d] += _cc[q, d] * fneq[q] + return pi + + # Construct the kernel + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + pi: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the equilibrium + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _pi = functional(_f) + + # Set the output + for d in range(_pi_dim): + pi[d, index[0], index[1], index[2]] = _pi[d] + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + pi: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the equilibrium + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _pi = functional(_f) + + # Set the output + for d in range(_pi_dim): + pi[d, index[0], index[1]] = _pi[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, pi): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + f, + pi, + ], + dim=pi.shape[1:], + ) + return pi From 9fe1a6182e75228fb23e7ef792349637047d6827 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 15 Aug 2024 10:18:26 -0400 Subject: [PATCH 068/144] fixed a subtle bug in ZouHe (JAX) --- examples/cfd/flow_past_sphere_3d.py | 21 ++++++++++----- xlb/operator/boundary_condition/bc_zouhe.py | 26 ++++++++++++------- .../indices_boundary_masker.py | 8 +++--- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2a580aa..55b26d4 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -5,6 +5,7 @@ from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, + ZouHeBC, EquilibriumBC, DoNothingBC, ) @@ -67,11 +68,14 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=inlet) + bc_left = ZouHeBC("velocity", (0.04, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_do_nothing = DoNothingBC(indices=outlet) + bc_outlet = ZouHeBC("pressure", 1.0, indices=outlet) bc_sphere = FullwayBounceBackBC(indices=sphere) - self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_sphere] + self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] + # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because + # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. + # TODO: how to ensure about this behind in the src code? def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( @@ -107,9 +111,14 @@ def post_process(self, i): # remove boundary cells u = u[:, 1:-1, 1:-1, 1:-1] + rho = rho[:, 1:-1, 1:-1, 1:-1][0] u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - fields = {"u_magnitude": u_magnitude} + fields = {"u_magnitude": u_magnitude, + "u_x": u[0], + "u_y": u[1], + "u_z": u[2], + "rho": rho} save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) @@ -117,9 +126,9 @@ def post_process(self, i): if __name__ == "__main__": # Running the simulation - grid_shape = (512, 128, 128) + grid_shape = (512//2, 128//2, 128//2) velocity_set = xlb.velocity_set.D3Q19() - backend = ComputeBackend.WARP + backend = ComputeBackend.JAX precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 06be077..3603f35 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -26,12 +26,14 @@ class ZouHeBC(BoundaryCondition): """ Zou-He boundary condition for a lattice Boltzmann method simulation. - This class implements the Zou-He boundary condition, which is a non-equilibrium bounce-back boundary condition. - It can be used to set inflow and outflow boundary conditions with prescribed pressure or velocity. + This method applies the Zou-He boundary condition by first computing the equilibrium distribution functions based + on the prescribed values and the type of boundary condition, and then setting the unknown distribution functions + based on the non-equilibrium bounce-back method. + Tangential velocity is not ensured to be zero by adding transverse contributions based on + Hecth & Harting (2010) (doi:10.1088/1742-5468/2010/01/P01018) as it caused numerical instabilities at higher + Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, bc_type, @@ -41,6 +43,9 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, ): + # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC + # may have different types (velocity or pressure). + self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__) assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() @@ -58,7 +63,7 @@ def __init__( # Set the prescribed value for pressure or velocity dim = self.velocity_set.d if self.compute_backend == ComputeBackend.JAX: - self.prescribed_value = jnp.array(prescribed_value)[:, None, None, None] if dim == 3 else jnp.array(prescribed_value)[:, None, None] + self.prescribed_value = jnp.atleast_1d(prescribed_value)[(slice(None),) + (None,) * dim] # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! @partial(jit, static_argnums=(0,), inline=True) @@ -77,7 +82,7 @@ def _get_normal_vec(self, missing_mask): @partial(jit, static_argnums=(0,), inline=True) def get_rho(self, fpop, missing_mask): if self.bc_type == "velocity": - vel = self.get_vel(fpop, missing_mask) + vel = self.prescribed_value rho = self.calculate_rho(fpop, vel, missing_mask) elif self.bc_type == "pressure": rho = self.prescribed_value @@ -90,7 +95,7 @@ def get_vel(self, fpop, missing_mask): if self.bc_type == "velocity": vel = self.prescribed_value elif self.bc_type == "pressure": - rho = self.get_rho(fpop, missing_mask) + rho = self.prescribed_value vel = self.calculate_vel(fpop, rho, missing_mask) else: raise ValueError(f"type = {self.bc_type} not supported! Use 'pressure' or 'velocity'.") @@ -104,8 +109,8 @@ def calculate_vel(self, fpop, rho, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) - - unormal = -1.0 + 1.0 / rho * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) + fsum = jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True) + unormal = -1.0 + fsum/rho # Return the above unormal as a normal vector which sets the tangential velocities to zero vel = unormal * normals @@ -119,7 +124,8 @@ def calculate_rho(self, fpop, vel, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) unormal = jnp.sum(normals * vel, keepdims=True, axis=0) - rho = (1.0 / (1.0 + unormal)) * (jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True)) + fsum = jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True) + rho = fsum / (1.0 + unormal) return rho @partial(jit, static_argnums=(0,), inline=True) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index b5c2fc4..7548cf0 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -52,10 +52,10 @@ def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=No local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] bid = bid.at[tuple(local_indices)].set(id_number) - if dim == 2: - grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) - if dim == 3: - grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + # if dim == 2: + # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) + # if dim == 3: + # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) From 12b5f3d58c1b33c3b52f4f7b4522656f9789cb3e Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 15 Aug 2024 10:19:45 -0400 Subject: [PATCH 069/144] WIP: regularized initial commit --- xlb/operator/boundary_condition/__init__.py | 1 + .../boundary_condition/bc_regularized.py | 365 ++++++++++++++++++ xlb/operator/stepper/nse_stepper.py | 7 + xlb/operator/stepper/stepper.py | 2 + 4 files changed, 375 insertions(+) create mode 100644 xlb/operator/boundary_condition/bc_regularized.py diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 4887085..506da1d 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -7,3 +7,4 @@ from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC +from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py new file mode 100644 index 0000000..63c8df4 --- /dev/null +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -0,0 +1,365 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +from functools import partial +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.macroscopic.second_moment import SecondMoment + + +class RegularizedBC(ZouHeBC): + """ + Regularized boundary condition for a lattice Boltzmann method simulation. + + This class implements the regularized boundary condition, which is a non-equilibrium bounce-back boundary condition + with additional regularization. It can be used to set inflow and outflow boundary conditions with prescribed pressure + or velocity. + + Attributes + ---------- + name : str + The name of the boundary condition. For this class, it is "Regularized". + Qi : numpy.ndarray + The Qi tensor, which is used in the regularization of the distribution functions. + + References + ---------- + Latt, J. (2007). Hydrodynamic limit of lattice Boltzmann equations. PhD thesis, University of Geneva. + Latt, J., Chopard, B., Malaspinas, O., Deville, M., & Michler, A. (2008). Straight velocity boundaries in the + lattice Boltzmann method. Physical Review E, 77(5), 056703. doi:10.1103/PhysRevE.77.056703 + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + bc_type, + prescribed_value, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + ): + # Call the parent constructor + super().__init__( + bc_type, + prescribed_value, + velocity_set, + precision_policy, + compute_backend, + indices, + ) + + # The operator to compute the momentum flux + self.momentum_flux = SecondMoment() + + @partial(jit, static_argnums=(0,), inline=True) + def regularize_fpop(self, fpop, feq): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + + Parameters + ---------- + fpop : jax.numpy.ndarray + The distribution functions. + feq : jax.numpy.ndarray + The equilibrium distribution functions. + + Returns + ------- + jax.numpy.ndarray + The regularized distribution functions. + """ + # Qi = cc - cs^2*I + dim = self.velocity_set.d + weights = self.velocity_set.w[(slice(None),) + (None,) * dim] + Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype) + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # Qi = cc - cs^2*I + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + # Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) + # Qi = Qi.at[:, offdiagonal].multiply(2.0) + + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + f_neq = fpop - feq + PiNeq = self.momentum_flux(f_neq) + # PiNeq = self.momentum_flux(fpop) - self.momentum_flux(feq) + + # Compute double dot product Qi:Pi1 + # QiPi1 = np.zeros_like(fpop) + # Pi1 = PiNeq + QiPi1 = jnp.tensordot(Qi, PiNeq, axes=(1, 0)) + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = 9.0 / 2.0 * weights * QiPi1 + fpop_regularized = feq + fpop1 + return fpop_regularized + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + # creat a mask to slice boundary cells + boundary = boundary_mask == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + + # compute the equilibrium based on prescribed values and the type of BC + feq = self.calculate_equilibrium(f_post, missing_mask) + + # set the unknown f populations based on the non-equilibrium bounce-back method + f_post_bd = self.bounceback_nonequilibrium(f_post, feq, missing_mask) + + # Regularize the boundary fpop + f_post_bd = self.regularize_fpop(f_post_bd, feq) + + # apply bc + f_post = jnp.where(boundary, f_post_bd, f_post) + return f_post + + def _construct_warp(self): + # assign placeholders for both u and rho based on prescribed_value + _d = self.velocity_set.d + _q = self.velocity_set.q + u = self.prescribed_value if self.bc_type == "velocity" else (0,) * _d + rho = self.prescribed_value if self.bc_type == "pressure" else 0.0 + + # Set local constants TODO: This is a hack and should be fixed with warp update + # _u_vec = wp.vec(_d, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = wp.float32(rho) + _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) + _opp_indices = self.velocity_set.wp_opp_indices + _c = self.velocity_set.wp_c + _c32 = self.velocity_set.wp_c32 + # TODO: this is way less than ideal. we should not be making new types + + @wp.func + def get_normal_vectors_2d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l]) + return normals + + @wp.func + def get_normal_vectors_3d( + lattice_direction: Any, + ): + l = lattice_direction + if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return normals + + @wp.func + def _helper_functional( + fpop: Any, + fsum: Any, + missing_mask: Any, + lattice_direction: Any, + ): + l = lattice_direction + known_mask = missing_mask[_opp_indices[l]] + middle_mask = ~(missing_mask[l] | known_mask) + # fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask) + if middle_mask and known_mask: + fsum += fpop[l] + 2.0 * fpop[l] + elif middle_mask: + fsum += fpop[l] + elif known_mask: + fsum += 2.0 * fpop[l] + return fsum + + @wp.func + def bounceback_nonequilibrium( + fpop: Any, + missing_mask: Any, + density: Any, + velocity: Any, + ): + feq = self.equilibrium_operator.warp_functional(density, velocity) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] + return fpop + + @wp.func + def functional3d_velocity( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_3d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = _fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional3d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_3d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + unormal = -1.0 + _fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional2d_velocity( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_2d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + for d in range(_d): + unormal += _u[d] * normals[d] + _rho = _fsum / (1.0 + unormal) + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + @wp.func + def functional2d_pressure( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + _fsum = self.compute_dtype(0.0) + unormal = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + normals = get_normal_vectors_2d(l) + _fsum = _helper_functional(_f, _fsum, missing_mask, l) + + unormal = -1.0 + _fsum / _rho + _u = unormal * normals + + # impose non-equilibrium bounceback + _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + return _f + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec3i(i, j) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + + # Apply the boundary condition + if _boundary_id == wp.uint8(ZouHeBC.id): + _f = functional(_f_pre, _f_post, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1]] = _f[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + + # Apply the boundary condition + if _boundary_id == wp.uint8(ZouHeBC.id): + _f = functional(_f_pre, _f_post, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + if self.velocity_set.d == 3 and self.bc_type == "velocity": + functional = functional3d_velocity + elif self.velocity_set.d == 3 and self.bc_type == "pressure": + functional = functional3d_pressure + elif self.bc_type == "velocity": + functional = functional2d_velocity + else: + functional = functional2d_pressure + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 84a0b8f..6ce8a26 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -99,6 +99,7 @@ class BoundaryConditionIDStruct: id_HalfwayBounceBackBC: wp.uint8 id_FullwayBounceBackBC: wp.uint8 id_ZouHeBC: wp.uint8 + id_RegularizedBC: wp.uint8 @wp.kernel def kernel2d( @@ -143,6 +144,9 @@ def kernel2d( elif _boundary_id == bc_struct.id_ZouHeBC: # Zouhe boundary condition f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + elif _boundary_id == bc_struct.id_RegularizedBC: + # Regularized boundary condition + f_post_stream = self.regularized_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -211,6 +215,9 @@ def kernel3d( elif _boundary_id == bc_struct.id_ZouHeBC: # Zouhe boundary condition f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + elif _boundary_id == bc_struct.id_RegularizedBC: + # Regularized boundary condition + f_post_stream = self.regularized_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 1608342..b40a69e 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -34,6 +34,7 @@ def __init__(self, operators, boundary_conditions): from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC + from xlb.operator.boundary_condition.bc_regularized import RegularizedBC # Define a list of tuples with attribute names and their corresponding classes conditions = [ @@ -42,6 +43,7 @@ def __init__(self, operators, boundary_conditions): ("halfway_bounce_back_bc", HalfwayBounceBackBC), ("fullway_bounce_back_bc", FullwayBounceBackBC), ("zouhe_bc", ZouHeBC), + ("regularized_bc", RegularizedBC), ] # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. From 4b230cb9924e480f7ae735a6b3185a100c85b8a2 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 15 Aug 2024 12:36:11 -0400 Subject: [PATCH 070/144] improving the call to BC functionals in Warp --- examples/cfd/flow_past_sphere_3d.py | 8 +- .../boundary_condition/bc_regularized.py | 2 - xlb/operator/boundary_condition/bc_zouhe.py | 4 +- xlb/operator/stepper/nse_stepper.py | 124 +++++++++++------- xlb/operator/stepper/stepper.py | 40 +----- 5 files changed, 81 insertions(+), 97 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 55b26d4..e674544 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -114,11 +114,7 @@ def post_process(self, i): rho = rho[:, 1:-1, 1:-1, 1:-1][0] u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - fields = {"u_magnitude": u_magnitude, - "u_x": u[0], - "u_y": u[1], - "u_z": u[2], - "rho": rho} + fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho} save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) @@ -126,7 +122,7 @@ def post_process(self, i): if __name__ == "__main__": # Running the simulation - grid_shape = (512//2, 128//2, 128//2) + grid_shape = (512 // 2, 128 // 2, 128 // 2) velocity_set = xlb.velocity_set.D3Q19() backend = ComputeBackend.JAX precision_policy = PrecisionPolicy.FP32FP32 diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 63c8df4..1f76037 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -40,8 +40,6 @@ class RegularizedBC(ZouHeBC): lattice Boltzmann method. Physical Review E, 77(5), 056703. doi:10.1103/PhysRevE.77.056703 """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, bc_type, diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 3603f35..e27c0d9 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -45,8 +45,8 @@ def __init__( ): # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). - self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__) assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." + self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__ + "_" + bc_type) self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() self.prescribed_value = prescribed_value @@ -110,7 +110,7 @@ def calculate_vel(self, fpop, rho, missing_mask): normals = self._get_normal_vec(missing_mask) known_mask, middle_mask = self._get_known_middle_mask(missing_mask) fsum = jnp.sum(fpop * middle_mask, axis=0, keepdims=True) + 2.0 * jnp.sum(fpop * known_mask, axis=0, keepdims=True) - unormal = -1.0 + fsum/rho + unormal = -1.0 + fsum / rho # Return the above unormal as a normal vector which sets the tangential velocities to zero vel = unormal * normals diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 6ce8a26..524c430 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -98,8 +98,55 @@ class BoundaryConditionIDStruct: id_DoNothingBC: wp.uint8 id_HalfwayBounceBackBC: wp.uint8 id_FullwayBounceBackBC: wp.uint8 - id_ZouHeBC: wp.uint8 - id_RegularizedBC: wp.uint8 + id_ZouHeBC_velocity: wp.uint8 + id_ZouHeBC_pressure: wp.uint8 + id_RegularizedBC_velocity: wp.uint8 + id_RegularizedBC_pressure: wp.uint8 + + @wp.func + def apply_post_streaming_bc( + f_pre: Any, + f_post: Any, + missing_mask: Any, + _boundary_id: Any, + bc_struct: BoundaryConditionIDStruct, + ): + # Apply post-streaming type boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: + # Equilibrium boundary condition + f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_DoNothingBC: + # Do nothing boundary condition + f_post = self.DoNothingBC.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: + # Half way boundary condition + f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC_velocity: + # Zouhe boundary condition (bc type = velocity) + f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_ZouHeBC_pressure: + # Zouhe boundary condition (bc type = pressure) + f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_RegularizedBC_velocity: + # Regularized boundary condition (bc type = velocity) + f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, missing_mask) + elif _boundary_id == bc_struct.id_RegularizedBC_pressure: + # Regularized boundary condition (bc type = velocity) + f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, missing_mask) + return f_post + + @wp.func + def apply_post_collision_bc( + f_pre: Any, + f_post: Any, + missing_mask: Any, + _boundary_id: Any, + bc_struct: BoundaryConditionIDStruct, + ): + if _boundary_id == bc_struct.id_FullwayBounceBackBC: + # Full way boundary condition + f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + return f_post @wp.kernel def kernel2d( @@ -132,21 +179,7 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Apply post-streaming type boundary conditions - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC: - # Zouhe boundary condition - f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_RegularizedBC: - # Regularized boundary condition - f_post_stream = self.regularized_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -163,9 +196,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -202,22 +233,8 @@ def kernel3d( # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming boundary conditions - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC: - # Zouhe boundary condition - f_post_stream = self.zouhe_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) - elif _boundary_id == bc_struct.id_RegularizedBC: - # Regularized boundary condition - f_post_stream = self.regularized_bc.warp_functional(f_post_collision, f_post_stream, _missing_mask) + # Apply post-streaming type boundary conditions + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -228,10 +245,8 @@ def kernel3d( # Apply collision f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) - # Apply collision type boundary conditions - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post_collision = self.fullway_bounce_back_bc.warp_functional(f_post_stream, f_post_collision, _missing_mask) + # Apply post-collision type boundary conditions + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -247,22 +262,29 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + # Read the list of bc_to_id created upon instantiation bc_to_id = boundary_condition_registry.bc_to_id - + id_to_bc = boundary_condition_registry.id_to_bc bc_struct = self.warp_functional() - bc_attribute_list = [] - for attribute_str in bc_to_id.keys(): - # Setting the Struct attributes based on the BC class names - setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) - bc_attribute_list.append("id_" + attribute_str) - - # Unused attributes of the struct are set to inernal (id=0) - ll = vars(bc_struct) - for var in ll: - if var not in bc_attribute_list and not var.startswith("_"): + active_bc_list = [] + for bc in self.boundary_conditions: + # Setting the Struct attributes and active BC classes based on the BC class names + bc_name = id_to_bc[bc.id] + setattr(self, bc_name, bc) + setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) + active_bc_list.append("id_" + bc_name) + + # Setting the Struct attributes and active BC classes based on the BC class names + bc_fallback = self.boundary_conditions[0] + for var in vars(bc_struct): + if var not in active_bc_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 setattr(bc_struct, var, 255) + # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not + # produce error when a particular BC is not used in an example. + setattr(self, var.replace("id_", ""), bc_fallback) + # Launch the warp kernel wp.launch( self.warp_kernel, diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index b40a69e..44aab5f 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -9,8 +9,12 @@ class Stepper(Operator): """ def __init__(self, operators, boundary_conditions): + # Get the boundary condition ids + from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + self.operators = operators self.boundary_conditions = boundary_conditions + # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators if op is not None]) assert len(velocity_sets) < 2, "All velocity sets must be the same. Got {}".format(velocity_sets) @@ -24,41 +28,5 @@ def __init__(self, operators, boundary_conditions): assert len(compute_backends) < 2, "All compute backends must be the same. Got {}".format(compute_backends) compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() - # Add boundary conditions - ############################################ - # Warp cannot handle lists of functions currently - # TODO: Fix this later - ############################################ - from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC - from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC - from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC - from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC - from xlb.operator.boundary_condition.bc_regularized import RegularizedBC - - # Define a list of tuples with attribute names and their corresponding classes - conditions = [ - ("equilibrium_bc", EquilibriumBC), - ("do_nothing_bc", DoNothingBC), - ("halfway_bounce_back_bc", HalfwayBounceBackBC), - ("fullway_bounce_back_bc", FullwayBounceBackBC), - ("zouhe_bc", ZouHeBC), - ("regularized_bc", RegularizedBC), - ] - - # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. - bc_fallback = boundary_conditions[0] - - # Iterate over each boundary condition - for attr_name, bc_class in conditions: - for bc in boundary_conditions: - if isinstance(bc, bc_class): - setattr(self, attr_name, bc) - break - elif not hasattr(self, attr_name): - setattr(self, attr_name, bc_fallback) - - ############################################ - # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) From 3f0cda3da0c71b7c51dc65cc7b57068154e01f4a Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 15 Aug 2024 17:28:46 -0400 Subject: [PATCH 071/144] fixed a nasty bug with struct type --- xlb/operator/boundary_condition/bc_zouhe.py | 4 ++-- xlb/operator/stepper/nse_stepper.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index e27c0d9..59d19a9 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -339,7 +339,7 @@ def kernel2d( _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(ZouHeBC.id): + if _boundary_id == wp.uint8(self.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post @@ -364,7 +364,7 @@ def kernel3d( _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(ZouHeBC.id): + if _boundary_id == wp.uint8(self.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 524c430..bd2403d 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -109,7 +109,7 @@ def apply_post_streaming_bc( f_post: Any, missing_mask: Any, _boundary_id: Any, - bc_struct: BoundaryConditionIDStruct, + bc_struct: Any, ): # Apply post-streaming type boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: @@ -141,7 +141,7 @@ def apply_post_collision_bc( f_post: Any, missing_mask: Any, _boundary_id: Any, - bc_struct: BoundaryConditionIDStruct, + bc_struct: Any, ): if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition @@ -154,7 +154,7 @@ def kernel2d( f_1: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: BoundaryConditionIDStruct, + bc_struct: Any, timestep: int, ): # Get the global index @@ -209,7 +209,7 @@ def kernel3d( f_1: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), - bc_struct: BoundaryConditionIDStruct, + bc_struct: Any, timestep: int, ): # Get the global index From 0ddfb2b95f435000ad08e4f032f3a1756ab6fe02 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 16 Aug 2024 00:57:10 -0400 Subject: [PATCH 072/144] fixed a really gnarly bug in ZouHe (Warp) --- xlb/operator/boundary_condition/bc_zouhe.py | 93 ++++++++++----------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 59d19a9..1d2c334 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -193,32 +193,29 @@ def get_normal_vectors_2d( return normals @wp.func - def get_normal_vectors_3d( - lattice_direction: Any, + def _helper_function( + fpop: Any, + missing_mask: Any, ): - l = lattice_direction - if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) - return normals + fsum_known = self.compute_dtype(0.0) + fsum_middle = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[_opp_indices[l]] == wp.uint8(1): + fsum_known += 2. * fpop[l] + elif missing_mask[l] != wp.uint8(1): + fsum_middle += fpop[l] + return fsum_known + fsum_middle @wp.func - def _helper_functional( - fpop: Any, - fsum: Any, + def get_normal_vectors_3d( missing_mask: Any, - lattice_direction: Any, ): - l = lattice_direction - known_mask = missing_mask[_opp_indices[l]] - middle_mask = ~(missing_mask[l] | known_mask) - # fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask) - if middle_mask and known_mask: - fsum += fpop[l] + 2.0 * fpop[l] - elif middle_mask: - fsum += fpop[l] - elif known_mask: - fsum += 2.0 * fpop[l] - return fsum + for l in range(_q): + if ( + missing_mask[l] == wp.uint8(1) + and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1 + ): + return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) @wp.func def bounceback_nonequilibrium( @@ -241,16 +238,16 @@ def functional3d_velocity( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_3d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate rho + fsum = _helper_function(_f, missing_mask) + unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = _fsum / (1.0 + unormal) + _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) @@ -264,14 +261,13 @@ def functional3d_pressure( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_3d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) - unormal = -1.0 + _fsum / _rho + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate velocity + fsum = _helper_function(_f, missing_mask) + unormal = -1.0 + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback @@ -286,16 +282,16 @@ def functional2d_velocity( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_2d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate rho + fsum = _helper_function(_f, missing_mask) + unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = _fsum / (1.0 + unormal) + _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) @@ -309,14 +305,13 @@ def functional2d_pressure( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_2d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) - unormal = -1.0 + _fsum / _rho + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate velocity + fsum = _helper_function(_f, missing_mask) + unormal = -1.0 + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback From 398092d1ba96d67c55709b83e5f008c5300e9942 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 16 Aug 2024 12:42:48 -0400 Subject: [PATCH 073/144] Regularized (JAX) works! --- examples/cfd/flow_past_sphere_3d.py | 11 +- .../boundary_condition/bc_regularized.py | 101 +++++++++--------- 2 files changed, 55 insertions(+), 57 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index e674544..70639c4 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -6,6 +6,7 @@ from xlb.operator.boundary_condition import ( FullwayBounceBackBC, ZouHeBC, + RegularizedBC, EquilibriumBC, DoNothingBC, ) @@ -68,9 +69,11 @@ def define_boundary_indices(self): def setup_boundary_conditions(self): inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = ZouHeBC("velocity", (0.04, 0.0, 0.0), indices=inlet) + bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet) + # bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_outlet = ZouHeBC("pressure", 1.0, indices=outlet) + bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) + # bc_outlet = DoNothingBC(indices=outlet) bc_sphere = FullwayBounceBackBC(indices=sphere) self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because @@ -111,10 +114,10 @@ def post_process(self, i): # remove boundary cells u = u[:, 1:-1, 1:-1, 1:-1] - rho = rho[:, 1:-1, 1:-1, 1:-1][0] + rho = rho[:, 1:-1, 1:-1, 1:-1] u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho} + fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]} save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 1f76037..838d5d9 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -94,8 +94,8 @@ def regularize_fpop(self, fpop, feq): # Qi = cc - cs^2*I # multiply off-diagonal elements by 2 because the Q tensor is symmetric - # Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) - # Qi = Qi.at[:, offdiagonal].multiply(2.0) + Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) + Qi = Qi.at[:, offdiagonal].multiply(2.0) # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} f_neq = fpop - feq @@ -160,32 +160,29 @@ def get_normal_vectors_2d( return normals @wp.func - def get_normal_vectors_3d( - lattice_direction: Any, + def _helper_function( + fpop: Any, + missing_mask: Any, ): - l = lattice_direction - if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) - return normals + fsum_known = self.compute_dtype(0.0) + fsum_middle = self.compute_dtype(0.0) + for l in range(_q): + if missing_mask[_opp_indices[l]] == wp.uint8(1): + fsum_known += 2. * fpop[l] + elif missing_mask[l] != wp.uint8(1): + fsum_middle += fpop[l] + return fsum_known + fsum_middle @wp.func - def _helper_functional( - fpop: Any, - fsum: Any, + def get_normal_vectors_3d( missing_mask: Any, - lattice_direction: Any, ): - l = lattice_direction - known_mask = missing_mask[_opp_indices[l]] - middle_mask = ~(missing_mask[l] | known_mask) - # fsum += fpop[l] * float(middle_mask) + 2.0 * fpop[l] * float(known_mask) - if middle_mask and known_mask: - fsum += fpop[l] + 2.0 * fpop[l] - elif middle_mask: - fsum += fpop[l] - elif known_mask: - fsum += 2.0 * fpop[l] - return fsum + for l in range(_q): + if ( + missing_mask[l] == wp.uint8(1) + and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1 + ): + return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) @wp.func def bounceback_nonequilibrium( @@ -208,16 +205,16 @@ def functional3d_velocity( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_3d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate rho + fsum = _helper_function(_f, missing_mask) + unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = _fsum / (1.0 + unormal) + _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) @@ -231,14 +228,13 @@ def functional3d_pressure( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_3d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) - unormal = -1.0 + _fsum / _rho + # Find normal vector + normals = get_normal_vectors_3d(missing_mask) + + # calculate velocity + fsum = _helper_function(_f, missing_mask) + unormal = -1.0 + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback @@ -253,16 +249,16 @@ def functional2d_velocity( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_2d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate rho + fsum = _helper_function(_f, missing_mask) + unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = _fsum / (1.0 + unormal) + _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) @@ -276,14 +272,13 @@ def functional2d_pressure( ): # Post-streaming values are only modified at missing direction _f = f_post - _fsum = self.compute_dtype(0.0) - unormal = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - normals = get_normal_vectors_2d(l) - _fsum = _helper_functional(_f, _fsum, missing_mask, l) - unormal = -1.0 + _fsum / _rho + # Find normal vector + normals = get_normal_vectors_2d(missing_mask) + + # calculate velocity + fsum = _helper_function(_f, missing_mask) + unormal = -1.0 + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback @@ -306,7 +301,7 @@ def kernel2d( _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(ZouHeBC.id): + if _boundary_id == wp.uint8(self.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post @@ -331,7 +326,7 @@ def kernel3d( _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(ZouHeBC.id): + if _boundary_id == wp.uint8(self.id): _f = functional(_f_pre, _f_post, _missing_mask) else: _f = _f_post From 3d54ffa06befcc8f5290356b8f69e15d42fd9f71 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 16 Aug 2024 15:59:34 -0400 Subject: [PATCH 074/144] some cleanup and rearrangement --- .../boundary_condition/bc_regularized.py | 35 ++++++++++++------- xlb/operator/boundary_condition/bc_zouhe.py | 23 ++++++------ 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 838d5d9..437cf03 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -168,7 +168,7 @@ def _helper_function( fsum_middle = self.compute_dtype(0.0) for l in range(_q): if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += 2. * fpop[l] + fsum_known += 2.0 * fpop[l] elif missing_mask[l] != wp.uint8(1): fsum_middle += fpop[l] return fsum_known + fsum_middle @@ -178,20 +178,15 @@ def get_normal_vectors_3d( missing_mask: Any, ): for l in range(_q): - if ( - missing_mask[l] == wp.uint8(1) - and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1 - ): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) @wp.func def bounceback_nonequilibrium( fpop: Any, + feq: Any, missing_mask: Any, - density: Any, - velocity: Any, ): - feq = self.equilibrium_operator.warp_functional(density, velocity) for l in range(_q): if missing_mask[l] == wp.uint8(1): fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] @@ -217,7 +212,11 @@ def functional3d_velocity( _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) return _f @wp.func @@ -238,7 +237,11 @@ def functional3d_pressure( _u = unormal * normals # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) return _f @wp.func @@ -261,7 +264,11 @@ def functional2d_velocity( _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) return _f @wp.func @@ -282,7 +289,11 @@ def functional2d_pressure( _u = unormal * normals # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) + + # Regularize the boundary fpop + _f = regularize_fpop(_f, feq) return _f # Construct the warp kernel diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 1d2c334..fbb463b 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -201,7 +201,7 @@ def _helper_function( fsum_middle = self.compute_dtype(0.0) for l in range(_q): if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += 2. * fpop[l] + fsum_known += 2.0 * fpop[l] elif missing_mask[l] != wp.uint8(1): fsum_middle += fpop[l] return fsum_known + fsum_middle @@ -211,20 +211,15 @@ def get_normal_vectors_3d( missing_mask: Any, ): for l in range(_q): - if ( - missing_mask[l] == wp.uint8(1) - and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1 - ): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) @wp.func def bounceback_nonequilibrium( fpop: Any, + feq: Any, missing_mask: Any, - density: Any, - velocity: Any, ): - feq = self.equilibrium_operator.warp_functional(density, velocity) for l in range(_q): if missing_mask[l] == wp.uint8(1): fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] @@ -250,7 +245,8 @@ def functional3d_velocity( _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f @wp.func @@ -271,7 +267,8 @@ def functional3d_pressure( _u = unormal * normals # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f @wp.func @@ -294,7 +291,8 @@ def functional2d_velocity( _rho = fsum / (1.0 + unormal) # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f @wp.func @@ -315,7 +313,8 @@ def functional2d_pressure( _u = unormal * normals # impose non-equilibrium bounceback - _f = bounceback_nonequilibrium(_f, missing_mask, _rho, _u) + feq = self.equilibrium_operator.warp_functional(_rho, _u) + _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f # Construct the warp kernel From b15bb05d7a5d25c2007d64396bda396a4337ea25 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 16 Aug 2024 16:36:50 -0400 Subject: [PATCH 075/144] Regularized (Warp) also completed and verified! --- examples/cfd/flow_past_sphere_3d.py | 4 +- .../boundary_condition/bc_regularized.py | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 70639c4..0ce07e9 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -92,7 +92,7 @@ def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") def run(self, num_steps, post_process_interval=100): for i in range(num_steps): @@ -127,7 +127,7 @@ def post_process(self, i): # Running the simulation grid_shape = (512 // 2, 128 // 2, 128 // 2) velocity_set = xlb.velocity_set.D3Q19() - backend = ComputeBackend.JAX + backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 437cf03..3d38f89 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -62,6 +62,25 @@ def __init__( # The operator to compute the momentum flux self.momentum_flux = SecondMoment() + # helper function + def compute_qi(self): + # Qi = cc - cs^2*I + dim = self.velocity_set.d + Qi = self.velocity_set.cc + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + Qi[:, diagonal] += -1.0 / 3.0 + Qi[:, offdiagonal] *= 2.0 + return Qi + @partial(jit, static_argnums=(0,), inline=True) def regularize_fpop(self, fpop, feq): """ @@ -82,6 +101,8 @@ def regularize_fpop(self, fpop, feq): # Qi = cc - cs^2*I dim = self.velocity_set.d weights = self.velocity_set.w[(slice(None),) + (None,) * dim] + # TODO: if I use the following I get NaN ! figure out why! + # Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype) Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype) if dim == 3: diagonal = (0, 3, 5) @@ -142,10 +163,14 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) + # compute Qi tensor and store it in self + _qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi())) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.wp_opp_indices + _w = self.velocity_set.wp_w _c = self.velocity_set.wp_c _c32 = self.velocity_set.wp_c32 # TODO: this is way less than ideal. we should not be making new types @@ -192,6 +217,32 @@ def bounceback_nonequilibrium( fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] return fpop + @wp.func + def regularize_fpop( + fpop: Any, + feq: Any, + ): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + """ + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + f_neq = fpop - feq + PiNeq = self.momentum_flux.warp_functional(f_neq) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + QiPi1 = _f_vec() + for l in range(_q): + QiPi1[l] = 0.0 + for t in range(nt): + QiPi1[l] += _qi[l, t] * PiNeq[t] + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = 9.0 / 2.0 * _w[l] * QiPi1[l] + fpop[l] = feq[l] + fpop1 + return fpop + @wp.func def functional3d_velocity( f_pre: Any, From 4cecd27e81e474cc0c207dd7179da97bb75a4875 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Sun, 18 Aug 2024 15:49:28 -0400 Subject: [PATCH 076/144] consistent naming for Moments. Macroscopic is now an alias for zeroth and first moments. --- xlb/operator/macroscopic/__init__.py | 2 +- .../macroscopic/{macroscopic.py => zero_first_moments.py} | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename xlb/operator/macroscopic/{macroscopic.py => zero_first_moments.py} (97%) diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index c0755ad..42b747f 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.macroscopic.macroscopic import Macroscopic as Macroscopic +from xlb.operator.macroscopic.zero_first_moments import FirstAndZerothMoment as Macroscopic from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/zero_first_moments.py similarity index 97% rename from xlb/operator/macroscopic/macroscopic.py rename to xlb/operator/macroscopic/zero_first_moments.py index 13d3817..7e64fc3 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/zero_first_moments.py @@ -10,9 +10,9 @@ from xlb.operator.operator import Operator -class Macroscopic(Operator): +class FirstAndZerothMoment(Operator): """ - Base class for all macroscopic operators + A class to compute first and zeroth moments of distribution functions. TODO: Currently this is only used for the standard rho and u moments. In the future, this should be extended to include higher order moments From 76597d4f5cf8f88b69272e62cdb7cf9e43c62969 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 20 Aug 2024 08:14:29 -0400 Subject: [PATCH 077/144] used lax.broadcast_in_dim instead of jnp.repeat plus other minor changes --- .../boundary_condition/bc_fullway_bounce_back.py | 4 +++- .../boundary_condition/bc_halfway_bounce_back.py | 4 +++- xlb/operator/boundary_condition/bc_regularized.py | 14 ++++++++------ xlb/operator/boundary_condition/bc_zouhe.py | 14 ++++++++------ xlb/operator/macroscopic/__init__.py | 2 +- xlb/operator/macroscopic/zero_first_moments.py | 2 +- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 85a3c35..0083bae 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -47,7 +48,8 @@ def __init__( @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) def _construct_warp(self): diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index df947c6..2ed0067 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -50,7 +51,8 @@ def __init__( @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 3d38f89..413a37f 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -139,7 +140,8 @@ def regularize_fpop(self, fpop, feq): def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): # creat a mask to slice boundary cells boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # compute the equilibrium based on prescribed values and the type of BC feq = self.calculate_equilibrium(f_post, missing_mask) @@ -185,7 +187,7 @@ def get_normal_vectors_2d( return normals @wp.func - def _helper_function( + def _get_fsum( fpop: Any, missing_mask: Any, ): @@ -256,7 +258,7 @@ def functional3d_velocity( normals = get_normal_vectors_3d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -283,7 +285,7 @@ def functional3d_pressure( normals = get_normal_vectors_3d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals @@ -308,7 +310,7 @@ def functional2d_velocity( normals = get_normal_vectors_2d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -335,7 +337,7 @@ def functional2d_pressure( normals = get_normal_vectors_2d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index fbb463b..3b69b21 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -156,7 +157,8 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): # creat a mask to slice boundary cells boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # compute the equilibrium based on prescribed values and the type of BC feq = self.calculate_equilibrium(f_post, missing_mask) @@ -193,7 +195,7 @@ def get_normal_vectors_2d( return normals @wp.func - def _helper_function( + def _get_fsum( fpop: Any, missing_mask: Any, ): @@ -238,7 +240,7 @@ def functional3d_velocity( normals = get_normal_vectors_3d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -262,7 +264,7 @@ def functional3d_pressure( normals = get_normal_vectors_3d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals @@ -284,7 +286,7 @@ def functional2d_velocity( normals = get_normal_vectors_2d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -308,7 +310,7 @@ def functional2d_pressure( normals = get_normal_vectors_2d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 42b747f..38195cd 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.macroscopic.zero_first_moments import FirstAndZerothMoment as Macroscopic +from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment diff --git a/xlb/operator/macroscopic/zero_first_moments.py b/xlb/operator/macroscopic/zero_first_moments.py index 7e64fc3..fbf7c93 100644 --- a/xlb/operator/macroscopic/zero_first_moments.py +++ b/xlb/operator/macroscopic/zero_first_moments.py @@ -10,7 +10,7 @@ from xlb.operator.operator import Operator -class FirstAndZerothMoment(Operator): +class ZeroAndFirstMoments(Operator): """ A class to compute first and zeroth moments of distribution functions. From 77f17ca80d52312cc02d7a99325678ea8756a1c7 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 20 Aug 2024 17:36:41 -0400 Subject: [PATCH 078/144] minor bugs --- .../bc_halfway_bounce_back.py | 2 +- .../boundary_condition/bc_regularized.py | 19 +++++++++---------- xlb/operator/boundary_condition/bc_zouhe.py | 2 +- xlb/operator/collision/kbc.py | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 2ed0067..a25d669 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -90,7 +90,7 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) + index = wp.vec2i(i, j) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 413a37f..b1bbbe2 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -177,15 +177,6 @@ def _construct_warp(self): _c32 = self.velocity_set.wp_c32 # TODO: this is way less than ideal. we should not be making new types - @wp.func - def get_normal_vectors_2d( - lattice_direction: Any, - ): - l = lattice_direction - if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l]) - return normals - @wp.func def _get_fsum( fpop: Any, @@ -200,6 +191,14 @@ def _get_fsum( fsum_middle += fpop[l] return fsum_known + fsum_middle + @wp.func + def get_normal_vectors_2d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -_u_vec(_c32[0, l], _c32[1, l]) + @wp.func def get_normal_vectors_3d( missing_mask: Any, @@ -359,7 +358,7 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) + index = wp.vec2i(i, j) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 3b69b21..f028ea7 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -329,7 +329,7 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) + index = wp.vec2i(i, j) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index da0aee5..e4e5d58 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -311,7 +311,7 @@ def kernel2d( ): # Get the global index i, j = wp.tid() - index = wp.vec3i(i, j) # TODO: Warp needs to fix this + index = wp.vec2i(i, j) # TODO: Warp needs to fix this # Load needed values _f = _f_vec() From f83dcce0c587f5675fb9b1359a0bb3b058fbe464 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 20 Aug 2024 17:37:03 -0400 Subject: [PATCH 079/144] WIP: extrapolation outflow bc initial commit --- xlb/operator/boundary_condition/__init__.py | 1 + .../bc_extrapolation_outflow.py | 177 ++++++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 xlb/operator/boundary_condition/bc_extrapolation_outflow.py diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 506da1d..b7ede03 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -8,3 +8,4 @@ from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC +from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py new file mode 100644 index 0000000..0270f43 --- /dev/null +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -0,0 +1,177 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +import jax.lax as lax +from functools import partial +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + + +class ExtrapolationOutflowBC(BoundaryCondition): + """ + Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. + + TODO: Implement moving boundary conditions for this + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + ): + # Auxiliary population variable to store previous time-step results + self.fpop_aux = None + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + indices, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + boundary = boundary_mask == self.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + return jnp.where( + jnp.logical_and(missing_mask, boundary), + f_pre[self.velocity_set.opp_indices], + f_post, + ) + + def _construct_warp(self): + # Set local constants + sound_speed = 1.0 / wp.sqrt(3.0) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _c = self.velocity_set.wp_c + _q = self.velocity_set.q + + @wp.func + def get_normal_vectors_2d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -wp.vec2i(_c[0, l], _c[1, l]) + + @wp.func + def get_normal_vectors_3d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -wp.vec2i(_c[0, l], _c[1, l], _c[2, l]) + + # Construct the functional for this BC + @wp.func + def functional( + f_pre: Any, + f_post: Any, + f_nbr: Any, + missing_mask: Any, + ): + # Post-streaming values are only modified at missing direction + _f = f_post + for l in range(self.velocity_set.q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + _f[l] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_nbr[l] + + return _f + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + index_nbr = index - get_normal_vectors_2d(missing_mask) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _faux = _f_vec() + for l in range(self.velocity_set.q): + # q-sized vector of populations + _faux[l] = f_pre[l, index_nbr[0], index_nbr[1]] + + # Apply the boundary condition + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + _f = functional(_f_pre, _f_post, _faux, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1]] = _f[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + index_nbr = index - get_normal_vectors_3d(missing_mask) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _faux = _f_vec() + for l in range(self.velocity_set.q): + # q-sized vector of populations + _faux[l] = f_pre[l, index_nbr[0], index_nbr[1], index_nbr[2]] + + # Apply the boundary condition + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + _f = functional(_f_pre, _f_post, _faux, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = _f[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post From b20b2af66a2a0fafe9b751c54e53bf4f8a78a5b1 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 20 Aug 2024 23:46:47 -0400 Subject: [PATCH 080/144] Extrapolation outflow BC added in Warp --- examples/cfd/flow_past_sphere_3d.py | 11 +--- .../boundary_condition/bc_do_nothing.py | 1 + .../boundary_condition/bc_equilibrium.py | 1 + .../bc_extrapolation_outflow.py | 36 +++++++----- .../bc_halfway_bounce_back.py | 1 + .../boundary_condition/bc_regularized.py | 4 ++ xlb/operator/boundary_condition/bc_zouhe.py | 4 ++ xlb/operator/stepper/nse_stepper.py | 55 ++++++++++++++++--- 8 files changed, 82 insertions(+), 31 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 0ce07e9..220489d 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -3,13 +3,7 @@ from xlb.precision_policy import PrecisionPolicy from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import ( - FullwayBounceBackBC, - ZouHeBC, - RegularizedBC, - EquilibriumBC, - DoNothingBC, -) +from xlb.operator.boundary_condition import FullwayBounceBackBC, ZouHeBC, RegularizedBC, EquilibriumBC, DoNothingBC, ExtrapolationOutflowBC from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image @@ -72,8 +66,9 @@ def setup_boundary_conditions(self): bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet) # bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) + # bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) # bc_outlet = DoNothingBC(indices=outlet) + bc_outlet = ExtrapolationOutflowBC(indices=outlet) bc_sphere = FullwayBounceBackBC(indices=sphere) self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 5049b1f..df0186a 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -56,6 +56,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): return f_pre diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 1937a17..29f07bb 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -79,6 +79,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 0270f43..d401717 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -24,9 +24,16 @@ class ExtrapolationOutflowBC(BoundaryCondition): """ - Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. + Extrapolation outflow boundary condition for a lattice Boltzmann method simulation. - TODO: Implement moving boundary conditions for this + This class implements the extrapolation outflow boundary condition, which is a type of outflow boundary condition + that uses extrapolation to avoid strong wave reflections. + + References + ---------- + Geier, M., Schönherr, M., Pasquali, A., & Krafczyk, M. (2015). The cumulant lattice Boltzmann equation in three + dimensions: Theory and validation. Computers & Mathematics with Applications, 70(4), 507-547. + doi:10.1016/j.camwa.2015.05.001. """ id = boundary_condition_registry.register_boundary_condition(__qualname__) @@ -38,9 +45,6 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, ): - # Auxiliary population variable to store previous time-step results - self.fpop_aux = None - # Call the parent constructor super().__init__( ImplementationStep.STREAMING, @@ -83,7 +87,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -wp.vec2i(_c[0, l], _c[1, l], _c[2, l]) + return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) # Construct the functional for this BC @wp.func @@ -113,14 +117,16 @@ def kernel2d( # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) - index_nbr = index - get_normal_vectors_2d(missing_mask) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) _faux = _f_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _faux[l] = f_pre[l, index_nbr[0], index_nbr[1]] + + # special preparation of auxiliary data + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + index_nbr = index - get_normal_vectors_2d(_missing_mask) + for l in range(self.velocity_set.q): + _faux[l] = _f_pre[l, index_nbr[0], index_nbr[1]] # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): @@ -143,14 +149,16 @@ def kernel3d( # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) - index_nbr = index - get_normal_vectors_3d(missing_mask) # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) _faux = _f_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _faux[l] = f_pre[l, index_nbr[0], index_nbr[1], index_nbr[2]] + + # special preparation of auxiliary data + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + index_nbr = index - get_normal_vectors_3d(_missing_mask) + for l in range(self.velocity_set.q): + _faux[l] = _f_pre[l, index_nbr[0], index_nbr[1], index_nbr[2]] # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index a25d669..004f792 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -68,6 +68,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index b1bbbe2..36ce152 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -248,6 +248,7 @@ def regularize_fpop( def functional3d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -275,6 +276,7 @@ def functional3d_velocity( def functional3d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -300,6 +302,7 @@ def functional3d_pressure( def functional2d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -327,6 +330,7 @@ def functional2d_velocity( def functional2d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index f028ea7..8fe76d1 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -231,6 +231,7 @@ def bounceback_nonequilibrium( def functional3d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -255,6 +256,7 @@ def functional3d_velocity( def functional3d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -277,6 +279,7 @@ def functional3d_pressure( def functional2d_velocity( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -301,6 +304,7 @@ def functional2d_velocity( def functional2d_pressure( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index bd2403d..014f7ee 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -88,6 +88,8 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _c = self.velocity_set.wp_c + _q = self.velocity_set.q @wp.struct class BoundaryConditionIDStruct: @@ -102,11 +104,13 @@ class BoundaryConditionIDStruct: id_ZouHeBC_pressure: wp.uint8 id_RegularizedBC_velocity: wp.uint8 id_RegularizedBC_pressure: wp.uint8 + id_ExtrapolationOutflowBC: wp.uint8 @wp.func def apply_post_streaming_bc( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, _boundary_id: Any, bc_struct: Any, @@ -114,25 +118,28 @@ def apply_post_streaming_bc( # Apply post-streaming type boundary conditions if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, missing_mask) + f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, missing_mask) + f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, missing_mask) + f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, missing_mask) + f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) + elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + # Regularized boundary condition (bc type = velocity) + f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) return f_post @wp.func @@ -148,6 +155,22 @@ def apply_post_collision_bc( f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) return f_post + @wp.func + def get_normal_vectors_2d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -wp.vec2i(_c[0, l], _c[1, l]) + + @wp.func + def get_normal_vectors_3d( + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) + @wp.kernel def kernel2d( f_0: wp.array3d(dtype=Any), @@ -163,6 +186,7 @@ def kernel2d( # Get the boundary id and missing mask f_post_collision = _f_vec() + f_auxiliary = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): @@ -178,8 +202,14 @@ def kernel2d( # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) + # special preparation of auxiliary data + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + index_nbr = index - get_normal_vectors_2d(_missing_mask) + for l in range(self.velocity_set.q): + f_auxiliary[l] = f_0[l, index_nbr[0], index_nbr[1]] + # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -218,6 +248,7 @@ def kernel3d( # Get the boundary id and missing mask f_post_collision = _f_vec() + f_auxiliary = _f_vec() _boundary_id = boundary_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): @@ -233,8 +264,14 @@ def kernel3d( # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) + # special preparation of auxiliary data + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + index_nbr = index - get_normal_vectors_3d(_missing_mask) + for l in range(self.velocity_set.q): + f_auxiliary[l] = f_0[l, index_nbr[0], index_nbr[1], index_nbr[2]] + # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) From c3db21cea3b37e723e05d2fbb508704431ab94f6 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 26 Aug 2024 14:21:37 -0400 Subject: [PATCH 081/144] added ExtrapolationOutflow (JAX) --- .../bc_extrapolation_outflow.py | 61 +++++++++++++++++++ .../boundary_condition/boundary_condition.py | 10 +++ xlb/operator/stepper/nse_stepper.py | 1 + 3 files changed, 72 insertions(+) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index d401717..e0b0244 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -8,6 +8,8 @@ from functools import partial import warp as wp from typing import Any +from collections import Counter +import numpy as np from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -54,6 +56,65 @@ def __init__( indices, ) + # find and store the normal vector using indices + self._get_normal_vec(indices) + + def _get_normal_vec(self, indices): + # Get the frequency count and most common element directly + freq_counts = [Counter(coord).most_common(1)[0] for coord in indices] + + # Extract counts and elements + counts = np.array([count for _, count in freq_counts]) + elements = np.array([element for element, _ in freq_counts]) + + # Normalize the counts + self.normal = counts // counts.max() + + # Reverse the normal vector if the most frequent element is 0 + if elements[np.argmax(counts)] == 0: + self.normal *= -1 + return + + @partial(jit, static_argnums=(0,), inline=True) + def _roll(self, fld, vec): + """ + Perform rolling operation of a field with dimentions [q, nx, ny, nz] in a direction + given by vec. All q-directions are rolled at the same time. + # TODO: how to improve this for multi-gpu runs? + """ + if self.velocity_set.d == 2: + return jnp.roll(fld, (vec[0], vec[1]), axis=(1, 2)) + elif self.velocity_set.d == 3: + return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) + + @partial(jit, static_argnums=(0,), inline=True) + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + """ + Prepare the auxilary distribution functions for the boundary condition. + Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision + """ + sound_speed = 1.0 / jnp.sqrt(3.0) + boundary = boundary_mask == self.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + + # Roll boundary mask in the opposite of the normal vector to mask its next immediate neighbour + neighbour = self._roll(boundary, -self.normal) + + # gather post-streaming values associated with previous time-step to construct the auxilary data for BC + fpop = jnp.where(boundary, f_pre, f_post) + fpop_neighbour = jnp.where(neighbour, f_pre, f_post) + + # With fpop_neighbour isolated, now roll it back to be positioned at the boundary for subsequent operations + fpop_neighbour = self._roll(fpop_neighbour, self.normal) + fpop_extrapolated = sound_speed * fpop_neighbour + (1.0 - sound_speed) * fpop + + # Use the iknown directions of f_postcollision that leave the domain during streaming to store the BC data + opp = self.velocity_set.opp_indices + known_mask = missing_mask[opp] + f_post = jnp.where(jnp.logical_and(boundary, known_mask), fpop_extrapolated[opp], f_post) + return f_post + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index f4e0a1b..eabb9fe 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -5,6 +5,8 @@ from enum import Enum, auto import warp as wp from typing import Any +from jax import jit +from functools import partial from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -103,3 +105,11 @@ def _get_thread_data_3d( if self.compute_backend == ComputeBackend.WARP: self._get_thread_data_2d = _get_thread_data_2d self._get_thread_data_3d = _get_thread_data_3d + + @partial(jit, static_argnums=(0,), inline=True) + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + """ + A placeholder function for prepare the auxilary distribution functions for the boundary condition. + currently being called after collision only. + """ + return f_post diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 014f7ee..885c0da 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -58,6 +58,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: + f_post_collision = bc.prepare_bc_auxilary_data(f_0, f_post_collision, boundary_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_0, From 25fbef81598f7d1e3658fc7d452d512fc7cea36f Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 26 Aug 2024 16:12:29 -0400 Subject: [PATCH 082/144] corrected the extrapolation BC but MLUPS drop of 12% --- examples/cfd/windtunnel_3d.py | 5 +- .../bc_extrapolation_outflow.py | 3 +- .../bc_fullway_bounce_back.py | 1 + xlb/operator/stepper/nse_stepper.py | 48 ++++++++++++++++--- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index eaac67e..460b217 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -9,6 +9,7 @@ FullwayBounceBackBC, EquilibriumBC, DoNothingBC, + ExtrapolationOutflowBC, ) from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker @@ -81,9 +82,9 @@ def setup_boundary_conditions(self, wind_speed): inlet, outlet, walls, car = self.define_boundary_indices() bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) - bc_do_nothing = DoNothingBC(indices=outlet) + bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) bc_car = FullwayBounceBackBC(indices=car) - self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_car] + self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] def setup_boundary_masks(self): indices_boundary_masker = IndicesBoundaryMasker( diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index e0b0244..f2f42fd 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -133,6 +133,7 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _c = self.velocity_set.wp_c _q = self.velocity_set.q + _opp_indices = self.velocity_set.wp_opp_indices @wp.func def get_normal_vectors_2d( @@ -163,7 +164,7 @@ def functional( for l in range(self.velocity_set.q): # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): - _f[l] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_nbr[l] + _f[l] = f_pre[_opp_indices[l]] return _f diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 0083bae..410a39b 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -63,6 +63,7 @@ def _construct_warp(self): def functional( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, ): fliped_f = _f_vec() diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 885c0da..d9bb57e 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -91,6 +91,8 @@ def _construct_warp(self): _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool _c = self.velocity_set.wp_c _q = self.velocity_set.q + _opp_indices = self.velocity_set.wp_opp_indices + sound_speed = 1.0 / wp.sqrt(3.0) @wp.struct class BoundaryConditionIDStruct: @@ -143,17 +145,35 @@ def apply_post_streaming_bc( f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) return f_post + @wp.func + def ExtrapolationOutflowBC_functional2( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + for l in range(self.velocity_set.q): + if missing_mask[l] == wp.uint8(1): + f_post[_opp_indices[l]] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + return f_post + @wp.func def apply_post_collision_bc( f_pre: Any, f_post: Any, + f_aux: Any, missing_mask: Any, _boundary_id: Any, bc_struct: Any, ): if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, missing_mask) + f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + # f_aux is the neighbour's post-streaming values + # Storing post-streaming data in directions that leave the domain + f_post = ExtrapolationOutflowBC_functional2(f_pre, f_post, f_aux, missing_mask) + return f_post @wp.func @@ -205,9 +225,16 @@ def kernel2d( # special preparation of auxiliary data if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - index_nbr = index - get_normal_vectors_2d(_missing_mask) + nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): - f_auxiliary[l] = f_0[l, index_nbr[0], index_nbr[1]] + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1]] # Apply post-streaming type boundary conditions f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) @@ -227,7 +254,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -267,9 +294,16 @@ def kernel3d( # special preparation of auxiliary data if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - index_nbr = index - get_normal_vectors_3d(_missing_mask) + nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): - f_auxiliary[l] = f_0[l, index_nbr[0], index_nbr[1], index_nbr[2]] + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] # Apply post-streaming type boundary conditions f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) @@ -284,7 +318,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): From 8a4c7cdb72b9ba9387ac7c040d55488474b58601 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 26 Aug 2024 17:28:27 -0400 Subject: [PATCH 083/144] fixed the performance issue. All great now! --- .../bc_extrapolation_outflow.py | 26 ++++++++++++++++--- .../boundary_condition/boundary_condition.py | 20 ++++++++++++++ xlb/operator/stepper/nse_stepper.py | 16 ++---------- 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index f2f42fd..3890f7d 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -59,6 +59,10 @@ def __init__( # find and store the normal vector using indices self._get_normal_vec(indices) + # Unpack the two warp functionals needed for this BC! + if self.compute_backend == ComputeBackend.WARP: + self.warp_functional_poststream, self.warp_functional_postcollision = self.warp_functional + def _get_normal_vec(self, indices): # Get the frequency count and most common element directly freq_counts = [Counter(coord).most_common(1)[0] for coord in indices] @@ -151,9 +155,9 @@ def get_normal_vectors_3d( if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) - # Construct the functional for this BC + # Construct the functionals for this BC @wp.func - def functional( + def functional_poststream( f_pre: Any, f_post: Any, f_nbr: Any, @@ -168,6 +172,22 @@ def functional( return _f + @wp.func + def functional_postcollision( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and + # f_pre (posti-streaming values of the current voxel). We use directions that leave the domain + # for storing this prepared data. + _f = f_post + for l in range(self.velocity_set.q): + if missing_mask[l] == wp.uint8(1): + _f[_opp_indices[l]] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + return _f + # Construct the warp kernel @wp.kernel def kernel2d( @@ -234,7 +254,7 @@ def kernel3d( kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel + return [functional_poststream, functional_postcollision], kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index eabb9fe..eb6f92f 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -51,6 +51,24 @@ def __init__( _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + @wp.func + def functional_postcollision( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + return f_post + + @wp.func + def functional_poststream( + f_pre: Any, + f_post: Any, + f_aux: Any, + missing_mask: Any, + ): + return f_post + @wp.func def _get_thread_data_2d( f_pre: wp.array3d(dtype=Any), @@ -105,6 +123,8 @@ def _get_thread_data_3d( if self.compute_backend == ComputeBackend.WARP: self._get_thread_data_2d = _get_thread_data_2d self._get_thread_data_3d = _get_thread_data_3d + self.warp_functional_poststream = functional_poststream + self.warp_functional_postcollision = functional_postcollision @partial(jit, static_argnums=(0,), inline=True) def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index d9bb57e..8d90d93 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -142,19 +142,7 @@ def apply_post_streaming_bc( f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - return f_post - - @wp.func - def ExtrapolationOutflowBC_functional2( - f_pre: Any, - f_post: Any, - f_aux: Any, - missing_mask: Any, - ): - for l in range(self.velocity_set.q): - if missing_mask[l] == wp.uint8(1): - f_post[_opp_indices[l]] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + f_post = self.ExtrapolationOutflowBC.warp_functional_poststream(f_pre, f_post, f_aux, missing_mask) return f_post @wp.func @@ -172,7 +160,7 @@ def apply_post_collision_bc( elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain - f_post = ExtrapolationOutflowBC_functional2(f_pre, f_post, f_aux, missing_mask) + f_post = self.ExtrapolationOutflowBC.warp_functional_postcollision(f_pre, f_post, f_aux, missing_mask) return f_post From 14051cef7fbcb364c3886badde6217d620537d00 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 26 Aug 2024 18:06:38 -0400 Subject: [PATCH 084/144] fixed minor syntax bugs --- .../boundary_condition/bc_do_nothing.py | 6 ++-- .../boundary_condition/bc_equilibrium.py | 6 ++-- .../bc_extrapolation_outflow.py | 30 +++++++++++++++---- .../bc_fullway_bounce_back.py | 6 ++-- .../bc_halfway_bounce_back.py | 6 ++-- .../boundary_condition/bc_regularized.py | 6 ++-- xlb/operator/boundary_condition/bc_zouhe.py | 6 ++-- 7 files changed, 48 insertions(+), 18 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index df0186a..e8a91ee 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -77,7 +77,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -102,7 +103,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 29f07bb..27d5eb2 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -102,7 +102,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -127,7 +128,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 3890f7d..2754a1f 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -206,13 +206,22 @@ def kernel2d( # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - index_nbr = index - get_normal_vectors_2d(_missing_mask) + nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): - _faux[l] = _f_pre[l, index_nbr[0], index_nbr[1]] + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _faux[l] = _f_pre[l, pull_index[0], pull_index[1]] # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - _f = functional(_f_pre, _f_post, _faux, _missing_mask) + # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both + # collision and streaming? + _f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask) else: _f = _f_post @@ -238,13 +247,22 @@ def kernel3d( # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - index_nbr = index - get_normal_vectors_3d(_missing_mask) + nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): - _faux[l] = _f_pre[l, index_nbr[0], index_nbr[1], index_nbr[2]] + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _faux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - _f = functional(_f_pre, _f_post, _faux, _missing_mask) + # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both + # collision and streaming? + _f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 410a39b..6272aca 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -86,7 +86,8 @@ def kernel2d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -111,7 +112,8 @@ def kernel3d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 004f792..a363479 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -98,7 +98,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -123,7 +124,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 36ce152..6bf7af1 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -369,7 +369,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -394,7 +395,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_vec() + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 8fe76d1..61783f8 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -340,7 +340,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -365,7 +366,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f = functional(_f_pre, _f_post, _missing_mask) + _f_aux = _f_post + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post From 2cd61fd28330c3d992f22e894d37aa0dc113d5f7 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 26 Aug 2024 18:31:53 -0400 Subject: [PATCH 085/144] minor refactoring --- .../boundary_condition/bc_regularized.py | 4 +- xlb/operator/collision/kbc.py | 4 +- xlb/operator/stepper/nse_stepper.py | 124 ++++++++++++------ 3 files changed, 88 insertions(+), 44 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 6bf7af1..6958d1b 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -16,7 +16,7 @@ from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry -from xlb.operator.macroscopic.second_moment import SecondMoment +from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux class RegularizedBC(ZouHeBC): @@ -61,7 +61,7 @@ def __init__( ) # The operator to compute the momentum flux - self.momentum_flux = SecondMoment() + self.momentum_flux = MomentumFlux() # helper function def compute_qi(self): diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index e4e5d58..ddd7ecc 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -12,7 +12,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision from xlb.operator import Operator -from xlb.operator.macroscopic import SecondMoment +from xlb.operator.macroscopic import SecondMoment as MomentumFlux class KBC(Collision): @@ -29,7 +29,7 @@ def __init__( precision_policy=None, compute_backend=None, ): - self.momentum_flux = SecondMoment() + self.momentum_flux = MomentumFlux() self.epsilon = 1e-32 self.beta = omega * 0.5 self.inv_beta = 1.0 / self.beta diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 8d90d93..98224ef 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -180,23 +180,14 @@ def get_normal_vectors_3d( if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) - @wp.kernel - def kernel2d( + @wp.func + def get_thread_data_2d( f_0: wp.array3d(dtype=Any), - f_1: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: Any, - timestep: int, + index: Any, ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) # TODO warp should fix this - # Get the boundary id and missing mask f_post_collision = _f_vec() - f_auxiliary = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations @@ -207,11 +198,38 @@ def kernel2d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) + return f_post_collision, _missing_mask - # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + @wp.func + def get_thread_data_3d( + f_0: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=Any), + index: Any, + ): + # Get the boundary id and missing mask + f_post_collision = _f_vec() + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + return f_post_collision, _missing_mask + + @wp.func + def prepare_bc_auxilary_data_2d( + f_0: wp.array3d(dtype=Any), + index: Any, + _boundary_id: Any, + _missing_mask: Any, + bc_struct: Any, + ): # special preparation of auxiliary data + f_auxiliary = _f_vec() if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): @@ -223,6 +241,53 @@ def kernel2d( pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1]] + return f_auxiliary + + @wp.func + def prepare_bc_auxilary_data_3d( + f_0: wp.array4d(dtype=Any), + index: Any, + _boundary_id: Any, + _missing_mask: Any, + bc_struct: Any, + ): + # special preparation of auxiliary data + f_auxiliary = _f_vec() + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + nv = get_normal_vectors_3d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] + return f_auxiliary + + @wp.kernel + def kernel2d( + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), + boundary_mask: wp.array3d(dtype=Any), + missing_mask: wp.array3d(dtype=Any), + bc_struct: Any, + timestep: int, + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) # TODO warp should fix this + + # Read thread data for populations and missing mask + f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Prepare auxilary data for BC (if applicable) + _boundary_id = boundary_mask[0, index[0], index[1]] + f_auxiliary = prepare_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) @@ -262,36 +327,15 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) # TODO warp should fix this - # Get the boundary id and missing mask - f_post_collision = _f_vec() - f_auxiliary = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1], index[2]] - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + # Read thread data for populations and missing mask + f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) # Apply streaming (pull method) f_post_stream = self.stream.warp_functional(f_0, index) - # special preparation of auxiliary data - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - nv = get_normal_vectors_3d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] + # Prepare auxilary data for BC (if applicable) + _boundary_id = boundary_mask[0, index[0], index[1], index[2]] + f_auxiliary = prepare_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) From fbcf525097039bb1a9280533381f73d09d307531 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 09:15:34 -0400 Subject: [PATCH 086/144] weird nan bug in Reg/Zouhe fixed. Python pointer issue! --- .../boundary_condition/bc_regularized.py | 40 ++----------------- xlb/velocity_set/velocity_set.py | 20 ++++++++++ 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 6958d1b..84dbbf9 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -63,25 +63,6 @@ def __init__( # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() - # helper function - def compute_qi(self): - # Qi = cc - cs^2*I - dim = self.velocity_set.d - Qi = self.velocity_set.cc - if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {dim} not supported") - - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi[:, diagonal] += -1.0 / 3.0 - Qi[:, offdiagonal] *= 2.0 - return Qi - @partial(jit, static_argnums=(0,), inline=True) def regularize_fpop(self, fpop, feq): """ @@ -102,22 +83,7 @@ def regularize_fpop(self, fpop, feq): # Qi = cc - cs^2*I dim = self.velocity_set.d weights = self.velocity_set.w[(slice(None),) + (None,) * dim] - # TODO: if I use the following I get NaN ! figure out why! - # Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype) - Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype) - if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {dim} not supported") - - # Qi = cc - cs^2*I - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) - Qi = Qi.at[:, offdiagonal].multiply(2.0) + Qi = jnp.array(self.velocity_set.qi, dtype=self.compute_dtype) # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} f_neq = fpop - feq @@ -166,7 +132,6 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) # compute Qi tensor and store it in self - _qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi())) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(rho) @@ -175,7 +140,8 @@ def _construct_warp(self): _w = self.velocity_set.wp_w _c = self.velocity_set.wp_c _c32 = self.velocity_set.wp_c32 - # TODO: this is way less than ideal. we should not be making new types + _qi = self.velocity_set.wp_qi + # TODO: related to _c32: this is way less than ideal. we should not be making new types @wp.func def _get_fsum( diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index cd63b36..a93d039 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -41,6 +41,7 @@ def __init__(self, d, q, c, w): self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() + self.qi = self._construct_qi() # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary @@ -49,6 +50,7 @@ def __init__(self, d, q, c, w): self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) + self.wp_qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.qi)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -59,6 +61,24 @@ def warp_u_vec(self, dtype): def warp_stream_mat(self, dtype): return wp.mat((self.q, self.d), dtype=dtype) + def _construct_qi(self): + # Qi = cc - cs^2*I + dim = self.d + Qi = self.cc.copy() + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + Qi[:, diagonal] += -1.0 / 3.0 + Qi[:, offdiagonal] *= 2.0 + return Qi + def _construct_lattice_moment(self): """ This function constructs the moments of the lattice. From 91c706a880987ba715e38d68fd9c7fe342181f13 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 13:54:19 -0400 Subject: [PATCH 087/144] modified the sequence of lbm step operators in JAX to match stream-then-collide pattern in Warp. --- xlb/operator/stepper/nse_stepper.py | 36 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 98224ef..76276b1 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -47,41 +47,41 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_0 = self.precision_policy.cast_to_compute_jax(f_0) f_1 = self.precision_policy.cast_to_compute_jax(f_1) + # Apply streaming + f_post_stream = self.stream(f_0) + + # Apply boundary conditions + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.STREAMING: + f_post_stream = bc( + f_0, + f_post_stream, + boundary_mask, + missing_mask, + ) + # Compute the macroscopic variables - rho, u = self.macroscopic(f_0) + rho, u = self.macroscopic(f_post_stream) # Compute equilibrium feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_0, feq, rho, u) + f_post_collision = self.collision(f_post_stream, feq, rho, u) # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_0, f_post_collision, boundary_mask, missing_mask) + f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( - f_0, - f_post_collision, - boundary_mask, - missing_mask, - ) - - # Apply streaming - f_1 = self.stream(f_post_collision) - - # Apply boundary conditions - for bc in self.boundary_conditions: - if bc.implementation_step == ImplementationStep.STREAMING: - f_1 = bc( + f_post_stream, f_post_collision, - f_1, boundary_mask, missing_mask, ) # Copy back to store precision - f_1 = self.precision_policy.cast_to_store_jax(f_1) + f_1 = self.precision_policy.cast_to_store_jax(f_post_collision) return f_1 From c0e311731114b3dc21292989e78e620bca2b8867 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 14:39:41 -0400 Subject: [PATCH 088/144] WIP: initial commit for force computation using momentum exchange method --- xlb/operator/force/__init__.py | 1 + xlb/operator/force/momentum_transfer.py | 216 ++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 xlb/operator/force/__init__.py create mode 100644 xlb/operator/force/momentum_transfer.py diff --git a/xlb/operator/force/__init__.py b/xlb/operator/force/__init__.py new file mode 100644 index 0000000..6a991ce --- /dev/null +++ b/xlb/operator/force/__init__.py @@ -0,0 +1 @@ +from xlb.operator.force.momentum_transfer import MomentumTransfer as MomentumTransfer diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py new file mode 100644 index 0000000..c5ad23b --- /dev/null +++ b/xlb/operator/force/momentum_transfer.py @@ -0,0 +1,216 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit, lax +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream import Stream + + +class MomentumTransfer(Operator): + """ + An opertor for the momentum exchange method to compute the boundary force vector exerted on the solid geometry + based on [1] as described in [3]. Ref [2] shows how [1] is applicable to curved geometries only by using a + bounce-back method (e.g. Bouzidi) that accounts for curved boundaries. + NOTE: this function should be called after BC's are imposed. + [1] A.J.C. Ladd, Numerical simulations of particular suspensions via a discretized Boltzmann equation. + Part 2 (numerical results), J. Fluid Mech. 271 (1994) 311-339. + [2] R. Mei, D. Yu, W. Shyy, L.-S. Luo, Force evaluation in the lattice Boltzmann method involving + curved geometry, Phys. Rev. E 65 (2002) 041203. + [3] Caiazzo, A., & Junk, M. (2008). Boundary forces in lattice Boltzmann: Analysis of momentum exchange + algorithm. Computers & Mathematics with Applications, 55(7), 1415-1423. + + Notes + ----- + This method computes the force exerted on the solid geometry at each boundary node using the momentum exchange method. + The force is computed based on the post-streaming and post-collision distribution functions. This method + should be called after the boundary conditions are imposed. + """ + + def __init__( + self, + no_slip_bc_instance, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + self.no_slip_bc_instance = no_slip_bc_instance + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + # Call the parent constructor + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f, boundary_id, missing_mask): + """ + Parameters + ---------- + f : jax.numpy.ndarray + The post-collision distribution function at each node in the grid. + boundary_id : jax.numpy.ndarray + A grid field with 0 everywhere except for boundary nodes which are designated + by their respective boundary id's. + missing_mask : jax.numpy.ndarray + A grid field with lattice cardinality that specifies missing lattice directions + for each boundary node. + + Returns + ------- + jax.numpy.ndarray + The force exerted on the solid geometry at each boundary node. + """ + # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. + f_post_collision = f + f_post_stream = self.stream(f_post_collision) + f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, boundary_id, missing_mask) + + # Compute momentum transfer + boundary = boundary_id == self.no_slip_bc_instance.id + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) + + # the following will return force as a grid-based field with zero everywhere except for boundary nodes. + opp = self.velocity_set.opp_indices + phi = f_post_collision[opp] + f_post_stream + phi = jnp.where(jnp.logical_and(boundary, missing_mask), phi, 0.0) + force = jnp.tensordot(self.velocity_set.c[:, opp], phi, axes=(-1, 0)) + return force + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _no_slip_id = self.no_slip_bc_instance.id + + # Find velocity index for 0, 0, 0 + for l in range(self.velocity_set.q): + if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0: + zero_index = l + _zero_index = wp.int32(zero_index) + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + boundary_id: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + force: wp.array(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Get the boundary id + _boundary_id = boundary_id[0, index[0], index[1]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Determin if boundary is an edge by checking if center is missing + is_edge = wp.bool(False) + if _boundary_id == wp.uint8(_no_slip_id): + if _missing_mask[_zero_index] == wp.uint8(0): + is_edge = wp.bool(True) + + # If the boundary is an edge then add the momentum transfer + m = wp.vec2() + if is_edge: + # Get the distribution function + f_post_collision = _f_vec() + for l in range(self.velocity_set.q): + f_post_collision[l] = f[l, index[0], index[1]] + + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + + # Compute the momentum transfer + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] + for d in range(self.velocity_set.d): + m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + + wp.atomic_add(force, 0, m) + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + force: wp.array(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Determin if boundary is an edge by checking if center is missing + is_edge = wp.bool(False) + if _boundary_id == wp.uint8(_no_slip_id): + if _missing_mask[_zero_index] == wp.uint8(0): + is_edge = wp.bool(True) + + # If the boundary is an edge then add the momentum transfer + m = wp.vec3() + if is_edge: + # Get the distribution function + f_post_collision = _f_vec() + for l in range(self.velocity_set.q): + f_post_collision[l] = f[l, index[0], index[1], index[2]] + + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + + # Compute the momentum transfer + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] + for d in range(self.velocity_set.d): + m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + + wp.atomic_add(force, 0, m) + + # Return the correct kernel + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, boundary_id, missing_mask): + # Allocate the force vector (the total integral value will be computed) + force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f, boundary_id, missing_mask, force], + dim=f.shape[1:], + ) + return force.numpy() From 23dd4e00f19c80967ee9c36179492b863a70eb4b Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 16:27:12 -0400 Subject: [PATCH 089/144] WIP: using stl_boundary_masker for reading STL geometries --- .../boundary_masker/stl_boundary_masker.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index b4ea8ca..e4be36f 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -12,7 +12,7 @@ class STLBoundaryMasker(Operator): """ - Operator for creating a boundary mask from an STL file + Operator for creating a boundary missing_mask from an STL file """ def __init__( @@ -24,6 +24,12 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation(self, stl_file, origin, spacing, id_number, boundary_id, missing_mask, start_index=(0, 0, 0)): + # Use Warp backend even for this particular operation. + boundary_id, missing_mask = self.warp_implementation(stl_file, origin, spacing, id_number, boundary_id, missing_mask, start_index=(0, 0, 0)) + return wp.to_jax(boundary_id), wp.to_jax(missing_mask) + def _construct_warp(self): # Make constants for warp _c = self.velocity_set.wp_c @@ -32,12 +38,12 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel( - mesh: wp.uint64, + mesh_id: wp.uint64, origin: wp.vec3, spacing: wp.vec3, id_number: wp.int32, - boundary_mask: wp.array4d(dtype=wp.uint8), - mask: wp.array4d(dtype=wp.bool), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): # get index @@ -56,9 +62,9 @@ def kernel( # Compute the maximum length max_length = wp.sqrt( - (spacing[0] * wp.float32(boundary_mask.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(boundary_mask.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(boundary_mask.shape[3])) ** 2.0 + (spacing[0] * wp.float32(boundary_id.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_id.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_id.shape[3])) ** 2.0 ) # evaluate if point is inside mesh @@ -66,7 +72,7 @@ def kernel( face_u = float(0.0) face_v = float(0.0) sign = float(0.0) - if wp.mesh_query_point_sign_winding_number(mesh, pos, max_length, sign, face_index, face_u, face_v): + if wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length, sign, face_index, face_u, face_v): # set point to be solid if sign <= 0: # TODO: fix this # Stream indices @@ -76,9 +82,9 @@ def kernel( for d in range(self.velocity_set.d): push_index[d] = index[d] + _c[d, l] - # Set the boundary id and mask - boundary_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) - mask[l, push_index[0], push_index[1], push_index[2]] = True + # Set the boundary id and missing_mask + boundary_id[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @@ -89,8 +95,8 @@ def warp_implementation( origin, spacing, id_number, - boundary_mask, - mask, + boundary_id, + missing_mask, start_index=(0, 0, 0), ): # Load the mesh @@ -110,11 +116,11 @@ def warp_implementation( origin, spacing, id_number, - boundary_mask, - mask, + boundary_id, + missing_mask, start_index, ], - dim=boundary_mask.shape[1:], + dim=boundary_id.shape[1:], ) - return boundary_mask, mask + return boundary_id, missing_mask From 47a5f26ebb2d8a077a0ec52cd6f1ee147231a6f2 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 16:42:00 -0400 Subject: [PATCH 090/144] renamed boundary_mask to boundary_map --- examples/cfd/flow_past_sphere_3d.py | 10 +-- examples/cfd/lid_driven_cavity_2d.py | 10 +-- examples/cfd/windtunnel_3d.py | 10 +-- .../flow_past_sphere.py | 16 ++--- .../cfd_old_to_be_migrated/taylor_green.py | 6 +- examples/performance/mlups_3d.py | 12 ++-- .../bc_equilibrium/test_bc_equilibrium_jax.py | 6 +- .../test_bc_equilibrium_warp.py | 6 +- .../test_bc_fullway_bounce_back_jax.py | 6 +- .../test_bc_fullway_bounce_back_warp.py | 6 +- .../mask/test_bc_indices_masker_jax.py | 24 +++---- .../mask/test_bc_indices_masker_warp.py | 28 ++++---- xlb/helper/nse_solver.py | 4 +- .../boundary_condition/bc_do_nothing.py | 20 +++--- .../boundary_condition/bc_equilibrium.py | 20 +++--- .../bc_extrapolation_outflow.py | 28 ++++---- .../bc_fullway_bounce_back.py | 20 +++--- .../bc_halfway_bounce_back.py | 20 +++--- .../boundary_condition/bc_regularized.py | 20 +++--- xlb/operator/boundary_condition/bc_zouhe.py | 20 +++--- .../boundary_condition/boundary_condition.py | 14 ++-- .../indices_boundary_masker.py | 22 +++---- .../boundary_masker/stl_boundary_masker.py | 24 +++---- xlb/operator/force/momentum_transfer.py | 24 +++---- xlb/operator/stepper/nse_stepper.py | 64 +++++++++---------- 25 files changed, 218 insertions(+), 222 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 220489d..ba8bb44 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -25,7 +25,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -34,7 +34,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): def _setup(self, omega): self.setup_boundary_conditions() - self.setup_boundary_masks() + self.setup_boundary_masker() self.initialize_fields() self.setup_stepper(omega) @@ -75,13 +75,13 @@ def setup_boundary_conditions(self): # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. # TODO: how to ensure about this behind in the src code? - def setup_boundary_masks(self): + def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) + self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -91,7 +91,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index b4540a9..16fb4f9 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -24,7 +24,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -33,7 +33,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): def _setup(self, omega): self.setup_boundary_conditions() - self.setup_boundary_masks() + self.setup_boundary_masker() self.initialize_fields() self.setup_stepper(omega) @@ -51,13 +51,13 @@ def setup_boundary_conditions(self): bc_walls = HalfwayBounceBackBC(indices=walls) self.boundary_conditions = [bc_top, bc_walls] - def setup_boundary_masks(self): + def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask) + self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -67,7 +67,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 460b217..b456562 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -32,7 +32,7 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -41,7 +41,7 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi def _setup(self, omega, wind_speed): self.setup_boundary_conditions(wind_speed) - self.setup_boundary_masks() + self.setup_boundary_masker() self.initialize_fields() self.setup_stepper(omega) @@ -86,13 +86,13 @@ def setup_boundary_conditions(self, wind_speed): bc_car = FullwayBounceBackBC(indices=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] - def setup_boundary_masks(self): + def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) + self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -103,7 +103,7 @@ def setup_stepper(self, omega): def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py index 68d1c2b..7214130 100644 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -75,7 +75,7 @@ def warp_implementation(self, rho, u, vel): u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_mask = grid.create_field(cardinality=1, dtype=wp.uint8) + boundary_map = grid.create_field(cardinality=1, dtype=wp.uint8) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators @@ -154,23 +154,19 @@ def warp_implementation(self, rho, u, vel): indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_mask, missing_mask, (0, 0, 0)) + boundary_map, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_map, missing_mask, (0, 0, 0)) # Set inlet bc lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_mask, missing_mask, (0, 0, 0) - ) + boundary_map, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_map, missing_mask, (0, 0, 0)) # Set outlet bc lower_bound = (nr - 1, 0, 0) upper_bound = (nr - 1, nr, nr) direction = (-1, 0, 0) - boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_mask, missing_mask, (0, 0, 0) - ) + boundary_map, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_map, missing_mask, (0, 0, 0)) # Set initial conditions rho, u = initializer(rho, u, vel) @@ -185,7 +181,7 @@ def warp_implementation(self, rho, u, vel): num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_mask, missing_mask, _) + f1 = stepper(f0, f1, boundary_map, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) @@ -195,7 +191,7 @@ def warp_implementation(self, rho, u, vel): plt.imshow(u[0, :, nr // 2, :].numpy()) plt.colorbar() plt.subplot(1, 2, 2) - plt.imshow(boundary_mask[0, :, nr // 2, :].numpy()) + plt.imshow(boundary_map[0, :, nr // 2, :].numpy()) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py index 9ed7fa6..c5b40b7 100644 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -113,7 +113,7 @@ def run_taylor_green(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + boundary_map = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -149,10 +149,10 @@ def run_taylor_green(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, boundary_mask, missing_mask, _) + f1 = stepper(f0, f1, boundary_map, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, boundary_mask, missing_mask, _) + f0 = stepper(f0, boundary_map, missing_mask, _) # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 74bfa04..602e741 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -42,9 +42,9 @@ def setup_simulation(args): def create_grid_and_fields(cube_edge): grid_shape = (cube_edge, cube_edge, cube_edge) - grid, f_0, f_1, missing_mask, boundary_mask = create_nse_fields(grid_shape) + grid, f_0, f_1, missing_mask, boundary_map = create_nse_fields(grid_shape) - return grid, f_0, f_1, missing_mask, boundary_mask + return grid, f_0, f_1, missing_mask, boundary_map def define_boundary_indices(grid): @@ -67,7 +67,7 @@ def setup_boundary_conditions(grid): return [bc_top, bc_walls] -def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): +def run(f_0, f_1, backend, grid, boundary_map, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): start_time = time.time() for i in range(num_steps): - f_1 = stepper(f_0, f_1, boundary_mask, missing_mask, i) + f_1 = stepper(f_0, f_1, boundary_map, missing_mask, i) f_0, f_1 = f_1, f_0 wp.synchronize() @@ -98,10 +98,10 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time): def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) - grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) + grid, f_0, f_1, missing_mask, boundary_map = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) + elapsed_time = run(f_0, f_1, backend, grid, boundary_map, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 3e50fdb..9d2e4ff 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -32,7 +32,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -58,7 +58,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_map, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -66,7 +66,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) + f = equilibrium_bc(f_pre, f_post, boundary_map, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 0274eba..917e7e4 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -31,7 +31,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -58,7 +58,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_map, missing_mask, start_index=None) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -66,7 +66,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask) + f = equilibrium_bc(f_pre, f_post, boundary_map, missing_mask) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 1b7edc2..2fe0b40 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -34,7 +34,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -54,7 +54,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([fullway_bc], boundary_map, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0) # Generate a random field with the same shape @@ -67,7 +67,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) + f = fullway_bc(f_pre, f_post, boundary_map, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index da76f5e..b25d39e 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -34,7 +34,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -54,7 +54,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([fullway_bc], boundary_map, missing_mask, start_index=None) # Generate a random field with the same shape random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) @@ -65,7 +65,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask) + f_pre = fullway_bc(f_pre, f_post, boundary_map, missing_mask) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index ddbc761..af121d3 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -34,7 +34,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -56,26 +56,26 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_mask, missing_mask = indices_boundary_masker([test_bc], boundary_mask, missing_mask, start_index=None) + boundary_map, missing_mask = indices_boundary_masker([test_bc], boundary_map, missing_mask, start_index=None) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype - assert boundary_mask.dtype == xlb.Precision.UINT8.jax_dtype + assert boundary_map.dtype == xlb.Precision.UINT8.jax_dtype - assert boundary_mask.shape == (1,) + grid_shape + assert boundary_map.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert jnp.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask = boundary_mask.at[0, indices[0], indices[1]].set(0) - assert jnp.all(boundary_mask == 0) + assert jnp.all(boundary_map[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map = boundary_map.at[0, indices[0], indices[1]].set(0) + assert jnp.all(boundary_map == 0) if dim == 3: - assert jnp.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask = boundary_mask.at[0, indices[0], indices[1], indices[2]].set(0) - assert jnp.all(boundary_mask == 0) + assert jnp.all(boundary_map[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map = boundary_map.at[0, indices[0], indices[1], indices[2]].set(0) + assert jnp.all(boundary_map == 0) if __name__ == "__main__": diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 6919ba9..4d02540 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -32,7 +32,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -54,33 +54,33 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_mask, missing_mask = indices_boundary_masker( + boundary_map, missing_mask = indices_boundary_masker( [test_bc], - boundary_mask, + boundary_map, missing_mask, start_index=(0, 0, 0) if dim == 3 else (0, 0), ) assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype - assert boundary_mask.dtype == xlb.Precision.UINT8.wp_dtype + assert boundary_map.dtype == xlb.Precision.UINT8.wp_dtype - boundary_mask = boundary_mask.numpy() + boundary_map = boundary_map.numpy() missing_mask = missing_mask.numpy() - assert boundary_mask.shape == (1,) + grid_shape + assert boundary_map.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert np.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask[0, indices[0], indices[1]] = 0 - assert np.all(boundary_mask == 0) + assert np.all(boundary_map[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map[0, indices[0], indices[1]] = 0 + assert np.all(boundary_map == 0) if dim == 3: - assert np.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the boundary_mask is zero - boundary_mask[0, indices[0], indices[1], indices[2]] = 0 - assert np.all(boundary_mask == 0) + assert np.all(boundary_map[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the boundary_map is zero + boundary_map[0, indices[0], indices[1], indices[2]] = 0 + assert np.all(boundary_map == 0) if __name__ == "__main__": diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index a42c6ac..96befa6 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -14,6 +14,6 @@ def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compu f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) - boundary_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) + boundary_map = grid.create_field(cardinality=1, dtype=Precision.UINT8) - return grid, f_0, f_1, missing_mask, boundary_mask + return grid, f_0, f_1, missing_mask, boundary_map diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index e8a91ee..4afadfb 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -46,8 +46,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): @@ -65,7 +65,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.uint8), ): # Get the global index @@ -73,10 +73,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(DoNothingBC.id): + if _boundary_map == wp.uint8(DoNothingBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -91,7 +91,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -99,10 +99,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(DoNothingBC.id): + if _boundary_map == wp.uint8(DoNothingBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -117,11 +117,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 27d5eb2..eb28ee8 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -60,11 +60,11 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) new_shape = feq.shape + (1,) * self.velocity_set.d feq = lax.broadcast_in_dim(feq, new_shape, [0]) - boundary = boundary_mask == self.id + boundary = boundary_map == self.id return jnp.where(boundary, feq, f_post) @@ -90,7 +90,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -98,10 +98,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(EquilibriumBC.id): + if _boundary_map == wp.uint8(EquilibriumBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -116,7 +116,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -124,10 +124,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(EquilibriumBC.id): + if _boundary_map == wp.uint8(EquilibriumBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -142,11 +142,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 2754a1f..b019320 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -92,13 +92,13 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): """ Prepare the auxilary distribution functions for the boundary condition. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) - boundary = boundary_mask == self.id + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -121,8 +121,8 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -193,7 +193,7 @@ def functional_postcollision( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -201,11 +201,11 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) _faux = _f_vec() # special preparation of auxiliary data - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -218,7 +218,7 @@ def kernel2d( _faux[l] = _f_pre[l, pull_index[0], pull_index[1]] # Apply the boundary condition - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? _f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask) @@ -234,7 +234,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -242,11 +242,11 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) _faux = _f_vec() # special preparation of auxiliary data - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -259,7 +259,7 @@ def kernel3d( _faux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] # Apply the boundary condition - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? _f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask) @@ -275,11 +275,11 @@ def kernel3d( return [functional_poststream, functional_postcollision], kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 6272aca..d613891 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -46,8 +46,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) @@ -75,17 +75,17 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Check if the boundary is active - if _boundary_id == wp.uint8(FullwayBounceBackBC.id): + if _boundary_map == wp.uint8(FullwayBounceBackBC.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -100,7 +100,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -108,10 +108,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Check if the boundary is active - if _boundary_id == wp.uint8(FullwayBounceBackBC.id): + if _boundary_map == wp.uint8(FullwayBounceBackBC.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -126,11 +126,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index a363479..94cf31b 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -49,8 +49,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): - boundary = boundary_mask == self.id + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -86,7 +86,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -94,10 +94,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_map == wp.uint8(HalfwayBounceBackBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -112,7 +112,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -120,10 +120,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_map == wp.uint8(HalfwayBounceBackBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -138,11 +138,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 84dbbf9..a2ec0e1 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -103,9 +103,9 @@ def regularize_fpop(self, fpop, feq): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): # creat a mask to slice boundary cells - boundary = boundary_mask == self.id + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -323,7 +323,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -331,10 +331,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -349,7 +349,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -357,10 +357,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -383,11 +383,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 61783f8..7a848eb 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -154,9 +154,9 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): + def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): # creat a mask to slice boundary cells - boundary = boundary_mask == self.id + boundary = boundary_map == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -328,7 +328,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -336,10 +336,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -354,7 +354,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -362,10 +362,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) + _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): + if _boundary_map == wp.uint8(self.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -388,11 +388,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): + def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_mask, missing_mask], + inputs=[f_pre, f_post, boundary_map, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index eb6f92f..62b2304 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -73,14 +73,14 @@ def functional_poststream( def _get_thread_data_2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), index: wp.vec2i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1]] + _boundary_map = boundary_map[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -92,20 +92,20 @@ def _get_thread_data_2d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_id, _missing_mask + return _f_pre, _f_post, _boundary_map, _missing_mask @wp.func def _get_thread_data_3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), index: wp.vec3i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] + _boundary_map = boundary_map[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -117,7 +117,7 @@ def _get_thread_data_3d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_id, _missing_mask + return _f_pre, _f_post, _boundary_map, _missing_mask # Construct some helper warp functions for getting tid data if self.compute_backend == ComputeBackend.WARP: @@ -127,7 +127,7 @@ def _get_thread_data_3d( self.warp_functional_postcollision = functional_postcollision @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): """ A placeholder function for prepare the auxilary distribution functions for the boundary condition. currently being called after collision only. diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 7548cf0..b697d34 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -29,7 +29,7 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): + def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=None): # Pad the missing mask to create a grid mask to identify out of bound boundaries # Set padded regin to True (i.e. boundary) dim = missing_mask.ndim - 1 @@ -45,7 +45,7 @@ def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=No if start_index is None: start_index = (0,) * dim - bid = boundary_mask[0] + bid = boundary_map[0] for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" id_number = bc.id @@ -59,13 +59,13 @@ def jax_implementation(self, bclist, boundary_mask, missing_mask, start_index=No # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - boundary_mask = boundary_mask.at[0].set(bid) + boundary_map = boundary_map.at[0].set(bid) grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] if dim == 3: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] - return boundary_mask, missing_mask + return boundary_map, missing_mask def _construct_warp(self): # Make constants for warp @@ -77,7 +77,7 @@ def _construct_warp(self): def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), - boundary_mask: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): @@ -104,14 +104,14 @@ def kernel2d( # Set the missing mask missing_mask[l, index[0], index[1]] = True - boundary_mask[0, index[0], index[1]] = id_number[ii] + boundary_map[0, index[0], index[1]] = id_number[ii] # Construct the warp 3D kernel @wp.kernel def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), - boundary_mask: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -153,14 +153,14 @@ def kernel3d( # Set the missing mask missing_mask[l, index[0], index[1], index[2]] = True - boundary_mask[0, index[0], index[1], index[2]] = id_number[ii] + boundary_map[0, index[0], index[1], index[2]] = id_number[ii] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): + def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=None): dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] @@ -184,11 +184,11 @@ def warp_implementation(self, bclist, boundary_mask, missing_mask, start_index=N inputs=[ indices, id_number, - boundary_mask, + boundary_map, missing_mask, start_index, ], dim=indices.shape[1], ) - return boundary_mask, missing_mask + return boundary_map, missing_mask diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index e4be36f..cec5e72 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -25,10 +25,10 @@ def __init__( super().__init__(velocity_set, precision_policy, compute_backend) @Operator.register_backend(ComputeBackend.JAX) - def jax_implementation(self, stl_file, origin, spacing, id_number, boundary_id, missing_mask, start_index=(0, 0, 0)): + def jax_implementation(self, stl_file, origin, spacing, id_number, boundary_map, missing_mask, start_index=(0, 0, 0)): # Use Warp backend even for this particular operation. - boundary_id, missing_mask = self.warp_implementation(stl_file, origin, spacing, id_number, boundary_id, missing_mask, start_index=(0, 0, 0)) - return wp.to_jax(boundary_id), wp.to_jax(missing_mask) + boundary_map, missing_mask = self.warp_implementation(stl_file, origin, spacing, id_number, boundary_map, missing_mask, start_index=(0, 0, 0)) + return wp.to_jax(boundary_map), wp.to_jax(missing_mask) def _construct_warp(self): # Make constants for warp @@ -42,7 +42,7 @@ def kernel( origin: wp.vec3, spacing: wp.vec3, id_number: wp.int32, - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -62,9 +62,9 @@ def kernel( # Compute the maximum length max_length = wp.sqrt( - (spacing[0] * wp.float32(boundary_id.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(boundary_id.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(boundary_id.shape[3])) ** 2.0 + (spacing[0] * wp.float32(boundary_map.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_map.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_map.shape[3])) ** 2.0 ) # evaluate if point is inside mesh @@ -83,7 +83,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and missing_mask - boundary_id[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + boundary_map[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) missing_mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @@ -95,7 +95,7 @@ def warp_implementation( origin, spacing, id_number, - boundary_id, + boundary_map, missing_mask, start_index=(0, 0, 0), ): @@ -116,11 +116,11 @@ def warp_implementation( origin, spacing, id_number, - boundary_id, + boundary_map, missing_mask, start_index, ], - dim=boundary_id.shape[1:], + dim=boundary_map.shape[1:], ) - return boundary_id, missing_mask + return boundary_map, missing_mask diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index c5ad23b..264e983 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -50,13 +50,13 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f, boundary_id, missing_mask): + def jax_implementation(self, f, boundary_map, missing_mask): """ Parameters ---------- f : jax.numpy.ndarray The post-collision distribution function at each node in the grid. - boundary_id : jax.numpy.ndarray + boundary_map : jax.numpy.ndarray A grid field with 0 everywhere except for boundary nodes which are designated by their respective boundary id's. missing_mask : jax.numpy.ndarray @@ -71,10 +71,10 @@ def jax_implementation(self, f, boundary_id, missing_mask): # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. f_post_collision = f f_post_stream = self.stream(f_post_collision) - f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, boundary_id, missing_mask) + f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, boundary_map, missing_mask) # Compute momentum transfer - boundary = boundary_id == self.no_slip_bc_instance.id + boundary = boundary_map == self.no_slip_bc_instance.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -103,7 +103,7 @@ def _construct_warp(self): @wp.kernel def kernel2d( f: wp.array3d(dtype=Any), - boundary_id: wp.array3d(dtype=wp.uint8), + boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), force: wp.array(dtype=Any), ): @@ -112,7 +112,7 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id - _boundary_id = boundary_id[0, index[0], index[1]] + _boundary_map = boundary_map[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -123,7 +123,7 @@ def kernel2d( # Determin if boundary is an edge by checking if center is missing is_edge = wp.bool(False) - if _boundary_id == wp.uint8(_no_slip_id): + if _boundary_map == wp.uint8(_no_slip_id): if _missing_mask[_zero_index] == wp.uint8(0): is_edge = wp.bool(True) @@ -152,7 +152,7 @@ def kernel2d( @wp.kernel def kernel3d( f: wp.array4d(dtype=Any), - boundary_id: wp.array4d(dtype=wp.uint8), + boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), ): @@ -161,7 +161,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id - _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _boundary_map = boundary_map[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -172,7 +172,7 @@ def kernel3d( # Determin if boundary is an edge by checking if center is missing is_edge = wp.bool(False) - if _boundary_id == wp.uint8(_no_slip_id): + if _boundary_map == wp.uint8(_no_slip_id): if _missing_mask[_zero_index] == wp.uint8(0): is_edge = wp.bool(True) @@ -203,14 +203,14 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, boundary_id, missing_mask): + def warp_implementation(self, f, boundary_map, missing_mask): # Allocate the force vector (the total integral value will be computed) force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f, boundary_id, missing_mask, force], + inputs=[f, boundary_map, missing_mask, force], dim=f.shape[1:], ) return force.numpy() diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 76276b1..0feaf6d 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -39,7 +39,7 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + def jax_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ @@ -56,7 +56,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_post_stream = bc( f_0, f_post_stream, - boundary_mask, + boundary_map, missing_mask, ) @@ -71,12 +71,12 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_mask, missing_mask) + f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_map, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, f_post_collision, - boundary_mask, + boundary_map, missing_mask, ) @@ -115,32 +115,32 @@ def apply_post_streaming_bc( f_post: Any, f_aux: Any, missing_mask: Any, - _boundary_id: Any, + _boundary_map: Any, bc_struct: Any, ): # Apply post-streaming type boundary conditions - if _boundary_id == bc_struct.id_EquilibriumBC: + if _boundary_map == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_DoNothingBC: + elif _boundary_map == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: + elif _boundary_map == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC_velocity: + elif _boundary_map == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ZouHeBC_pressure: + elif _boundary_map == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_RegularizedBC_velocity: + elif _boundary_map == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_RegularizedBC_pressure: + elif _boundary_map == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + elif _boundary_map == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) f_post = self.ExtrapolationOutflowBC.warp_functional_poststream(f_pre, f_post, f_aux, missing_mask) return f_post @@ -151,13 +151,13 @@ def apply_post_collision_bc( f_post: Any, f_aux: Any, missing_mask: Any, - _boundary_id: Any, + _boundary_map: Any, bc_struct: Any, ): - if _boundary_id == bc_struct.id_FullwayBounceBackBC: + if _boundary_map == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + elif _boundary_map == bc_struct.id_ExtrapolationOutflowBC: # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain f_post = self.ExtrapolationOutflowBC.warp_functional_postcollision(f_pre, f_post, f_aux, missing_mask) @@ -224,13 +224,13 @@ def get_thread_data_3d( def prepare_bc_auxilary_data_2d( f_0: wp.array3d(dtype=Any), index: Any, - _boundary_id: Any, + _boundary_map: Any, _missing_mask: Any, bc_struct: Any, ): # special preparation of auxiliary data f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + if _boundary_map == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -247,13 +247,13 @@ def prepare_bc_auxilary_data_2d( def prepare_bc_auxilary_data_3d( f_0: wp.array4d(dtype=Any), index: Any, - _boundary_id: Any, + _boundary_map: Any, _missing_mask: Any, bc_struct: Any, ): # special preparation of auxiliary data f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + if _boundary_map == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -270,7 +270,7 @@ def prepare_bc_auxilary_data_3d( def kernel2d( f_0: wp.array3d(dtype=Any), f_1: wp.array3d(dtype=Any), - boundary_mask: wp.array3d(dtype=Any), + boundary_map: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), bc_struct: Any, timestep: int, @@ -286,11 +286,11 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = boundary_mask[0, index[0], index[1]] - f_auxiliary = prepare_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _boundary_map = boundary_map[0, index[0], index[1]] + f_auxiliary = prepare_bc_auxilary_data_2d(f_0, index, _boundary_map, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -307,7 +307,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -318,7 +318,7 @@ def kernel2d( def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), - boundary_mask: wp.array4d(dtype=Any), + boundary_map: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), bc_struct: Any, timestep: int, @@ -334,11 +334,11 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - f_auxiliary = prepare_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _boundary_map = boundary_map[0, index[0], index[1], index[2]] + f_auxiliary = prepare_bc_auxilary_data_3d(f_0, index, _boundary_map, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -350,7 +350,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_map, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -362,7 +362,7 @@ def kernel3d( return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry @@ -395,7 +395,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): inputs=[ f_0, f_1, - boundary_mask, + boundary_map, missing_mask, bc_struct, timestep, From 1db32421e9980cf94ff3fffff0a1c629a05c22b8 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 21:15:24 -0400 Subject: [PATCH 091/144] WIP: changed mesh_boundary_masker and used it in the wind tunnel example --- examples/cfd/windtunnel_3d.py | 39 ++++++++++++------- .../boundary_condition/bc_do_nothing.py | 2 + .../boundary_condition/bc_equilibrium.py | 2 + .../bc_extrapolation_outflow.py | 2 + .../bc_fullway_bounce_back.py | 2 + .../bc_halfway_bounce_back.py | 2 + .../boundary_condition/bc_regularized.py | 2 + xlb/operator/boundary_condition/bc_zouhe.py | 2 + .../boundary_condition/boundary_condition.py | 2 + xlb/operator/boundary_masker/__init__.py | 4 +- .../indices_boundary_masker.py | 14 ++++--- ...dary_masker.py => mesh_boundary_masker.py} | 38 +++++++++++++----- 12 files changed, 80 insertions(+), 31 deletions(-) rename xlb/operator/boundary_masker/{stl_boundary_masker.py => mesh_boundary_masker.py} (75%) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index b456562..696cb7b 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -12,7 +12,7 @@ ExtrapolationOutflowBC, ) from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker +from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np @@ -64,17 +64,20 @@ def define_boundary_indices(self): for i in range(self.velocity_set.d) ] + # Load the mesh stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" - grid_size_x = self.grid_shape[0] - car_length_lbm_unit = grid_size_x / 4 - car_voxelized, pitch = self.voxelize_stl(stl_filename, car_length_lbm_unit) - - # car_area = np.prod(car_voxelized.shape[1:]) - tx, ty, _ = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape - shift = [tx // 4, ty // 2, 0] - car = np.argwhere(car_voxelized) + shift - car = np.array(car).T - car = [tuple(car[i]) for i in range(self.velocity_set.d)] + mesh = trimesh.load_mesh(stl_filename, process=False) + mesh_points = mesh.vertices + + # Transform the mesh points to be located in the right position in the wind tunnel + mesh_points -= mesh_points.min(axis=0) + mesh_extents = mesh_points.max(axis=0) + length_phys_unit = mesh_extents.max() + length_lbm_unit = self.grid_shape[0] / 4 + dx = length_phys_unit / length_lbm_unit + shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0]) + car = mesh_points + shift + self.grid_spacing = dx return inlet, outlet, walls, car @@ -83,7 +86,7 @@ def setup_boundary_conditions(self, wind_speed): bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = FullwayBounceBackBC(indices=car) + bc_car = FullwayBounceBackBC(mesh_points=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] def setup_boundary_masker(self): @@ -92,7 +95,17 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask, (0, 0, 0)) + mesh_boundary_masker = MeshBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.backend, + ) + bclist_other = self.boundary_conditions[:-1] + bc_mesh = self.boundary_conditions[-1] + dx = self.grid_spacing + origin, spacing = (0, 0, 0), (dx, dx, dx) + self.boundary_map, self.missing_mask = indices_boundary_masker(bclist_other, self.boundary_map, self.missing_mask) + self.boundary_map, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.boundary_map, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 4afadfb..96b733f 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -35,6 +35,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): super().__init__( ImplementationStep.STREAMING, @@ -42,6 +43,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index eb28ee8..ff55cf0 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -40,6 +40,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): # Store the equilibrium information self.rho = rho @@ -56,6 +57,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index b019320..777b1ec 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -46,6 +46,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): # Call the parent constructor super().__init__( @@ -54,6 +55,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) # find and store the normal vector using indices diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index d613891..30065e4 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -35,6 +35,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): super().__init__( ImplementationStep.COLLISION, @@ -42,6 +43,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 94cf31b..3b26810 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -37,6 +37,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): # Call the parent constructor super().__init__( @@ -45,6 +46,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index a2ec0e1..3c71c96 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -49,6 +49,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): # Call the parent constructor super().__init__( @@ -58,6 +59,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) # The operator to compute the momentum flux diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 7a848eb..5153ced 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -43,6 +43,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). @@ -59,6 +60,7 @@ def __init__( precision_policy, compute_backend, indices, + mesh_points, ) # Set the prescribed value for pressure or velocity diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 62b2304..79963a0 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -33,6 +33,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, + mesh_points=None, ): velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy @@ -42,6 +43,7 @@ def __init__( # Set the BC indices self.indices = indices + self.mesh_points = mesh_points # Set the implementation step self.implementation_step = implementation_step diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index 262e638..20b16b5 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,6 +1,6 @@ from xlb.operator.boundary_masker.indices_boundary_masker import ( IndicesBoundaryMasker as IndicesBoundaryMasker, ) -from xlb.operator.boundary_masker.stl_boundary_masker import ( - STLBoundaryMasker as STLBoundaryMasker, +from xlb.operator.boundary_masker.mesh_boundary_masker import ( + MeshBoundaryMasker as MeshBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index b697d34..7c09ac1 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -48,8 +48,9 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non bid = boundary_map[0] for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" + assert bc.mesh_points is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id - local_indices = np.array(bc.indices) + np.array(start_index)[:, np.newaxis] + local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] bid = bid.at[tuple(local_indices)].set(id_number) # if dim == 2: @@ -86,8 +87,8 @@ def kernel2d( # Get local indices index = wp.vec2i() - index[0] = indices[0, ii] + start_index[0] - index[1] = indices[1, ii] + start_index[1] + index[0] = indices[0, ii] - start_index[0] + index[1] = indices[1, ii] - start_index[1] # Check if index is in bounds if index[0] >= 0 and index[0] < missing_mask.shape[1] and index[1] >= 0 and index[1] < missing_mask.shape[2]: @@ -120,9 +121,9 @@ def kernel3d( # Get local indices index = wp.vec3i() - index[0] = indices[0, ii] + start_index[0] - index[1] = indices[1, ii] + start_index[1] - index[2] = indices[2, ii] + start_index[2] + index[0] = indices[0, ii] - start_index[0] + index[1] = indices[1, ii] - start_index[1] + index[2] = indices[2, ii] - start_index[2] # Check if index is in bounds if ( @@ -166,6 +167,7 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No id_list = [] for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' + assert bc.mesh_points is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py similarity index 75% rename from xlb/operator/boundary_masker/stl_boundary_masker.py rename to xlb/operator/boundary_masker/mesh_boundary_masker.py index cec5e72..1ecec13 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -1,7 +1,6 @@ # Base class for all equilibriums import numpy as np -from stl import mesh as np_mesh import warp as wp from xlb.velocity_set.velocity_set import VelocitySet @@ -10,7 +9,7 @@ from xlb.operator.operator import Operator -class STLBoundaryMasker(Operator): +class MeshBoundaryMasker(Operator): """ Operator for creating a boundary missing_mask from an STL file """ @@ -19,15 +18,29 @@ def __init__( self, velocity_set: VelocitySet, precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, + compute_backend: ComputeBackend.WARP, ): # Call super super().__init__(velocity_set, precision_policy, compute_backend) + # Also using Warp kernels for JAX implementation + if self.compute_backend == ComputeBackend.JAX: + self.warp_functional, self.warp_kernel = self._construct_warp() + @Operator.register_backend(ComputeBackend.JAX) - def jax_implementation(self, stl_file, origin, spacing, id_number, boundary_map, missing_mask, start_index=(0, 0, 0)): + def jax_implementation( + self, + bc, + origin, + spacing, + boundary_map, + missing_mask, + start_index=(0, 0, 0), + ): # Use Warp backend even for this particular operation. - boundary_map, missing_mask = self.warp_implementation(stl_file, origin, spacing, id_number, boundary_map, missing_mask, start_index=(0, 0, 0)) + boundary_map, missing_mask = self.warp_implementation( + bc, origin, spacing, wp.array(np.array(boundary_map)), wp.array(np.array(missing_mask)), start_index + ) return wp.to_jax(boundary_map), wp.to_jax(missing_mask) def _construct_warp(self): @@ -91,17 +104,22 @@ def kernel( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation( self, - stl_file, + bc, origin, spacing, - id_number, boundary_map, missing_mask, start_index=(0, 0, 0), ): - # Load the mesh - mesh = np_mesh.Mesh.from_file(stl_file) - mesh_points = mesh.points.reshape(-1, 3) + assert bc.mesh_points is not None, f'Please provide the mesh points for {bc.__class__.__name__} BC using keyword "mesh_points"!' + assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" + assert bc.mesh_points.shape[1] == self.velocity_set.d, "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + mesh_points = bc.mesh_points + id_number = bc.id + + # We are done with bc.mesh_points. Remove them from BC objects + bc.__dict__.pop("mesh_points", None) + mesh_indices = np.arange(mesh_points.shape[0]) mesh = wp.Mesh( points=wp.array(mesh_points, dtype=wp.vec3), From 8a2255fa76cece859c649bd9f8fcd14f6a5133a8 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 29 Aug 2024 11:53:45 -0400 Subject: [PATCH 092/144] applied changes requested in the PR review --- .../bc_extrapolation_outflow.py | 24 +++++++++---------- .../boundary_condition/boundary_condition.py | 14 ++--------- xlb/operator/stepper/nse_stepper.py | 13 +++++----- 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 2754a1f..d16068b 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -61,7 +61,7 @@ def __init__( # Unpack the two warp functionals needed for this BC! if self.compute_backend == ComputeBackend.WARP: - self.warp_functional_poststream, self.warp_functional_postcollision = self.warp_functional + self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional def _get_normal_vec(self, indices): # Get the frequency count and most common element directly @@ -157,10 +157,10 @@ def get_normal_vectors_3d( # Construct the functionals for this BC @wp.func - def functional_poststream( + def functional( f_pre: Any, f_post: Any, - f_nbr: Any, + f_aux: Any, missing_mask: Any, ): # Post-streaming values are only modified at missing direction @@ -173,7 +173,7 @@ def functional_poststream( return _f @wp.func - def functional_postcollision( + def prepare_bc_auxilary_data( f_pre: Any, f_post: Any, f_aux: Any, @@ -202,7 +202,7 @@ def kernel2d( # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index) - _faux = _f_vec() + _f_aux = _f_vec() # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): @@ -215,13 +215,13 @@ def kernel2d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - _faux[l] = _f_pre[l, pull_index[0], pull_index[1]] + _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both + # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both # collision and streaming? - _f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask) + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -243,7 +243,7 @@ def kernel3d( # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index) - _faux = _f_vec() + _f_aux = _f_vec() # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): @@ -256,13 +256,13 @@ def kernel3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - _faux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] + _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? - _f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask) + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: _f = _f_post @@ -272,7 +272,7 @@ def kernel3d( kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return [functional_poststream, functional_postcollision], kernel + return (functional, prepare_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask): diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index eb6f92f..29ca2db 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -52,16 +52,7 @@ def __init__( _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func - def functional_postcollision( - f_pre: Any, - f_post: Any, - f_aux: Any, - missing_mask: Any, - ): - return f_post - - @wp.func - def functional_poststream( + def prepare_bc_auxilary_data( f_pre: Any, f_post: Any, f_aux: Any, @@ -123,8 +114,7 @@ def _get_thread_data_3d( if self.compute_backend == ComputeBackend.WARP: self._get_thread_data_2d = _get_thread_data_2d self._get_thread_data_3d = _get_thread_data_3d - self.warp_functional_poststream = functional_poststream - self.warp_functional_postcollision = functional_postcollision + self.prepare_bc_auxilary_data = prepare_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask): diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 98224ef..efbf847 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -142,7 +142,7 @@ def apply_post_streaming_bc( f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional_poststream(f_pre, f_post, f_aux, missing_mask) + f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) return f_post @wp.func @@ -160,7 +160,7 @@ def apply_post_collision_bc( elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain - f_post = self.ExtrapolationOutflowBC.warp_functional_postcollision(f_pre, f_post, f_aux, missing_mask) + f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask) return f_post @@ -221,7 +221,7 @@ def get_thread_data_3d( return f_post_collision, _missing_mask @wp.func - def prepare_bc_auxilary_data_2d( + def get_bc_auxilary_data_2d( f_0: wp.array3d(dtype=Any), index: Any, _boundary_id: Any, @@ -244,7 +244,7 @@ def prepare_bc_auxilary_data_2d( return f_auxiliary @wp.func - def prepare_bc_auxilary_data_3d( + def get_bc_auxilary_data_3d( f_0: wp.array4d(dtype=Any), index: Any, _boundary_id: Any, @@ -287,7 +287,7 @@ def kernel2d( # Prepare auxilary data for BC (if applicable) _boundary_id = boundary_mask[0, index[0], index[1]] - f_auxiliary = prepare_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) @@ -335,7 +335,7 @@ def kernel3d( # Prepare auxilary data for BC (if applicable) _boundary_id = boundary_mask[0, index[0], index[1], index[2]] - f_auxiliary = prepare_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) @@ -380,6 +380,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): # Setting the Struct attributes and active BC classes based on the BC class names bc_fallback = self.boundary_conditions[0] + # TODO: what if self.boundary_conditions is an empty list e.g. when we have periodic BC all around! for var in vars(bc_struct): if var not in active_bc_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 From 18c8ee581b64a2db037db53a4356db05d14ee34d Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 29 Aug 2024 13:36:08 -0400 Subject: [PATCH 093/144] mesh_boundary_masker working as expected and added to the wind tunnel example. JAX implementation using WARP backend remains! --- examples/cfd/windtunnel_3d.py | 8 ++++++-- xlb/operator/boundary_masker/mesh_boundary_masker.py | 10 ++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 696cb7b..9bf1869 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -9,6 +9,8 @@ FullwayBounceBackBC, EquilibriumBC, DoNothingBC, + RegularizedBC, + HalfwayBounceBackBC, ExtrapolationOutflowBC, ) from xlb.operator.macroscopic import Macroscopic @@ -75,7 +77,7 @@ def define_boundary_indices(self): length_phys_unit = mesh_extents.max() length_lbm_unit = self.grid_shape[0] / 4 dx = length_phys_unit / length_lbm_unit - shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0]) + shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0.0]) car = mesh_points + shift self.grid_spacing = dx @@ -84,9 +86,11 @@ def define_boundary_indices(self): def setup_boundary_conditions(self, wind_speed): inlet, outlet, walls, car = self.define_boundary_indices() bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) + # bc_left = RegularizedBC('velocity', (wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = FullwayBounceBackBC(mesh_points=car) + bc_car = HalfwayBounceBackBC(mesh_points=car) + # bc_car = FullwayBounceBackBC(mesh_points=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] def setup_boundary_masker(self): diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 1ecec13..4c8e003 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -2,7 +2,7 @@ import numpy as np import warp as wp - +import jax from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend @@ -37,10 +37,12 @@ def jax_implementation( missing_mask, start_index=(0, 0, 0), ): + raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") # Use Warp backend even for this particular operation. - boundary_map, missing_mask = self.warp_implementation( - bc, origin, spacing, wp.array(np.array(boundary_map)), wp.array(np.array(missing_mask)), start_index - ) + wp.init() + boundary_map = wp.from_jax(boundary_map) + missing_mask = wp.from_jax(missing_mask) + boundary_map, missing_mask = self.warp_implementation(bc, origin, spacing, boundary_map, missing_mask, start_index) return wp.to_jax(boundary_map), wp.to_jax(missing_mask) def _construct_warp(self): From 0adb6f7e8e2ffca01d1a81004cd004c2902460cb Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 29 Aug 2024 14:30:38 -0400 Subject: [PATCH 094/144] added drag and force calculation to the windtunnel example. --- examples/cfd/windtunnel_3d.py | 87 +++++++++++++++---- .../boundary_condition/bc_do_nothing.py | 4 +- .../boundary_condition/bc_equilibrium.py | 4 +- .../bc_extrapolation_outflow.py | 4 +- .../bc_fullway_bounce_back.py | 4 +- .../bc_halfway_bounce_back.py | 4 +- .../boundary_condition/bc_regularized.py | 4 +- xlb/operator/boundary_condition/bc_zouhe.py | 4 +- .../boundary_condition/boundary_condition.py | 4 +- .../indices_boundary_masker.py | 4 +- .../boundary_masker/mesh_boundary_masker.py | 16 ++-- xlb/operator/force/momentum_transfer.py | 2 +- 12 files changed, 100 insertions(+), 41 deletions(-) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 9bf1869..8395579 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -13,12 +13,14 @@ HalfwayBounceBackBC, ExtrapolationOutflowBC, ) +from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np import jax.numpy as jnp +import matplotlib.pyplot as plt class WindTunnel3D: @@ -39,13 +41,20 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.boundary_conditions = [] # Setup the simulation BC, its initial conditions, and the stepper - self._setup(omega, wind_speed) + self.wind_speed = wind_speed + self.omega = omega + self._setup() - def _setup(self, omega, wind_speed): - self.setup_boundary_conditions(wind_speed) + # Make list to store drag coefficients + self.time_steps = [] + self.drag_coefficients = [] + self.lift_coefficients = [] + + def _setup(self): + self.setup_boundary_conditions() self.setup_boundary_masker() self.initialize_fields() - self.setup_stepper(omega) + self.setup_stepper() def voxelize_stl(self, stl_filename, length_lbm_unit): mesh = trimesh.load_mesh(stl_filename, process=False) @@ -69,28 +78,29 @@ def define_boundary_indices(self): # Load the mesh stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" mesh = trimesh.load_mesh(stl_filename, process=False) - mesh_points = mesh.vertices + mesh_vertices = mesh.vertices # Transform the mesh points to be located in the right position in the wind tunnel - mesh_points -= mesh_points.min(axis=0) - mesh_extents = mesh_points.max(axis=0) + mesh_vertices -= mesh_vertices.min(axis=0) + mesh_extents = mesh_vertices.max(axis=0) length_phys_unit = mesh_extents.max() length_lbm_unit = self.grid_shape[0] / 4 dx = length_phys_unit / length_lbm_unit shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0.0]) - car = mesh_points + shift + car = mesh_vertices + shift self.grid_spacing = dx + self.car_cross_section = np.prod(mesh_extents[1:]) / dx**2 return inlet, outlet, walls, car - def setup_boundary_conditions(self, wind_speed): + def setup_boundary_conditions(self): inlet, outlet, walls, car = self.define_boundary_indices() - bc_left = EquilibriumBC(rho=1.0, u=(wind_speed, 0.0, 0.0), indices=inlet) - # bc_left = RegularizedBC('velocity', (wind_speed, 0.0, 0.0), indices=inlet) + bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet) + # bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = HalfwayBounceBackBC(mesh_points=car) - # bc_car = FullwayBounceBackBC(mesh_points=car) + bc_car = HalfwayBounceBackBC(mesh_vertices=car) + # bc_car = FullwayBounceBackBC(mesh_vertices=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] def setup_boundary_masker(self): @@ -114,10 +124,14 @@ def setup_boundary_masker(self): def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) - def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") + def setup_stepper(self): + self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") def run(self, num_steps, print_interval, post_process_interval=100): + # Setup the operator for computing surface forces at the interface of the specified BC + bc_car = self.boundary_conditions[-1] + self.momentum_transfer = MomentumTransfer(bc_car) + start_time = time.time() for i in range(num_steps): self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) @@ -150,6 +164,49 @@ def post_process(self, i): save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) + # Compute lift and drag + boundary_force = self.momentum_transfer(self.f_0, self.boundary_map, self.missing_mask) + drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane + lift = boundary_force[2] + c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) + c_l = 2.0 * lift / (self.wind_speed**2 * self.car_cross_section) + self.drag_coefficients.append(c_d) + self.lift_coefficients.append(c_l) + self.time_steps.append(i) + + # Save monitor plot + self.plot_drag_coefficient() + return + + def plot_drag_coefficient(self): + # Compute moving average of drag coefficient, 100, 1000, 10000 + drag_coefficients = np.array(self.drag_coefficients) + self.drag_coefficients_ma_10 = np.convolve(drag_coefficients, np.ones(10) / 10, mode="valid") + self.drag_coefficients_ma_100 = np.convolve(drag_coefficients, np.ones(100) / 100, mode="valid") + self.drag_coefficients_ma_1000 = np.convolve(drag_coefficients, np.ones(1000) / 1000, mode="valid") + self.drag_coefficients_ma_10000 = np.convolve(drag_coefficients, np.ones(10000) / 10000, mode="valid") + self.drag_coefficients_ma_100000 = np.convolve(drag_coefficients, np.ones(100000) / 100000, mode="valid") + + # Plot drag coefficient + plt.plot(self.time_steps, drag_coefficients, label="Raw") + if len(self.time_steps) > 10: + plt.plot(self.time_steps[9:], self.drag_coefficients_ma_10, label="MA 10") + if len(self.time_steps) > 100: + plt.plot(self.time_steps[99:], self.drag_coefficients_ma_100, label="MA 100") + if len(self.time_steps) > 1000: + plt.plot(self.time_steps[999:], self.drag_coefficients_ma_1000, label="MA 1,000") + if len(self.time_steps) > 10000: + plt.plot(self.time_steps[9999:], self.drag_coefficients_ma_10000, label="MA 10,000") + if len(self.time_steps) > 100000: + plt.plot(self.time_steps[99999:], self.drag_coefficients_ma_100000, label="MA 100,000") + + plt.ylim(-1.0, 1.0) + plt.legend() + plt.xlabel("Time step") + plt.ylabel("Drag coefficient") + plt.savefig("drag_coefficient_ma.png") + plt.close() + if __name__ == "__main__": # Grid parameters diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 96b733f..6e8d317 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -35,7 +35,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): super().__init__( ImplementationStep.STREAMING, @@ -43,7 +43,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index ff55cf0..6853c0e 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -40,7 +40,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): # Store the equilibrium information self.rho = rho @@ -57,7 +57,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 777b1ec..a658be7 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -46,7 +46,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): # Call the parent constructor super().__init__( @@ -55,7 +55,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) # find and store the normal vector using indices diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 30065e4..6af4226 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -35,7 +35,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): super().__init__( ImplementationStep.COLLISION, @@ -43,7 +43,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 3b26810..5c001d9 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -37,7 +37,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): # Call the parent constructor super().__init__( @@ -46,7 +46,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) @Operator.register_backend(ComputeBackend.JAX) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 3c71c96..b74c0b1 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -49,7 +49,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): # Call the parent constructor super().__init__( @@ -59,7 +59,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) # The operator to compute the momentum flux diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 5153ced..56c6868 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -43,7 +43,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). @@ -60,7 +60,7 @@ def __init__( precision_policy, compute_backend, indices, - mesh_points, + mesh_vertices, ) # Set the prescribed value for pressure or velocity diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 79963a0..d0780cf 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -33,7 +33,7 @@ def __init__( precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, - mesh_points=None, + mesh_vertices=None, ): velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy @@ -43,7 +43,7 @@ def __init__( # Set the BC indices self.indices = indices - self.mesh_points = mesh_points + self.mesh_vertices = mesh_vertices # Set the implementation step self.implementation_step = implementation_step diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 7c09ac1..29a96e9 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -48,7 +48,7 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non bid = boundary_map[0] for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" - assert bc.mesh_points is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" + assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] @@ -167,7 +167,7 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No id_list = [] for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' - assert bc.mesh_points is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" + assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 4c8e003..366c9d6 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -113,18 +113,20 @@ def warp_implementation( missing_mask, start_index=(0, 0, 0), ): - assert bc.mesh_points is not None, f'Please provide the mesh points for {bc.__class__.__name__} BC using keyword "mesh_points"!' + assert bc.mesh_vertices is not None, f'Please provide the mesh points for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" - assert bc.mesh_points.shape[1] == self.velocity_set.d, "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" - mesh_points = bc.mesh_points + assert ( + bc.mesh_vertices.shape[1] == self.velocity_set.d + ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + mesh_vertices = bc.mesh_vertices id_number = bc.id - # We are done with bc.mesh_points. Remove them from BC objects - bc.__dict__.pop("mesh_points", None) + # We are done with bc.mesh_vertices. Remove them from BC objects + bc.__dict__.pop("mesh_vertices", None) - mesh_indices = np.arange(mesh_points.shape[0]) + mesh_indices = np.arange(mesh_vertices.shape[0]) mesh = wp.Mesh( - points=wp.array(mesh_points, dtype=wp.vec3), + points=wp.array(mesh_vertices, dtype=wp.vec3), indices=wp.array(mesh_indices, dtype=int), ) diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 264e983..66dba13 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -213,4 +213,4 @@ def warp_implementation(self, f, boundary_map, missing_mask): inputs=[f, boundary_map, missing_mask, force], dim=f.shape[1:], ) - return force.numpy() + return force.numpy()[0] From 60aad2e3fcfd8dd26ae764d1a46c7815a712b13e Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 30 Aug 2024 10:33:40 -0400 Subject: [PATCH 095/144] now indices_boundary_masker also handles boundary geometries in the interior of the domain --- examples/cfd/flow_past_sphere_3d.py | 18 ++++- .../indices_boundary_masker.py | 80 ++++++++++++++++--- 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index ba8bb44..ff39b1a 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -3,13 +3,22 @@ from xlb.precision_policy import PrecisionPolicy from xlb.helper import create_nse_fields, initialize_eq from xlb.operator.stepper import IncompressibleNavierStokesStepper -from xlb.operator.boundary_condition import FullwayBounceBackBC, ZouHeBC, RegularizedBC, EquilibriumBC, DoNothingBC, ExtrapolationOutflowBC +from xlb.operator.boundary_condition import ( + FullwayBounceBackBC, + HalfwayBounceBackBC, + ZouHeBC, + RegularizedBC, + EquilibriumBC, + DoNothingBC, + ExtrapolationOutflowBC, +) from xlb.operator.macroscopic import Macroscopic from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np import jax.numpy as jnp +import time class FlowOverSphere: @@ -69,7 +78,8 @@ def setup_boundary_conditions(self): # bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet) # bc_outlet = DoNothingBC(indices=outlet) bc_outlet = ExtrapolationOutflowBC(indices=outlet) - bc_sphere = FullwayBounceBackBC(indices=sphere) + bc_sphere = HalfwayBounceBackBC(indices=sphere) + self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. @@ -90,12 +100,16 @@ def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") def run(self, num_steps, post_process_interval=100): + start_time = time.time() for i in range(num_steps): self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: self.post_process(i) + end_time = time.time() + print(f"Completing {i} iterations. Time elapsed for 1000 LBM steps in {end_time - start_time:.6f} seconds.") + start_time = time.time() def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 29a96e9..208f50f 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -26,6 +26,23 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) + def are_indices_in_interior(self, indices, shape): + """ + Check if all 2D or 3D indices are inside the bounds of the domain with the given shape and not + at its boundary. + + :param indices: List of tuples, where each tuple contains indices for each dimension. + :param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D. + :return: Boolean flag is_inside indicating whether all indices are inside the bounds. + """ + # Ensure that the number of dimensions in indices matches the domain shape + dim = len(shape) + if len(indices) != dim: + raise ValueError(f"Indices tuple must have {dim} dimensions to match the domain shape.") + + # Check if all indices are within the bounds + return all(0 < idx < shape[d] - 1 for d, idx_list in enumerate(indices) for idx in idx_list) + @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) @@ -45,22 +62,32 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non if start_index is None: start_index = (0,) * dim - bid = boundary_map[0] + bmap = boundary_map[0] + domain_shape = bmap.shape for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] - bid = bid.at[tuple(local_indices)].set(id_number) - # if dim == 2: - # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) - # if dim == 3: - # grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + bmap = bmap.at[tuple(local_indices)].set(id_number) + if self.are_indices_in_interior(bc.indices, domain_shape): + # checking if all indices associated with this BC are in the interior of the domain (not at the boundary). + # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain. + if dim == 2: + grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) + if dim == 3: + grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + + # Assign the boundary id to the push indices + push_indices = local_indices[:, :, None] + self.velocity_set.c[:, None, :] + push_indices = push_indices.reshape(3, -1) + bmap = bmap.at[tuple(push_indices)].set(id_number) + # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - boundary_map = boundary_map.at[0].set(bid) + boundary_map = boundary_map.at[0].set(bmap) grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] @@ -78,6 +105,7 @@ def _construct_warp(self): def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), + is_interior: wp.array1d(dtype=wp.bool), boundary_map: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, @@ -96,22 +124,36 @@ def kernel2d( for l in range(_q): # Get the index of the streaming direction pull_index = wp.vec2i() + push_index = wp.vec2i() for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + push_index[d] = index[d] + _c[d, l] # check if pull index is out of bound # These directions will have missing information after streaming if pull_index[0] < 0 or pull_index[0] >= missing_mask.shape[1] or pull_index[1] < 0 or pull_index[1] >= missing_mask.shape[2]: # Set the missing mask missing_mask[l, index[0], index[1]] = True + boundary_map[0, index[0], index[1]] = id_number[ii] - boundary_map[0, index[0], index[1]] = id_number[ii] + # handling geometries in the interior of the computational domain + elif ( + is_interior[ii] + and push_index[0] >= 0 + and push_index[0] < missing_mask.shape[1] + and push_index[1] >= 0 + and push_index[1] < missing_mask.shape[2] + ): + # Set the missing mask + missing_mask[l, push_index[0], push_index[1]] = True + boundary_map[0, push_index[0], push_index[1]] = id_number[ii] # Construct the warp 3D kernel @wp.kernel def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), + is_interior: wp.array1d(dtype=wp.bool), boundary_map: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, @@ -138,8 +180,10 @@ def kernel3d( for l in range(_q): # Get the index of the streaming direction pull_index = wp.vec3i() + push_index = wp.vec3i() for d in range(self.velocity_set.d): pull_index[d] = index[d] - _c[d, l] + push_index[d] = index[d] + _c[d, l] # check if pull index is out of bound # These directions will have missing information after streaming @@ -153,8 +197,21 @@ def kernel3d( ): # Set the missing mask missing_mask[l, index[0], index[1], index[2]] = True + boundary_map[0, index[0], index[1], index[2]] = id_number[ii] - boundary_map[0, index[0], index[1], index[2]] = id_number[ii] + # handling geometries in the interior of the computational domain + elif ( + is_interior[ii] + and push_index[0] >= 0 + and push_index[0] < missing_mask.shape[1] + and push_index[1] >= 0 + and push_index[1] < missing_mask.shape[2] + and push_index[2] >= 0 + and push_index[2] < missing_mask.shape[3] + ): + # Set the missing mask + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + boundary_map[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d @@ -165,17 +222,21 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] + is_interior = [] for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) + is_interior += [self.are_indices_in_interior(bc.indices, boundary_map[0].shape)] * len(bc.indices[0]) + # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) indices = wp.array2d(index_list, dtype=wp.int32) id_number = wp.array1d(id_list, dtype=wp.uint8) + is_interior = wp.array1d(is_interior, dtype=wp.bool) if start_index is None: start_index = (0,) * dim @@ -186,6 +247,7 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No inputs=[ indices, id_number, + is_interior, boundary_map, missing_mask, start_index, From 2a3e6c601a1bbae8bf81dde9c0cc03c4156d320b Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 30 Aug 2024 12:40:32 -0400 Subject: [PATCH 096/144] Fixing the syntax error in Warp when bc list is empty. When empty, perf is really bad, otherwise normal. --- examples/cfd/flow_past_sphere_3d.py | 2 +- xlb/operator/stepper/nse_stepper.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index ff39b1a..7d07e05 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -128,7 +128,7 @@ def post_process(self, i): fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]} - save_fields_vtk(fields, timestep=i) + # save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index ad9831f..ba3e294 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -14,6 +14,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep +from xlb.operator.boundary_condition import DoNothingBC as DummyBC class IncompressibleNavierStokesStepper(Stepper): @@ -378,9 +379,16 @@ def warp_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) active_bc_list.append("id_" + bc_name) - # Setting the Struct attributes and active BC classes based on the BC class names - bc_fallback = self.boundary_conditions[0] - # TODO: what if self.boundary_conditions is an empty list e.g. when we have periodic BC all around! + # Check if boundary_conditions is an empty list (e.g. all periodic and no BC) + # TODO: There is a huge issue here with perf. when boundary_conditions list + # is empty and is initialized with a dummy BC. If it is not empty, no perf + # loss ocurrs. The following code at least prevents syntax error for periodic examples. + if self.boundary_conditions: + bc_dummy = self.boundary_conditions[0] + else: + bc_dummy = DummyBC() + + # Setting the Struct attributes for inactive BC classes for var in vars(bc_struct): if var not in active_bc_list and not var.startswith("_"): # set unassigned boundaries to the maximum integer in uint8 @@ -388,7 +396,7 @@ def warp_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not # produce error when a particular BC is not used in an example. - setattr(self, var.replace("id_", ""), bc_fallback) + setattr(self, var.replace("id_", ""), bc_dummy) # Launch the warp kernel wp.launch( From 7a7399dec7feb457f98e859a5c99278aa719aa60 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 6 Sep 2024 16:47:32 -0400 Subject: [PATCH 097/144] fixing the issue for cases where internal geometry contacts the domain boundary --- .../indices_boundary_masker.py | 25 +++++++++++-------- xlb/operator/force/momentum_transfer.py | 22 ++++++++-------- xlb/operator/stepper/nse_stepper.py | 2 -- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 208f50f..6c58ff5 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -28,20 +28,21 @@ def __init__( def are_indices_in_interior(self, indices, shape): """ - Check if all 2D or 3D indices are inside the bounds of the domain with the given shape and not + Check if each 2D or 3D index is inside the bounds of the domain with the given shape and not at its boundary. :param indices: List of tuples, where each tuple contains indices for each dimension. :param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D. - :return: Boolean flag is_inside indicating whether all indices are inside the bounds. + :return: List of boolean flags where each flag indicates whether the corresponding index is inside the bounds. """ # Ensure that the number of dimensions in indices matches the domain shape dim = len(shape) if len(indices) != dim: raise ValueError(f"Indices tuple must have {dim} dimensions to match the domain shape.") - # Check if all indices are within the bounds - return all(0 < idx < shape[d] - 1 for d, idx_list in enumerate(indices) for idx in idx_list) + # Check each index tuple and return a list of boolean flags + flags = [all(0 < idx[d] < shape[d] - 1 for d in range(dim)) for idx in np.array(indices).T] + return flags @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! @@ -54,24 +55,25 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non pad_x, pad_y, pad_z = nDevices, 1, 1 if dim == 2: grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) + bmap = jnp.pad(boundary_map[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) if dim == 3: grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) + bmap = jnp.pad(boundary_map[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) # shift indices shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) if start_index is None: start_index = (0,) * dim - bmap = boundary_map[0] - domain_shape = bmap.shape + domain_shape = boundary_map[0].shape for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] - bmap = bmap.at[tuple(local_indices)].set(id_number) - if self.are_indices_in_interior(bc.indices, domain_shape): + bmap = bmap.at[tuple(padded_indices)].set(id_number) + if any(self.are_indices_in_interior(bc.indices, domain_shape)): # checking if all indices associated with this BC are in the interior of the domain (not at the boundary). # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain. if dim == 2: @@ -80,19 +82,20 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) # Assign the boundary id to the push indices - push_indices = local_indices[:, :, None] + self.velocity_set.c[:, None, :] + push_indices = padded_indices[:, :, None] + self.velocity_set.c[:, None, :] push_indices = push_indices.reshape(3, -1) bmap = bmap.at[tuple(push_indices)].set(id_number) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - boundary_map = boundary_map.at[0].set(bmap) grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] + boundary_map = boundary_map.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y]) if dim == 3: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] + boundary_map = boundary_map.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) return boundary_map, missing_mask def _construct_warp(self): @@ -229,7 +232,7 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) - is_interior += [self.are_indices_in_interior(bc.indices, boundary_map[0].shape)] * len(bc.indices[0]) + is_interior += self.are_indices_in_interior(bc.indices, boundary_map[0].shape) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 66dba13..5b6db25 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -140,11 +140,12 @@ def kernel2d( f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) # Compute the momentum transfer - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + for d in range(self.velocity_set.d): + m[d] = self.compute_dtype(0.0) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] + m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) wp.atomic_add(force, 0, m) @@ -189,11 +190,12 @@ def kernel3d( f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) # Compute the momentum transfer - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + for d in range(self.velocity_set.d): + m[d] = self.compute_dtype(0.0) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] + m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) wp.atomic_add(force, 0, m) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index ba3e294..d45d595 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -92,8 +92,6 @@ def _construct_warp(self): _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool _c = self.velocity_set.wp_c _q = self.velocity_set.q - _opp_indices = self.velocity_set.wp_opp_indices - sound_speed = 1.0 / wp.sqrt(3.0) @wp.struct class BoundaryConditionIDStruct: From 431b855707bb2d0d52be4dd0c8c015eda5808e33 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 6 Sep 2024 18:08:28 -0400 Subject: [PATCH 098/144] minor renaming for clarity --- .../boundary_condition/bc_do_nothing.py | 8 ++-- .../boundary_condition/bc_equilibrium.py | 8 ++-- .../bc_extrapolation_outflow.py | 12 ++--- .../bc_fullway_bounce_back.py | 8 ++-- .../bc_halfway_bounce_back.py | 8 ++-- .../boundary_condition/bc_regularized.py | 8 ++-- xlb/operator/boundary_condition/bc_zouhe.py | 8 ++-- .../boundary_condition/boundary_condition.py | 8 ++-- xlb/operator/force/momentum_transfer.py | 8 ++-- xlb/operator/stepper/nse_stepper.py | 48 +++++++++---------- 10 files changed, 62 insertions(+), 62 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 6e8d317..77e3e97 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -75,10 +75,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(DoNothingBC.id): + if _boundary_id == wp.uint8(DoNothingBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -101,10 +101,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(DoNothingBC.id): + if _boundary_id == wp.uint8(DoNothingBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 6853c0e..e7f4a7c 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -100,10 +100,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(EquilibriumBC.id): + if _boundary_id == wp.uint8(EquilibriumBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -126,10 +126,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(EquilibriumBC.id): + if _boundary_id == wp.uint8(EquilibriumBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 55f094d..cd53a8f 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -203,11 +203,11 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data - if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -220,7 +220,7 @@ def kernel2d( _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] # Apply the boundary condition - if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both # collision and streaming? _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) @@ -244,11 +244,11 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data - if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -261,7 +261,7 @@ def kernel3d( _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] # Apply the boundary condition - if _boundary_map == wp.uint8(ExtrapolationOutflowBC.id): + if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 6af4226..d11eead 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -84,10 +84,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Check if the boundary is active - if _boundary_map == wp.uint8(FullwayBounceBackBC.id): + if _boundary_id == wp.uint8(FullwayBounceBackBC.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -110,10 +110,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Check if the boundary is active - if _boundary_map == wp.uint8(FullwayBounceBackBC.id): + if _boundary_id == wp.uint8(FullwayBounceBackBC.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 5c001d9..a663252 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -96,10 +96,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -122,10 +122,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(HalfwayBounceBackBC.id): + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index b74c0b1..674778d 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -333,10 +333,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(self.id): + if _boundary_id == wp.uint8(self.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -359,10 +359,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(self.id): + if _boundary_id == wp.uint8(self.id): _f_aux = _f_vec() _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 56c6868..c1a9b7d 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -338,10 +338,10 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(self.id): + if _boundary_id == wp.uint8(self.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: @@ -364,10 +364,10 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_map, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) # Apply the boundary condition - if _boundary_map == wp.uint8(self.id): + if _boundary_id == wp.uint8(self.id): _f_aux = _f_post _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) else: diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 2cf6d67..a4c2181 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -73,7 +73,7 @@ def _get_thread_data_2d( # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_map = boundary_map[0, index[0], index[1]] + _boundary_id = boundary_map[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -85,7 +85,7 @@ def _get_thread_data_2d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_map, _missing_mask + return _f_pre, _f_post, _boundary_id, _missing_mask @wp.func def _get_thread_data_3d( @@ -98,7 +98,7 @@ def _get_thread_data_3d( # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_map = boundary_map[0, index[0], index[1], index[2]] + _boundary_id = boundary_map[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -110,7 +110,7 @@ def _get_thread_data_3d( _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_map, _missing_mask + return _f_pre, _f_post, _boundary_id, _missing_mask # Construct some helper warp functions for getting tid data if self.compute_backend == ComputeBackend.WARP: diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 5b6db25..4455ee5 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -112,7 +112,7 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id - _boundary_map = boundary_map[0, index[0], index[1]] + _boundary_id = boundary_map[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -123,7 +123,7 @@ def kernel2d( # Determin if boundary is an edge by checking if center is missing is_edge = wp.bool(False) - if _boundary_map == wp.uint8(_no_slip_id): + if _boundary_id == wp.uint8(_no_slip_id): if _missing_mask[_zero_index] == wp.uint8(0): is_edge = wp.bool(True) @@ -162,7 +162,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id - _boundary_map = boundary_map[0, index[0], index[1], index[2]] + _boundary_id = boundary_map[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -173,7 +173,7 @@ def kernel3d( # Determin if boundary is an edge by checking if center is missing is_edge = wp.bool(False) - if _boundary_map == wp.uint8(_no_slip_id): + if _boundary_id == wp.uint8(_no_slip_id): if _missing_mask[_zero_index] == wp.uint8(0): is_edge = wp.bool(True) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index d45d595..1e3bbf7 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -114,32 +114,32 @@ def apply_post_streaming_bc( f_post: Any, f_aux: Any, missing_mask: Any, - _boundary_map: Any, + _boundary_id: Any, bc_struct: Any, ): # Apply post-streaming type boundary conditions - if _boundary_map == bc_struct.id_EquilibriumBC: + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_DoNothingBC: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_HalfwayBounceBackBC: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_ZouHeBC_velocity: + elif _boundary_id == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_ZouHeBC_pressure: + elif _boundary_id == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_RegularizedBC_velocity: + elif _boundary_id == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_RegularizedBC_pressure: + elif _boundary_id == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_ExtrapolationOutflowBC: + elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) return f_post @@ -150,13 +150,13 @@ def apply_post_collision_bc( f_post: Any, f_aux: Any, missing_mask: Any, - _boundary_map: Any, + _boundary_id: Any, bc_struct: Any, ): - if _boundary_map == bc_struct.id_FullwayBounceBackBC: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) - elif _boundary_map == bc_struct.id_ExtrapolationOutflowBC: + elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask) @@ -223,13 +223,13 @@ def get_thread_data_3d( def get_bc_auxilary_data_2d( f_0: wp.array3d(dtype=Any), index: Any, - _boundary_map: Any, + _boundary_id: Any, _missing_mask: Any, bc_struct: Any, ): # special preparation of auxiliary data f_auxiliary = _f_vec() - if _boundary_map == bc_struct.id_ExtrapolationOutflowBC: + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_2d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -246,13 +246,13 @@ def get_bc_auxilary_data_2d( def get_bc_auxilary_data_3d( f_0: wp.array4d(dtype=Any), index: Any, - _boundary_map: Any, + _boundary_id: Any, _missing_mask: Any, bc_struct: Any, ): # special preparation of auxiliary data f_auxiliary = _f_vec() - if _boundary_map == bc_struct.id_ExtrapolationOutflowBC: + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): @@ -285,11 +285,11 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_map = boundary_map[0, index[0], index[1]] - f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_map, _missing_mask, bc_struct) + _boundary_id = boundary_map[0, index[0], index[1]] + f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_map, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -306,7 +306,7 @@ def kernel2d( ) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_map, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -333,11 +333,11 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_map = boundary_map[0, index[0], index[1], index[2]] - f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_map, _missing_mask, bc_struct) + _boundary_id = boundary_map[0, index[0], index[1], index[2]] + f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_map, bc_struct) + f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -349,7 +349,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_map, bc_struct) + f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): From 721dc0ecfc8cd36a0224eccfe6f76f950317e792 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 9 Sep 2024 15:34:27 -0400 Subject: [PATCH 099/144] renaming boundary_map to bc_id --- examples/cfd/flow_past_sphere_3d.py | 6 ++-- examples/cfd/lid_driven_cavity_2d.py | 6 ++-- examples/cfd/windtunnel_3d.py | 10 +++--- .../flow_past_sphere.py | 12 +++---- .../cfd_old_to_be_migrated/taylor_green.py | 6 ++-- examples/performance/mlups_3d.py | 12 +++---- .../bc_equilibrium/test_bc_equilibrium_jax.py | 6 ++-- .../test_bc_equilibrium_warp.py | 6 ++-- .../test_bc_fullway_bounce_back_jax.py | 6 ++-- .../test_bc_fullway_bounce_back_warp.py | 6 ++-- .../mask/test_bc_indices_masker_jax.py | 24 ++++++------- .../mask/test_bc_indices_masker_warp.py | 28 +++++++-------- xlb/helper/nse_solver.py | 4 +-- .../boundary_condition/bc_do_nothing.py | 16 ++++----- .../boundary_condition/bc_equilibrium.py | 16 ++++----- .../bc_extrapolation_outflow.py | 20 +++++------ .../bc_fullway_bounce_back.py | 16 ++++----- .../bc_halfway_bounce_back.py | 16 ++++----- .../boundary_condition/bc_regularized.py | 16 ++++----- xlb/operator/boundary_condition/bc_zouhe.py | 16 ++++----- .../boundary_condition/boundary_condition.py | 10 +++--- .../indices_boundary_masker.py | 34 +++++++++---------- .../boundary_masker/mesh_boundary_masker.py | 26 +++++++------- xlb/operator/force/momentum_transfer.py | 20 +++++------ xlb/operator/stepper/nse_stepper.py | 20 +++++------ 25 files changed, 179 insertions(+), 179 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 7d07e05..f1ff7c3 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -34,7 +34,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -91,7 +91,7 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask, (0, 0, 0)) + self.bc_id, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_id, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -102,7 +102,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 16fb4f9..277c51e 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -24,7 +24,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -57,7 +57,7 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_map, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_map, self.missing_mask) + self.bc_id, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_id, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -67,7 +67,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 8395579..891e2dd 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -36,7 +36,7 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_map = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -118,8 +118,8 @@ def setup_boundary_masker(self): bc_mesh = self.boundary_conditions[-1] dx = self.grid_spacing origin, spacing = (0, 0, 0), (dx, dx, dx) - self.boundary_map, self.missing_mask = indices_boundary_masker(bclist_other, self.boundary_map, self.missing_mask) - self.boundary_map, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.boundary_map, self.missing_mask) + self.bc_id, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_id, self.missing_mask) + self.bc_id, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_id, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -134,7 +134,7 @@ def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_map, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: @@ -165,7 +165,7 @@ def post_process(self, i): save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) # Compute lift and drag - boundary_force = self.momentum_transfer(self.f_0, self.boundary_map, self.missing_mask) + boundary_force = self.momentum_transfer(self.f_0, self.bc_id, self.missing_mask) drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane lift = boundary_force[2] c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py index 7214130..ca404a9 100644 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -75,7 +75,7 @@ def warp_implementation(self, rho, u, vel): u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_map = grid.create_field(cardinality=1, dtype=wp.uint8) + bc_id = grid.create_field(cardinality=1, dtype=wp.uint8) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators @@ -154,19 +154,19 @@ def warp_implementation(self, rho, u, vel): indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_map, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_map, missing_mask, (0, 0, 0)) + bc_id, missing_mask = indices_boundary_masker(indices, half_way_bc.id, bc_id, missing_mask, (0, 0, 0)) # Set inlet bc lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) - boundary_map, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_map, missing_mask, (0, 0, 0)) + bc_id, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, bc_id, missing_mask, (0, 0, 0)) # Set outlet bc lower_bound = (nr - 1, 0, 0) upper_bound = (nr - 1, nr, nr) direction = (-1, 0, 0) - boundary_map, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_map, missing_mask, (0, 0, 0)) + bc_id, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, bc_id, missing_mask, (0, 0, 0)) # Set initial conditions rho, u = initializer(rho, u, vel) @@ -181,7 +181,7 @@ def warp_implementation(self, rho, u, vel): num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_map, missing_mask, _) + f1 = stepper(f0, f1, bc_id, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) @@ -191,7 +191,7 @@ def warp_implementation(self, rho, u, vel): plt.imshow(u[0, :, nr // 2, :].numpy()) plt.colorbar() plt.subplot(1, 2, 2) - plt.imshow(boundary_map[0, :, nr // 2, :].numpy()) + plt.imshow(bc_id[0, :, nr // 2, :].numpy()) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py index c5b40b7..970a246 100644 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -113,7 +113,7 @@ def run_taylor_green(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - boundary_map = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + bc_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -149,10 +149,10 @@ def run_taylor_green(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, boundary_map, missing_mask, _) + f1 = stepper(f0, f1, bc_id, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, boundary_map, missing_mask, _) + f0 = stepper(f0, bc_id, missing_mask, _) # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 602e741..52be973 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -42,9 +42,9 @@ def setup_simulation(args): def create_grid_and_fields(cube_edge): grid_shape = (cube_edge, cube_edge, cube_edge) - grid, f_0, f_1, missing_mask, boundary_map = create_nse_fields(grid_shape) + grid, f_0, f_1, missing_mask, bc_id = create_nse_fields(grid_shape) - return grid, f_0, f_1, missing_mask, boundary_map + return grid, f_0, f_1, missing_mask, bc_id def define_boundary_indices(grid): @@ -67,7 +67,7 @@ def setup_boundary_conditions(grid): return [bc_top, bc_walls] -def run(f_0, f_1, backend, grid, boundary_map, missing_mask, num_steps): +def run(f_0, f_1, backend, grid, bc_id, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, grid, boundary_map, missing_mask, num_steps): start_time = time.time() for i in range(num_steps): - f_1 = stepper(f_0, f_1, boundary_map, missing_mask, i) + f_1 = stepper(f_0, f_1, bc_id, missing_mask, i) f_0, f_1 = f_1, f_0 wp.synchronize() @@ -98,10 +98,10 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time): def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) - grid, f_0, f_1, missing_mask, boundary_map = create_grid_and_fields(args.cube_edge) + grid, f_0, f_1, missing_mask, bc_id = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run(f_0, f_1, backend, grid, boundary_map, missing_mask, args.num_steps) + elapsed_time = run(f_0, f_1, backend, grid, bc_id, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 9d2e4ff..6f1d7d8 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -32,7 +32,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -58,7 +58,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices=indices, ) - boundary_map, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_map, missing_mask, start_index=None) + bc_id, missing_mask = indices_boundary_masker([equilibrium_bc], bc_id, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -66,7 +66,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_map, missing_mask) + f = equilibrium_bc(f_pre, f_post, bc_id, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 917e7e4..7414a3b 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -31,7 +31,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -58,7 +58,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices=indices, ) - boundary_map, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_map, missing_mask, start_index=None) + bc_id, missing_mask = indices_boundary_masker([equilibrium_bc], bc_id, missing_mask, start_index=None) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -66,7 +66,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, boundary_map, missing_mask) + f = equilibrium_bc(f_pre, f_post, bc_id, missing_mask) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 2fe0b40..9704ff6 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -34,7 +34,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -54,7 +54,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_map, missing_mask = indices_boundary_masker([fullway_bc], boundary_map, missing_mask, start_index=None) + bc_id, missing_mask = indices_boundary_masker([fullway_bc], bc_id, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0) # Generate a random field with the same shape @@ -67,7 +67,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = fullway_bc(f_pre, f_post, boundary_map, missing_mask) + f = fullway_bc(f_pre, f_post, bc_id, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index b25d39e..72c0495 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -34,7 +34,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -54,7 +54,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_map, missing_mask = indices_boundary_masker([fullway_bc], boundary_map, missing_mask, start_index=None) + bc_id, missing_mask = indices_boundary_masker([fullway_bc], bc_id, missing_mask, start_index=None) # Generate a random field with the same shape random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) @@ -65,7 +65,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, boundary_map, missing_mask) + f_pre = fullway_bc(f_pre, f_post, bc_id, missing_mask) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index af121d3..7bc3b9c 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -34,7 +34,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -56,26 +56,26 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_map, missing_mask = indices_boundary_masker([test_bc], boundary_map, missing_mask, start_index=None) + bc_id, missing_mask = indices_boundary_masker([test_bc], bc_id, missing_mask, start_index=None) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype - assert boundary_map.dtype == xlb.Precision.UINT8.jax_dtype + assert bc_id.dtype == xlb.Precision.UINT8.jax_dtype - assert boundary_map.shape == (1,) + grid_shape + assert bc_id.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert jnp.all(boundary_map[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the boundary_map is zero - boundary_map = boundary_map.at[0, indices[0], indices[1]].set(0) - assert jnp.all(boundary_map == 0) + assert jnp.all(bc_id[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the bc_id is zero + bc_id = bc_id.at[0, indices[0], indices[1]].set(0) + assert jnp.all(bc_id == 0) if dim == 3: - assert jnp.all(boundary_map[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the boundary_map is zero - boundary_map = boundary_map.at[0, indices[0], indices[1], indices[2]].set(0) - assert jnp.all(boundary_map == 0) + assert jnp.all(bc_id[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the bc_id is zero + bc_id = bc_id.at[0, indices[0], indices[1], indices[2]].set(0) + assert jnp.all(bc_id == 0) if __name__ == "__main__": diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 4d02540..734e428 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -32,7 +32,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - boundary_map = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -54,33 +54,33 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_map, missing_mask = indices_boundary_masker( + bc_id, missing_mask = indices_boundary_masker( [test_bc], - boundary_map, + bc_id, missing_mask, start_index=(0, 0, 0) if dim == 3 else (0, 0), ) assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype - assert boundary_map.dtype == xlb.Precision.UINT8.wp_dtype + assert bc_id.dtype == xlb.Precision.UINT8.wp_dtype - boundary_map = boundary_map.numpy() + bc_id = bc_id.numpy() missing_mask = missing_mask.numpy() - assert boundary_map.shape == (1,) + grid_shape + assert bc_id.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert np.all(boundary_map[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the boundary_map is zero - boundary_map[0, indices[0], indices[1]] = 0 - assert np.all(boundary_map == 0) + assert np.all(bc_id[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the bc_id is zero + bc_id[0, indices[0], indices[1]] = 0 + assert np.all(bc_id == 0) if dim == 3: - assert np.all(boundary_map[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the boundary_map is zero - boundary_map[0, indices[0], indices[1], indices[2]] = 0 - assert np.all(boundary_map == 0) + assert np.all(bc_id[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the bc_id is zero + bc_id[0, indices[0], indices[1], indices[2]] = 0 + assert np.all(bc_id == 0) if __name__ == "__main__": diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index 96befa6..58027ff 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -14,6 +14,6 @@ def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compu f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) - boundary_map = grid.create_field(cardinality=1, dtype=Precision.UINT8) + bc_id = grid.create_field(cardinality=1, dtype=Precision.UINT8) - return grid, f_0, f_1, missing_mask, boundary_map + return grid, f_0, f_1, missing_mask, bc_id diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 77e3e97..54371dc 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -48,8 +48,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): - boundary = boundary_map == self.id + def jax_implementation(self, f_pre, f_post, bc_id, missing_mask): + boundary = bc_id == self.id return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): @@ -67,7 +67,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.uint8), ): # Get the global index @@ -75,7 +75,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): @@ -93,7 +93,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -101,7 +101,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): @@ -119,11 +119,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index e7f4a7c..3e3b132 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -62,11 +62,11 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_id, missing_mask): feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) new_shape = feq.shape + (1,) * self.velocity_set.d feq = lax.broadcast_in_dim(feq, new_shape, [0]) - boundary = boundary_map == self.id + boundary = bc_id == self.id return jnp.where(boundary, feq, f_post) @@ -92,7 +92,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -100,7 +100,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): @@ -118,7 +118,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -126,7 +126,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): @@ -144,11 +144,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index cd53a8f..ee21da6 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -94,13 +94,13 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, bc_id, missing_mask): """ Prepare the auxilary distribution functions for the boundary condition. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) - boundary = boundary_map == self.id + boundary = bc_id == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -123,8 +123,8 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): - boundary = boundary_map == self.id + def apply_jax(self, f_pre, f_post, bc_id, missing_mask): + boundary = bc_id == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -195,7 +195,7 @@ def prepare_bc_auxilary_data( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -203,7 +203,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data @@ -236,7 +236,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -244,7 +244,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data @@ -277,11 +277,11 @@ def kernel3d( return (functional, prepare_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index d11eead..cf1330a 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -48,8 +48,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): - boundary = boundary_map == self.id + def apply_jax(self, f_pre, f_post, bc_id, missing_mask): + boundary = bc_id == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) @@ -77,14 +77,14 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): @@ -102,7 +102,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -110,7 +110,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): @@ -128,11 +128,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index a663252..688f71b 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -51,8 +51,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): - boundary = boundary_map == self.id + def apply_jax(self, f_pre, f_post, bc_id, missing_mask): + boundary = bc_id == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -88,7 +88,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -96,7 +96,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): @@ -114,7 +114,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -122,7 +122,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): @@ -140,11 +140,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 674778d..90634f0 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -105,9 +105,9 @@ def regularize_fpop(self, fpop, feq): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + def apply_jax(self, f_pre, f_post, bc_id, missing_mask): # creat a mask to slice boundary cells - boundary = boundary_map == self.id + boundary = bc_id == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -325,7 +325,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -333,7 +333,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -351,7 +351,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -359,7 +359,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -385,11 +385,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index c1a9b7d..6b7ca35 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -156,9 +156,9 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): + def apply_jax(self, f_pre, f_post, bc_id, missing_mask): # creat a mask to slice boundary cells - boundary = boundary_map == self.id + boundary = bc_id == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -330,7 +330,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -338,7 +338,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -356,7 +356,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -364,7 +364,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_map, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -390,11 +390,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, boundary_map, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, boundary_map, missing_mask], + inputs=[f_pre, f_post, bc_id, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index a4c2181..faeac77 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -66,14 +66,14 @@ def prepare_bc_auxilary_data( def _get_thread_data_2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), index: wp.vec2i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = boundary_map[0, index[0], index[1]] + _boundary_id = bc_id[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -91,14 +91,14 @@ def _get_thread_data_2d( def _get_thread_data_3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), index: wp.vec3i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = boundary_map[0, index[0], index[1], index[2]] + _boundary_id = bc_id[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -119,7 +119,7 @@ def _get_thread_data_3d( self.prepare_bc_auxilary_data = prepare_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_map, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, bc_id, missing_mask): """ A placeholder function for prepare the auxilary distribution functions for the boundary condition. currently being called after collision only. diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 6c58ff5..e54df4b 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -47,7 +47,7 @@ def are_indices_in_interior(self, indices, shape): @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=None): + def jax_implementation(self, bclist, bc_id, missing_mask, start_index=None): # Pad the missing mask to create a grid mask to identify out of bound boundaries # Set padded regin to True (i.e. boundary) dim = missing_mask.ndim - 1 @@ -55,17 +55,17 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non pad_x, pad_y, pad_z = nDevices, 1, 1 if dim == 2: grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) - bmap = jnp.pad(boundary_map[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) + bmap = jnp.pad(bc_id[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) if dim == 3: grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) - bmap = jnp.pad(boundary_map[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) + bmap = jnp.pad(bc_id[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) # shift indices shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) if start_index is None: start_index = (0,) * dim - domain_shape = boundary_map[0].shape + domain_shape = bc_id[0].shape for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" @@ -92,11 +92,11 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] - boundary_map = boundary_map.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y]) + bc_id = bc_id.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y]) if dim == 3: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] - boundary_map = boundary_map.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) - return boundary_map, missing_mask + bc_id = bc_id.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) + return bc_id, missing_mask def _construct_warp(self): # Make constants for warp @@ -109,7 +109,7 @@ def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), is_interior: wp.array1d(dtype=wp.bool), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): @@ -137,7 +137,7 @@ def kernel2d( if pull_index[0] < 0 or pull_index[0] >= missing_mask.shape[1] or pull_index[1] < 0 or pull_index[1] >= missing_mask.shape[2]: # Set the missing mask missing_mask[l, index[0], index[1]] = True - boundary_map[0, index[0], index[1]] = id_number[ii] + bc_id[0, index[0], index[1]] = id_number[ii] # handling geometries in the interior of the computational domain elif ( @@ -149,7 +149,7 @@ def kernel2d( ): # Set the missing mask missing_mask[l, push_index[0], push_index[1]] = True - boundary_map[0, push_index[0], push_index[1]] = id_number[ii] + bc_id[0, push_index[0], push_index[1]] = id_number[ii] # Construct the warp 3D kernel @wp.kernel @@ -157,7 +157,7 @@ def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), is_interior: wp.array1d(dtype=wp.bool), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -200,7 +200,7 @@ def kernel3d( ): # Set the missing mask missing_mask[l, index[0], index[1], index[2]] = True - boundary_map[0, index[0], index[1], index[2]] = id_number[ii] + bc_id[0, index[0], index[1], index[2]] = id_number[ii] # handling geometries in the interior of the computational domain elif ( @@ -214,14 +214,14 @@ def kernel3d( ): # Set the missing mask missing_mask[l, push_index[0], push_index[1], push_index[2]] = True - boundary_map[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] + bc_id[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=None): + def warp_implementation(self, bclist, bc_id, missing_mask, start_index=None): dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] @@ -232,7 +232,7 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) - is_interior += self.are_indices_in_interior(bc.indices, boundary_map[0].shape) + is_interior += self.are_indices_in_interior(bc.indices, bc_id[0].shape) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) @@ -251,11 +251,11 @@ def warp_implementation(self, bclist, boundary_map, missing_mask, start_index=No indices, id_number, is_interior, - boundary_map, + bc_id, missing_mask, start_index, ], dim=indices.shape[1], ) - return boundary_map, missing_mask + return bc_id, missing_mask diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 366c9d6..f3fb7ba 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -33,17 +33,17 @@ def jax_implementation( bc, origin, spacing, - boundary_map, + bc_id, missing_mask, start_index=(0, 0, 0), ): raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") # Use Warp backend even for this particular operation. wp.init() - boundary_map = wp.from_jax(boundary_map) + bc_id = wp.from_jax(bc_id) missing_mask = wp.from_jax(missing_mask) - boundary_map, missing_mask = self.warp_implementation(bc, origin, spacing, boundary_map, missing_mask, start_index) - return wp.to_jax(boundary_map), wp.to_jax(missing_mask) + bc_id, missing_mask = self.warp_implementation(bc, origin, spacing, bc_id, missing_mask, start_index) + return wp.to_jax(bc_id), wp.to_jax(missing_mask) def _construct_warp(self): # Make constants for warp @@ -57,7 +57,7 @@ def kernel( origin: wp.vec3, spacing: wp.vec3, id_number: wp.int32, - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -77,9 +77,9 @@ def kernel( # Compute the maximum length max_length = wp.sqrt( - (spacing[0] * wp.float32(boundary_map.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(boundary_map.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(boundary_map.shape[3])) ** 2.0 + (spacing[0] * wp.float32(bc_id.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(bc_id.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(bc_id.shape[3])) ** 2.0 ) # evaluate if point is inside mesh @@ -98,7 +98,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and missing_mask - boundary_map[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + bc_id[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) missing_mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @@ -109,7 +109,7 @@ def warp_implementation( bc, origin, spacing, - boundary_map, + bc_id, missing_mask, start_index=(0, 0, 0), ): @@ -138,11 +138,11 @@ def warp_implementation( origin, spacing, id_number, - boundary_map, + bc_id, missing_mask, start_index, ], - dim=boundary_map.shape[1:], + dim=bc_id.shape[1:], ) - return boundary_map, missing_mask + return bc_id, missing_mask diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 4455ee5..f0b474d 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -50,13 +50,13 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f, boundary_map, missing_mask): + def jax_implementation(self, f, bc_id, missing_mask): """ Parameters ---------- f : jax.numpy.ndarray The post-collision distribution function at each node in the grid. - boundary_map : jax.numpy.ndarray + bc_id : jax.numpy.ndarray A grid field with 0 everywhere except for boundary nodes which are designated by their respective boundary id's. missing_mask : jax.numpy.ndarray @@ -71,10 +71,10 @@ def jax_implementation(self, f, boundary_map, missing_mask): # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. f_post_collision = f f_post_stream = self.stream(f_post_collision) - f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, boundary_map, missing_mask) + f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_id, missing_mask) # Compute momentum transfer - boundary = boundary_map == self.no_slip_bc_instance.id + boundary = bc_id == self.no_slip_bc_instance.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -103,7 +103,7 @@ def _construct_warp(self): @wp.kernel def kernel2d( f: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=wp.uint8), + bc_id: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), force: wp.array(dtype=Any), ): @@ -112,7 +112,7 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id - _boundary_id = boundary_map[0, index[0], index[1]] + _boundary_id = bc_id[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -153,7 +153,7 @@ def kernel2d( @wp.kernel def kernel3d( f: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=wp.uint8), + bc_id: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), ): @@ -162,7 +162,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id - _boundary_id = boundary_map[0, index[0], index[1], index[2]] + _boundary_id = bc_id[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -205,14 +205,14 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, boundary_map, missing_mask): + def warp_implementation(self, f, bc_id, missing_mask): # Allocate the force vector (the total integral value will be computed) force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f, boundary_map, missing_mask, force], + inputs=[f, bc_id, missing_mask, force], dim=f.shape[1:], ) return force.numpy()[0] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 1e3bbf7..0b3af5b 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -40,7 +40,7 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): + def jax_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ @@ -57,7 +57,7 @@ def jax_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): f_post_stream = bc( f_0, f_post_stream, - boundary_map, + bc_id, missing_mask, ) @@ -72,12 +72,12 @@ def jax_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, boundary_map, missing_mask) + f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, bc_id, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, f_post_collision, - boundary_map, + bc_id, missing_mask, ) @@ -269,7 +269,7 @@ def get_bc_auxilary_data_3d( def kernel2d( f_0: wp.array3d(dtype=Any), f_1: wp.array3d(dtype=Any), - boundary_map: wp.array3d(dtype=Any), + bc_id: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), bc_struct: Any, timestep: int, @@ -285,7 +285,7 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = boundary_map[0, index[0], index[1]] + _boundary_id = bc_id[0, index[0], index[1]] f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions @@ -317,7 +317,7 @@ def kernel2d( def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), - boundary_map: wp.array4d(dtype=Any), + bc_id: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), bc_struct: Any, timestep: int, @@ -333,7 +333,7 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = boundary_map[0, index[0], index[1], index[2]] + _boundary_id = bc_id[0, index[0], index[1], index[2]] f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions @@ -361,7 +361,7 @@ def kernel3d( return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry @@ -402,7 +402,7 @@ def warp_implementation(self, f_0, f_1, boundary_map, missing_mask, timestep): inputs=[ f_0, f_1, - boundary_map, + bc_id, missing_mask, bc_struct, timestep, From aaff2c3fdf71316908b545e92b29e9aa0eac68c3 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 9 Sep 2024 17:10:43 -0400 Subject: [PATCH 100/144] added exact difference forcing scheme --- xlb/operator/collision/__init__.py | 1 + xlb/operator/collision/forced_collision.py | 135 +++++++++++++++++ xlb/operator/collision/kbc.py | 9 +- xlb/operator/force/__init__.py | 1 + xlb/operator/force/exact_difference_force.py | 150 +++++++++++++++++++ xlb/operator/stepper/nse_stepper.py | 6 +- 6 files changed, 297 insertions(+), 5 deletions(-) create mode 100644 xlb/operator/collision/forced_collision.py create mode 100644 xlb/operator/force/exact_difference_force.py diff --git a/xlb/operator/collision/__init__.py b/xlb/operator/collision/__init__.py index b48d0ce..0526c8a 100644 --- a/xlb/operator/collision/__init__.py +++ b/xlb/operator/collision/__init__.py @@ -1,3 +1,4 @@ from xlb.operator.collision.collision import Collision as Collision from xlb.operator.collision.bgk import BGK as BGK from xlb.operator.collision.kbc import KBC as KBC +from xlb.operator.collision.forced_collision import ForcedCollision as ForcedCollision diff --git a/xlb/operator/collision/forced_collision.py b/xlb/operator/collision/forced_collision.py new file mode 100644 index 0000000..31ef392 --- /dev/null +++ b/xlb/operator/collision/forced_collision.py @@ -0,0 +1,135 @@ +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.collision.collision import Collision +from xlb.operator import Operator +from functools import partial +from xlb.operator.force import ExactDifference + + +class ForcedCollision(Collision): + """ + A collision operator for LBM with external force. + """ + + def __init__( + self, + collision_operator: Operator, + forcing_scheme="exact_difference", + force_vector=None, + ): + assert collision_operator is not None + self.collision_operator = collision_operator + super().__init__(self.collision_operator.omega) + + assert forcing_scheme == "exact_difference", NotImplementedError(f"Force model {forcing_scheme} not implemented!") + assert force_vector.shape[0] == self.velocity_set.d, "Check the dimensions of the input force!" + self.force_vector = force_vector + if forcing_scheme == "exact_difference": + self.forcing_operator = ExactDifference(force_vector) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0,)) + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u): + fout = self.collision_operator(f, feq, rho, u) + fout = self.forcing_operator(fout, feq, rho, u) + return fout + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + # Construct the functional + @wp.func + def functional(f: Any, feq: Any, rho: Any, u: Any): + fout = self.collision_operator.warp_functional(f, feq, rho, u) + fout = self.forcing_operator.warp_functional(fout, feq, rho, u) + return fout + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + feq: wp.array3d(dtype=Any), + fout: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) # TODO: Warp needs to fix this + + # Load needed values + _f = _f_vec() + _feq = _f_vec() + _d = self.velocity_set.d + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _feq[l] = feq[l, index[0], index[1]] + _u = _u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1]] + _rho = rho[0, index[0], index[1]] + + # Compute the collision + _fout = functional(_f, _feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1]] = _fout[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + feq: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this + + # Load needed values + _f = _f_vec() + _feq = _f_vec() + _d = self.velocity_set.d + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _feq[l] = feq[l, index[0], index[1], index[2]] + _u = _u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + + # Compute the collision + _fout = functional(_f, _feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1], index[2]] = _fout[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, feq, fout, rho, u): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + f, + feq, + fout, + rho, + u, + ], + dim=f.shape[1:], + ) + return fout diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index ddd7ecc..ac8513b 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -182,6 +182,7 @@ def _construct_warp(self): raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set))) # Set local constants TODO: This is a hack and should be fixed with warp update + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _epsilon = wp.constant(self.compute_dtype(self.epsilon)) _beta = wp.constant(self.compute_dtype(self.beta)) @@ -305,9 +306,9 @@ def functional3d( def kernel2d( f: wp.array3d(dtype=Any), feq: wp.array3d(dtype=Any), + fout: wp.array3d(dtype=Any), rho: wp.array3d(dtype=Any), u: wp.array3d(dtype=Any), - fout: wp.array3d(dtype=Any), ): # Get the global index i, j = wp.tid() @@ -320,7 +321,7 @@ def kernel2d( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1]] _feq[l] = feq[l, index[0], index[1]] - _u = self.warp_u_vec() + _u = _u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1]] _rho = rho[0, index[0], index[1]] @@ -337,9 +338,9 @@ def kernel2d( def kernel3d( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), - fout: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() @@ -352,7 +353,7 @@ def kernel3d( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = self.warp_u_vec() + _u = _u_vec() for l in range(_d): _u[l] = u[l, index[0], index[1], index[2]] _rho = rho[0, index[0], index[1], index[2]] diff --git a/xlb/operator/force/__init__.py b/xlb/operator/force/__init__.py index 6a991ce..2f3e3da 100644 --- a/xlb/operator/force/__init__.py +++ b/xlb/operator/force/__init__.py @@ -1 +1,2 @@ from xlb.operator.force.momentum_transfer import MomentumTransfer as MomentumTransfer +from xlb.operator.force.exact_difference_force import ExactDifference as ExactDifference diff --git a/xlb/operator/force/exact_difference_force.py b/xlb/operator/force/exact_difference_force.py new file mode 100644 index 0000000..431a025 --- /dev/null +++ b/xlb/operator/force/exact_difference_force.py @@ -0,0 +1,150 @@ +from functools import partial +from jax import jit +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.equilibrium import QuadraticEquilibrium + + +class ExactDifference(Operator): + """ + Add external body force based on the exact-difference method due to Kupershtokh (2004) + + References + ---------- + Kupershtokh, A. (2004). New method of incorporating a body force term into the lattice Boltzmann equation. In + Proceedings of the 5th International EHD Workshop (pp. 241-246). University of Poitiers, Poitiers, France. + Chikatamarla, S. S., & Karlin, I. V. (2013). Entropic lattice Boltzmann method for turbulent flow simulations: + Boundary conditions. Physica A, 392, 1925-1930. + Krüger, T., et al. (2017). The lattice Boltzmann method. Springer International Publishing, 10.978-3, 4-15. + """ + + def __init__( + self, + force_vector, + equilibrium: Operator = None, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # TODO: currently we are limited to a single force vector not a spatially dependent forcing field + self.force_vector = force_vector + self.equilibrium = QuadraticEquilibrium() if equilibrium is None else equilibrium + + # Call the parent constructor + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_postcollision, feq, rho, u): + """ + Parameters + ---------- + f_postcollision: jax.numpy.ndarray + The post-collision distribution functions. + feq: jax.numpy.ndarray + The equilibrium distribution functions. + rho: jax.numpy.ndarray + The density field. + + u: jax.numpy.ndarray + The velocity field. + + Returns + ------- + f_postcollision: jax.numpy.ndarray + The post-collision distribution functions with the force applied. + """ + delta_u = self.force_vector + feq_force = self.equilibrium(rho, u + delta_u) + f_postcollision += feq_force - feq + return f_postcollision + + def _construct_warp(self): + _d = self.velocity_set.d + _u_vec = wp.vec(_d, dtype=self.compute_dtype) + if _d == 2: + _force = _u_vec(self.force_vector[0], self.force_vector[1]) + else: + _force = _u_vec(self.force_vector[0], self.force_vector[1], self.force_vector[2]) + + # Construct the functional + @wp.func + def functional(f_postcollision: Any, feq: Any, rho: Any, u: Any): + delta_u = _force + feq_force = self.equilibrium.warp_functional(rho, u + delta_u) + f_postcollision += feq_force - feq + return f_postcollision + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_postcollision: Any, + feq: Any, + fout: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # Load needed values + _u = _u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1]] + _rho = rho[0, index[0], index[1]] + + # Compute the collision + _fout = functional(f_postcollision, feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1]] = _fout[l] + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_postcollision: Any, + feq: Any, + fout: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this + + # Load needed values + _u = _u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + + # Compute the collision + _fout = functional(f_postcollision, feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1], index[2]] = _fout[l] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_postcollision, feq, fout, rho, u): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_postcollision, feq, fout, rho, u], + dim=f_postcollision.shape[1:], + ) + return fout diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 0b3af5b..ab63de6 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -15,10 +15,11 @@ from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from xlb.operator.boundary_condition import DoNothingBC as DummyBC +from xlb.operator.collision import ForcedCollision class IncompressibleNavierStokesStepper(Stepper): - def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): + def __init__(self, omega, boundary_conditions=[], collision_type="BGK", forcing_scheme="exact_difference", force_vector=None): velocity_set = DefaultConfig.velocity_set precision_policy = DefaultConfig.default_precision_policy compute_backend = DefaultConfig.default_backend @@ -29,6 +30,9 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): elif collision_type == "KBC": self.collision = KBC(omega, velocity_set, precision_policy, compute_backend) + if force_vector is not None: + self.collision = ForcedCollision(collision_operator=self.collision, forcing_scheme=forcing_scheme, force_vector=force_vector) + # Construct the operators self.stream = Stream(velocity_set, precision_policy, compute_backend) self.equilibrium = QuadraticEquilibrium(velocity_set, precision_policy, compute_backend) From d96a394ff1f00129122594abc68b1464366f0795 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 10 Sep 2024 18:28:16 -0400 Subject: [PATCH 101/144] ported the benchmark example of turbulent channel flow. --- .../cfd/data/turbulent_channel_dns_data.json | 1 + examples/cfd/turbulent_channel_3d.py | 202 ++++++++++++++++++ xlb/helper/initializers.py | 6 +- xlb/operator/force/exact_difference_force.py | 4 +- 4 files changed, 209 insertions(+), 4 deletions(-) create mode 100644 examples/cfd/data/turbulent_channel_dns_data.json create mode 100644 examples/cfd/turbulent_channel_3d.py diff --git a/examples/cfd/data/turbulent_channel_dns_data.json b/examples/cfd/data/turbulent_channel_dns_data.json new file mode 100644 index 0000000..bdbee75 --- /dev/null +++ b/examples/cfd/data/turbulent_channel_dns_data.json @@ -0,0 +1 @@ +{"y": [0, 0.000301, 0.0012, 0.00271, 0.00482, 0.00752, 0.0108, 0.0147, 0.0192, 0.0243, 0.03, 0.0362, 0.0431, 0.0505, 0.0585, 0.067, 0.0761, 0.0858, 0.096, 0.107, 0.118, 0.13, 0.142, 0.155, 0.169, 0.182, 0.197, 0.212, 0.227, 0.243, 0.259, 0.276, 0.293, 0.31, 0.328, 0.347, 0.366, 0.385, 0.404, 0.424, 0.444, 0.465, 0.486, 0.507, 0.529, 0.55, 0.572, 0.595, 0.617, 0.64, 0.663, 0.686, 0.71, 0.733, 0.757, 0.781, 0.805, 0.829, 0.853, 0.878, 0.902, 0.926, 0.951, 0.975, 1], "y+": [0, 0.053648, 0.21456, 0.48263, 0.85771, 1.3396, 1.9279, 2.6224, 3.4226, 4.328, 5.3381, 6.4523, 7.67, 8.9902, 10.412, 11.936, 13.559, 15.281, 17.102, 19.019, 21.033, 23.141, 25.342, 27.635, 30.019, 32.492, 35.053, 37.701, 40.432, 43.247, 46.143, 49.118, 52.171, 55.3, 58.503, 61.778, 65.123, 68.536, 72.016, 75.559, 79.164, 82.828, 86.55, 90.327, 94.157, 98.037, 101.97, 105.94, 109.96, 114.02, 118.12, 122.25, 126.42, 130.62, 134.84, 139.1, 143.37, 147.67, 151.99, 156.32, 160.66, 165.02, 169.38, 173.75, 178.12], "Umean": [0, 0.053639, 0.21443, 0.48197, 0.85555, 1.3339, 1.9148, 2.5939, 3.3632, 4.2095, 5.1133, 6.0493, 6.9892, 7.9052, 8.7741, 9.579, 10.311, 10.967, 11.55, 12.066, 12.52, 12.921, 13.276, 13.59, 13.87, 14.121, 14.349, 14.557, 14.75, 14.931, 15.101, 15.264, 15.419, 15.569, 15.714, 15.855, 15.993, 16.128, 16.26, 16.389, 16.515, 16.637, 16.756, 16.872, 16.985, 17.094, 17.2, 17.302, 17.4, 17.494, 17.585, 17.672, 17.756, 17.835, 17.911, 17.981, 18.045, 18.103, 18.154, 18.198, 18.235, 18.264, 18.285, 18.297, 18.301], "dUmean/dy": [178, 178, 178, 178, 177, 176, 175, 173, 169, 163, 155, 144, 131, 116, 101, 87.1, 73.9, 62.2, 52.2, 43.8, 36.9, 31.1, 26.4, 22.6, 19.4, 16.9, 14.9, 13.3, 12, 10.9, 10.1, 9.38, 8.79, 8.29, 7.86, 7.49, 7.19, 6.91, 6.63, 6.35, 6.07, 5.81, 5.58, 5.36, 5.14, 4.92, 4.68, 4.45, 4.23, 4.04, 3.85, 3.66, 3.48, 3.28, 3.06, 2.81, 2.54, 2.25, 1.96, 1.67, 1.35, 1.02, 0.673, 0.33, 0], "Wmean": [0, 7.07e-05, 0.000283, 0.000636, 0.00113, 0.00176, 0.00252, 0.00339, 0.00435, 0.00538, 0.00643, 0.00751, 0.00864, 0.00986, 0.0112, 0.0126, 0.0141, 0.0156, 0.017, 0.0181, 0.0186, 0.0184, 0.0176, 0.0163, 0.0149, 0.0135, 0.0124, 0.0116, 0.0107, 0.00966, 0.00843, 0.00695, 0.00519, 0.00329, 0.00145, -0.000284, -0.00177, -0.00292, -0.00377, -0.00445, -0.00497, -0.0054, -0.00594, -0.00681, -0.0082, -0.00996, -0.0119, -0.0139, -0.0163, -0.0191, -0.0225, -0.0263, -0.0306, -0.0354, -0.0405, -0.0455, -0.05, -0.0539, -0.0577, -0.0615, -0.0653, -0.0685, -0.071, -0.0724, -0.0729], "dWmean/dy": [0.235, 0.235, 0.235, 0.234, 0.234, 0.232, 0.228, 0.22, 0.208, 0.194, 0.179, 0.168, 0.164, 0.164, 0.166, 0.167, 0.162, 0.148, 0.121, 0.076, 0.0159, -0.0439, -0.087, -0.107, -0.106, -0.0871, -0.0643, -0.0546, -0.061, -0.0707, -0.0818, -0.0958, -0.108, -0.106, -0.0989, -0.0881, -0.0697, -0.0506, -0.0379, -0.0303, -0.0221, -0.0216, -0.0314, -0.0522, -0.0756, -0.0841, -0.0884, -0.0974, -0.114, -0.136, -0.154, -0.172, -0.196, -0.214, -0.215, -0.199, -0.174, -0.156, -0.155, -0.159, -0.147, -0.118, -0.0788, -0.0387, 0], "Pmean": [6.217e-13, -7.3193e-10, -1.5832e-07, -3.7598e-06, -3.3837e-05, -0.00017683, -0.00065008, -0.001865, -0.0044488, -0.0092047, -0.017023, -0.028777, -0.045228, -0.066952, -0.094281, -0.12724, -0.16551, -0.20842, -0.25498, -0.30396, -0.35398, -0.40362, -0.45163, -0.49698, -0.5388, -0.57639, -0.60919, -0.63686, -0.6593, -0.67652, -0.68867, -0.69613, -0.69928, -0.69854, -0.69444, -0.68744, -0.67802, -0.66675, -0.65429, -0.64131, -0.62817, -0.61487, -0.60122, -0.58703, -0.57221, -0.55678, -0.5409, -0.52493, -0.50917, -0.49371, -0.47867, -0.46421, -0.4505, -0.43759, -0.4255, -0.41436, -0.40444, -0.39595, -0.389, -0.3836, -0.37966, -0.37702, -0.37542, -0.3746, -0.37436]} \ No newline at end of file diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py new file mode 100644 index 0000000..e160fd5 --- /dev/null +++ b/examples/cfd/turbulent_channel_3d.py @@ -0,0 +1,202 @@ +import xlb +import time +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.helper import create_nse_fields, initialize_eq +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import RegularizedBC +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.boundary_masker import IndicesBoundaryMasker +from xlb.utils import save_fields_vtk, save_image +import warp as wp +import numpy as np +import jax.numpy as jnp +import matplotlib.pyplot as plt +import json + + +# helper functions for this benchmark example +def vonKarman_loglaw_wall(yplus): + vonKarmanConst = 0.41 + cplus = 5.5 + uplus = np.log(yplus) / vonKarmanConst + cplus + return uplus + + +def get_dns_data(): + """ + Reference: DNS of Turbulent Channel Flow up to Re_tau=590, 1999, + Physics of Fluids, vol 11, 943-945. + https://turbulence.oden.utexas.edu/data/MKM/chan180/profiles/chan180.means + """ + file_name = "examples/cfd/data/turbulent_channel_dns_data.json" + with open(file_name, "r") as file: + return json.load(file) + + +class TurbulentChannel3D: + def __init__(self, channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, backend, precision_policy): + # initialize backend + xlb.init( + velocity_set=velocity_set, + default_backend=backend, + default_precision_policy=precision_policy, + ) + + self.channel_half_width = channel_half_width + self.Re_tau = Re_tau + self.u_tau = u_tau + self.visc = u_tau * channel_half_width / Re_tau + self.omega = 1.0 / (3.0 * self.visc + 0.5) + # DeltaPlus = Re_tau / channel_half_width + # DeltaPlus = u_tau / nu * Delta where u_tau / nu = Re_tau / channel_half_width + + self.grid_shape = grid_shape + self.velocity_set = velocity_set + self.backend = backend + self.precision_policy = precision_policy + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) + self.stepper = None + self.boundary_conditions = [] + + # Setup the simulation BC, its initial conditions, and the stepper + self._setup() + + def get_force(self): + # define the external force + shape = (self.velocity_set.d,) + force = np.zeros(shape) + force[0] = self.Re_tau**2 * self.visc**2 / self.channel_half_width**3 + return force + + def _setup(self): + self.setup_boundary_conditions() + self.setup_boundary_masker() + self.initialize_fields() + self.setup_stepper() + + def define_boundary_indices(self): + # top and bottom sides of the channel are no-slip and the other directions are periodic + walls = [self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["top"][i] for i in range(self.velocity_set.d)] + return walls + + def setup_boundary_conditions(self): + walls = self.define_boundary_indices() + bc_walls = RegularizedBC("velocity", (0.0, 0.0, 0.0), indices=walls) + self.boundary_conditions = [bc_walls] + + def setup_boundary_masker(self): + indices_boundary_masker = IndicesBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.backend, + ) + self.bc_id, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_id, self.missing_mask) + + def initialize_fields(self): + shape = (self.velocity_set.d,) + (self.grid_shape) + np.random.seed(0) + u_init = np.random.random(shape) + if self.backend == ComputeBackend.JAX: + u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init) + else: + u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend, u=u_init) + + def setup_stepper(self): + force = self.get_force() + self.stepper = IncompressibleNavierStokesStepper( + self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC", forcing_scheme="exact_difference", force_vector=force + ) + + def run(self, num_steps, print_interval, post_process_interval=100): + start_time = time.time() + for i in range(num_steps): + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) + self.f_0, self.f_1 = self.f_1, self.f_0 + + if (i + 1) % print_interval == 0: + elapsed_time = time.time() - start_time + print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") + + if i % post_process_interval == 0 or i == num_steps - 1: + self.post_process(i) + + def post_process(self, i): + # Write the results. We'll use JAX backend for the post-processing + if not isinstance(self.f_0, jnp.ndarray): + f_0 = wp.to_jax(self.f_0) + else: + f_0 = self.f_0 + + macro = Macroscopic(compute_backend=ComputeBackend.JAX) + + rho, u = macro(f_0) + + # compute velocity magnitude + u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 + fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_z": u[2], "u_magnitude": u_magnitude} + save_fields_vtk(fields, timestep=i) + save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) + + # Save monitor plot + self.plot_uplus(u, i) + return + + def plot_uplus(self, u, timestep): + # Compute moving average of drag coefficient, 100, 1000, 10000 + # mean streamwise velocity in wall units u^+(z) + # Wall distance in wall units to be used inside output_data + zz = np.arange(self.grid_shape[-1]) + zz = np.minimum(zz, zz.max() - zz) + yplus = zz * self.u_tau / self.visc + uplus = np.mean(u[0], axis=(0, 1)) / self.u_tau + uplus_loglaw = vonKarman_loglaw_wall(yplus) + dns_dic = get_dns_data() + plt.clf() + plt.semilogx(yplus, uplus, "r.", yplus, uplus_loglaw, "k:", dns_dic["y+"], dns_dic["Umean"], "b-") + ax = plt.gca() + ax.set_xlim([0.1, 300]) + ax.set_ylim([0, 20]) + fname = "uplus_" + str(timestep // 10000).zfill(5) + ".png" + plt.savefig(fname, format="png") + + +if __name__ == "__main__": + # Problem Configuration + # h: channel half-width + channel_half_width = 50 + + # Define channel geometry based on h + grid_size_x = 6 * channel_half_width + grid_size_y = 3 * channel_half_width + grid_size_z = 2 * channel_half_width + + # Grid parameters + grid_shape = (grid_size_x, grid_size_y, grid_size_z) + + # Define flow regime + # Set up Reynolds number and deduce relaxation time (omega) + Re_tau = 180 + u_tau = 0.001 + + # Runtime & backend configurations + backend = ComputeBackend.WARP + velocity_set = xlb.velocity_set.D3Q27() + precision_policy = PrecisionPolicy.FP32FP32 + num_steps = 100000 + print_interval = 1000 + + # Print simulation info + print("\n" + "=" * 50 + "\n") + print("Simulation Configuration:") + print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") + print(f"Backend: {backend}") + print(f"Velocity set: {velocity_set}") + print(f"Precision policy: {precision_policy}") + print(f"Reynolds number: {Re_tau}") + print(f"Max iterations: {num_steps}") + print("\n" + "=" * 50 + "\n") + + simulation = TurbulentChannel3D(channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, backend, precision_policy) + simulation.run(num_steps, print_interval, post_process_interval=1000) diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py index aff3ee8..c8439d9 100644 --- a/xlb/helper/initializers.py +++ b/xlb/helper/initializers.py @@ -3,8 +3,10 @@ def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None): - rho = rho or grid.create_field(cardinality=1, fill_value=1.0) - u = u or grid.create_field(cardinality=velocity_set.d, fill_value=0.0) + if rho is None: + rho = grid.create_field(cardinality=1, fill_value=1.0) + if u is None: + u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0) equilibrium = QuadraticEquilibrium() if backend == ComputeBackend.JAX: diff --git a/xlb/operator/force/exact_difference_force.py b/xlb/operator/force/exact_difference_force.py index 431a025..f148e12 100644 --- a/xlb/operator/force/exact_difference_force.py +++ b/xlb/operator/force/exact_difference_force.py @@ -1,5 +1,5 @@ from functools import partial -from jax import jit +from jax import jit, lax import warp as wp from typing import Any @@ -63,7 +63,7 @@ def jax_implementation(self, f_postcollision, feq, rho, u): f_postcollision: jax.numpy.ndarray The post-collision distribution functions with the force applied. """ - delta_u = self.force_vector + delta_u = lax.broadcast_in_dim(self.force_vector, u.shape, (0,)) feq_force = self.equilibrium(rho, u + delta_u) f_postcollision += feq_force - feq return f_postcollision From e21780ee848552f7eb99a9178ce56671d5980723 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 13 Sep 2024 17:44:47 -0400 Subject: [PATCH 102/144] Removed the need to have separate JAX/Warp constants --- examples/cfd/flow_past_sphere_3d.py | 8 +- examples/cfd/lid_driven_cavity_2d.py | 9 +- .../cfd/lid_driven_cavity_2d_distributed.py | 2 +- examples/cfd/windtunnel_3d.py | 8 +- .../bc_equilibrium/test_bc_equilibrium_jax.py | 3 +- .../test_bc_equilibrium_warp.py | 5 +- .../test_bc_fullway_bounce_back_jax.py | 3 +- .../test_bc_fullway_bounce_back_warp.py | 5 +- .../mask/test_bc_indices_masker_jax.py | 4 +- .../mask/test_bc_indices_masker_warp.py | 5 +- tests/grids/test_grid_jax.py | 13 +- tests/grids/test_grid_warp.py | 14 +- .../collision/test_bgk_collision_jax.py | 4 +- .../collision/test_bgk_collision_warp.py | 5 +- .../equilibrium/test_equilibrium_jax.py | 4 +- .../equilibrium/test_equilibrium_warp.py | 7 +- .../macroscopic/test_macroscopic_jax.py | 3 +- .../macroscopic/test_macroscopic_warp.py | 5 +- .../bc_extrapolation_outflow.py | 4 +- .../bc_fullway_bounce_back.py | 2 +- .../bc_halfway_bounce_back.py | 2 +- .../boundary_condition/bc_regularized.py | 10 +- xlb/operator/boundary_condition/bc_zouhe.py | 6 +- .../indices_boundary_masker.py | 2 +- .../boundary_masker/mesh_boundary_masker.py | 2 +- xlb/operator/collision/bgk.py | 2 +- .../equilibrium/quadratic_equilibrium.py | 4 +- xlb/operator/force/momentum_transfer.py | 4 +- xlb/operator/macroscopic/second_moment.py | 2 +- .../macroscopic/zero_first_moments.py | 2 +- xlb/operator/stepper/nse_stepper.py | 4 +- xlb/operator/stream/stream.py | 2 +- xlb/velocity_set/d2q9.py | 4 +- xlb/velocity_set/d3q19.py | 4 +- xlb/velocity_set/d3q27.py | 4 +- xlb/velocity_set/velocity_set.py | 129 ++++++++++++------ 36 files changed, 182 insertions(+), 114 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 7d07e05..e09a4bf 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -118,7 +118,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) # remove boundary cells @@ -135,9 +139,9 @@ def post_process(self, i): if __name__ == "__main__": # Running the simulation grid_shape = (512 // 2, 128 // 2, 128 // 2) - velocity_set = xlb.velocity_set.D3Q19() backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) omega = 1.6 simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 16fb4f9..837aa01 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -9,6 +9,7 @@ from xlb.utils import save_fields_vtk, save_image import warp as wp import jax.numpy as jnp +import xlb.velocity_set class LidDrivenCavity2D: @@ -80,7 +81,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) @@ -100,8 +105,8 @@ def post_process(self, i): grid_size = 500 grid_shape = (grid_size, grid_size) backend = ComputeBackend.WARP - velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 225d6bd..397d476 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -26,8 +26,8 @@ def setup_stepper(self, omega): grid_size = 512 grid_shape = (grid_size, grid_size) backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! - velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 8395579..3522122 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -151,7 +151,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) @@ -215,8 +219,8 @@ def plot_drag_coefficient(self): # Configuration backend = ComputeBackend.WARP - velocity_set = xlb.velocity_set.D3Q27() precision_policy = PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend) wind_speed = 0.02 num_steps = 100000 print_interval = 1000 diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 9d2e4ff..bd40dfb 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 917e7e4..9f0cd68 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -8,10 +8,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 2fe0b40..dde05bb 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index b25d39e..96e2f21 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index af121d3..9325890 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -8,13 +8,13 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) - @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 4d02540..43ad052 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -7,10 +7,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index edd9dd0..dd74da6 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -8,17 +8,18 @@ import jax.numpy as jnp -def init_xlb_env(): +def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=xlb.velocity_set.D2Q9, # does not affect the test + velocity_set=vel_set, ) @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_jax_2d_grid_initialization(grid_size): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) grid_shape = (grid_size, grid_size) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) @@ -34,7 +35,7 @@ def test_jax_2d_grid_initialization(grid_size): @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_jax_3d_grid_initialization(grid_size): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D3Q19) grid_shape = (grid_size, grid_size, grid_size) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) @@ -54,7 +55,7 @@ def test_jax_3d_grid_initialization(grid_size): def test_jax_grid_create_field_fill_value(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) grid_shape = (100, 100) fill_value = 3.14 my_grid = grid_factory(grid_shape) @@ -66,7 +67,7 @@ def test_jax_grid_create_field_fill_value(): @pytest.fixture(autouse=True) def setup_xlb_env(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) if __name__ == "__main__": diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 22445cc..11c8b2a 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -7,18 +7,18 @@ from xlb.precision_policy import Precision -def init_xlb_env(): +def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=xlb.velocity_set.D2Q9, + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) - @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_warp_grid_create_field(grid_size): for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]: - init_xlb_env() + init_xlb_env(xlb.velocity_set.D3Q19) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) @@ -27,7 +27,7 @@ def test_warp_grid_create_field(grid_size): def test_warp_grid_create_field_fill_value(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) grid_shape = (100, 100) fill_value = 3.14 my_grid = grid_factory(grid_shape) @@ -42,7 +42,7 @@ def test_warp_grid_create_field_fill_value(): @pytest.fixture(autouse=True) def setup_xlb_env(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) if __name__ == "__main__": diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 5a400e0..aebc726 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -9,13 +9,13 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) - @pytest.mark.parametrize( "dim,velocity_set,grid_shape,omega", [ diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 522ea33..2743050 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index 07bafe7..50418bc 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -8,13 +8,13 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) - @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index 063a723..9759fb2 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -8,13 +8,12 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) - - @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 50d1735..2c2ad55 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -8,10 +8,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index d98a014..6a97927 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 55f094d..7bd447e 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -137,9 +137,9 @@ def _construct_warp(self): # Set local constants sound_speed = 1.0 / wp.sqrt(3.0) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = self.velocity_set.q - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices @wp.func def get_normal_vectors_2d( diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 6af4226..729dbb4 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -56,7 +56,7 @@ def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices _q = wp.constant(self.velocity_set.q) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 5c001d9..2cb60fa 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -63,7 +63,7 @@ def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): def _construct_warp(self): # Set local constants - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices # Construct the functional for this BC @wp.func diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index b74c0b1..069847b 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -138,11 +138,11 @@ def _construct_warp(self): _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) - _opp_indices = self.velocity_set.wp_opp_indices - _w = self.velocity_set.wp_w - _c = self.velocity_set.wp_c - _c32 = self.velocity_set.wp_c32 - _qi = self.velocity_set.wp_qi + _opp_indices = self.velocity_set.opp_indices + _w = self.velocity_set.w + _c = self.velocity_set.c + _c32 = self.velocity_set.c32 + _qi = self.velocity_set.qi # TODO: related to _c32: this is way less than ideal. we should not be making new types @wp.func diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 56c6868..56cd19a 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -182,9 +182,9 @@ def _construct_warp(self): _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) - _opp_indices = self.velocity_set.wp_opp_indices - _c = self.velocity_set.wp_c - _c32 = self.velocity_set.wp_c32 + _opp_indices = self.velocity_set.opp_indices + _c = self.velocity_set.c + _c32 = self.velocity_set.c32 # TODO: this is way less than ideal. we should not be making new types @wp.func diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 208f50f..54cb7aa 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -97,7 +97,7 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) # Construct the warp 2D kernel diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 366c9d6..c43ea02 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -47,7 +47,7 @@ def jax_implementation( def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) # Construct the warp kernel diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 9dbfabd..196e3ba 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -23,7 +23,7 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _w = self.velocity_set.wp_w + _w = self.velocity_set.w _omega = wp.constant(self.compute_dtype(self.omega)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 3af6b4a..2fe1526 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -26,8 +26,8 @@ def jax_implementation(self, rho, u): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c - _w = self.velocity_set.wp_w + _c = self.velocity_set.c + _w = self.velocity_set.w _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 66dba13..d273f8a 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -87,8 +87,8 @@ def jax_implementation(self, f, boundary_map, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c - _opp_indices = self.velocity_set.wp_opp_indices + _c = self.velocity_set.c + _opp_indices = self.velocity_set.opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool _no_slip_id = self.no_slip_bc_instance.id diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index db8fce6..5209d69 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -56,7 +56,7 @@ def jax_implementation( def _construct_warp(self): # Make constants for warp - _cc = self.velocity_set.wp_cc + _cc = self.velocity_set.cc _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 _pi_vec = wp.vec( diff --git a/xlb/operator/macroscopic/zero_first_moments.py b/xlb/operator/macroscopic/zero_first_moments.py index fbf7c93..48cf108 100644 --- a/xlb/operator/macroscopic/zero_first_moments.py +++ b/xlb/operator/macroscopic/zero_first_moments.py @@ -46,7 +46,7 @@ def jax_implementation(self, f): def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index ba3e294..60f610a 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -90,9 +90,9 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = self.velocity_set.q - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices sound_speed = 1.0 / wp.sqrt(3.0) @wp.struct diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index f91b567..d96c307 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -52,7 +52,7 @@ def _streaming_jax_i(f, c): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the warp functional diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index 700806c..5324618 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -13,7 +13,7 @@ class D2Q9(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ - def __init__(self): + def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] @@ -21,4 +21,4 @@ def __init__(self): w = np.array([4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36]) # Call the parent constructor - super().__init__(2, 9, c, w) + super().__init__(2, 9, c, w, precision_policy=precision_policy, backend=backend) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 97db1d9..4a48c2f 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -14,7 +14,7 @@ class D3Q19(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self): + def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T w = np.zeros(19) @@ -27,4 +27,4 @@ def __init__(self): w[i] = 1.0 / 36.0 # Initialize the lattice - super().__init__(3, 19, c, w) + super().__init__(3, 19, c, w, precision_policy=precision_policy, backend=backend) diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py index 702acf4..b056d53 100644 --- a/xlb/velocity_set/d3q27.py +++ b/xlb/velocity_set/d3q27.py @@ -14,7 +14,7 @@ class D3Q27(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self): + def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T w = np.zeros(27) @@ -29,4 +29,4 @@ def __init__(self): w[i] = 1.0 / 216.0 # Initialize the Lattice - super().__init__(3, 27, c, w) + super().__init__(3, 27, c, w, precision_policy=precision_policy, backend=backend) diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index a93d039..91069d2 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -2,9 +2,11 @@ import math import numpy as np - import warp as wp +import jax.numpy as jnp +from xlb import DefaultConfig +from xlb.compute_backend import ComputeBackend class VelocitySet(object): """ @@ -22,35 +24,86 @@ class VelocitySet(object): The weights of the lattice. Shape: (q,) """ - def __init__(self, d, q, c, w): + def __init__(self, d, q, c, w, precision_policy, backend): # Store the dimension and the number of velocities self.d = d self.q = q + self.precision_policy = precision_policy + self.backend = backend + + # Create all properties in NumPy first + self._init_numpy_properties(c, w) + + # Convert properties to backend-specific format + if self.backend == ComputeBackend.WARP: + self._init_warp_properties() + elif self.backend == ComputeBackend.JAX: + self._init_jax_properties() + else: + raise ValueError(f"Unsupported compute backend: {self.backend}") + + # Set up backend-specific constants + self._init_backend_constants() - # Constants - self.cs = math.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - - # Construct the properties of the lattice - self.c = c - self.w = w - self.cc = self._construct_lattice_moment() - self.opp_indices = self._construct_opposite_indices() - self.get_opp_index = lambda i: self.opp_indices[i] + def _init_numpy_properties(self, c, w): + """ + Initialize all properties in NumPy first. + """ + self._c = np.array(c) + self._w = np.array(w) + self._opp_indices = self._construct_opposite_indices() + self._cc = self._construct_lattice_moment() + self._c32 = self._c.astype(np.float64) + self._qi = self._construct_qi() + + # Constants in NumPy + self.cs = np.float64(math.sqrt(3) / 3.0) + self.cs2 = np.float64(1.0 / 3.0) + self.inv_cs2 = np.float64(3.0) + + # Indices self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() - self.qi = self._construct_qi() - # Make warp constants for these vectors - # TODO: Following warp updates these may not be necessary - self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) - self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow - self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) - self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) - self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) - self.wp_qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.qi)) + def _init_warp_properties(self): + """ + Convert NumPy properties to Warp-specific properties. + """ + dtype = self.precision_policy.compute_precision.wp_dtype + self.c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self._c)) + self.w = wp.constant(wp.vec(self.q, dtype=dtype)(self._w)) + self.opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self._opp_indices)) + self.cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._cc)) + self.c32 = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c32)) + self.qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._qi)) + + def _init_jax_properties(self): + """ + Convert NumPy properties to JAX-specific properties. + """ + dtype = self.precision_policy.compute_precision.jax_dtype + self.c = jnp.array(self._c, dtype=dtype) + self.w = jnp.array(self._w, dtype=dtype) + self.opp_indices = jnp.array(self._opp_indices, dtype=jnp.int32) + self.cc = jnp.array(self._cc, dtype=dtype) + self.c32 = jnp.array(self._c32, dtype=dtype) + self.qi = jnp.array(self._qi, dtype=dtype) + + def _init_backend_constants(self): + """ + Initialize the constants for the backend. + """ + if self.backend == ComputeBackend.WARP: + dtype = self.precision_policy.compute_precision.wp_dtype + self.cs = wp.constant(dtype(self.cs)) + self.cs2 = wp.constant(dtype(self.cs2)) + self.inv_cs2 = wp.constant(dtype(self.inv_cs2)) + elif self.backend == ComputeBackend.JAX: + dtype = self.precision_policy.compute_precision.jax_dtype + self.cs = jnp.array(self.cs, dtype=dtype) + self.cs2 = jnp.array(self.cs2, dtype=dtype) + self.inv_cs2 = jnp.array(self.inv_cs2, dtype=dtype) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -64,13 +117,11 @@ def warp_stream_mat(self, dtype): def _construct_qi(self): # Qi = cc - cs^2*I dim = self.d - Qi = self.cc.copy() + Qi = self._cc.copy() if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) + diagonal, offdiagonal = (0, 3, 5), (1, 2, 4) elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) + diagonal, offdiagonal = (0, 2), (1,) else: raise ValueError(f"dim = {dim} not supported") @@ -92,19 +143,18 @@ def _construct_lattice_moment(self): cc: numpy.ndarray The moments of the lattice. """ - c = self.c.T + c = self._c.T # Counter for the loop cntr = 0 - + c = self._c.T # nt: number of independent elements of a symmetric tensor nt = self.d * (self.d + 1) // 2 - cc = np.zeros((self.q, nt)) - for a in range(0, self.d): + cntr = 0 + for a in range(self.d): for b in range(a, self.d): cc[:, cntr] = c[:, a] * c[:, b] cntr += 1 - return cc def _construct_opposite_indices(self): @@ -119,9 +169,8 @@ def _construct_opposite_indices(self): opposite: numpy.ndarray The indices of the opposite velocities. """ - c = self.c.T - opposite = np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) - return opposite + c = self._c.T + return np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) def _construct_main_indices(self): """ @@ -134,10 +183,9 @@ def _construct_main_indices(self): numpy.ndarray The indices of the main velocities. """ - c = self.c.T + c = self._c.T if self.d == 2: return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] - elif self.d == 3: return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] @@ -151,8 +199,7 @@ def _construct_right_indices(self): numpy.ndarray The indices of the right velocities. """ - c = self.c.T - return np.nonzero(c[:, 0] == 1)[0] + return np.nonzero(self._c.T[:, 0] == 1)[0] def _construct_left_indices(self): """ @@ -164,8 +211,7 @@ def _construct_left_indices(self): numpy.ndarray The indices of the left velocities. """ - c = self.c.T - return np.nonzero(c[:, 0] == -1)[0] + return np.nonzero(self._c.T[:, 0] == -1)[0] def __str__(self): """ @@ -178,3 +224,4 @@ def __repr__(self): This function returns the name of the lattice in the format of DxQy. """ return "D{}Q{}".format(self.d, self.q) + From fd759d3b75229ca1c551d728c8bb75c4c63a8d45 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 13 Sep 2024 18:11:41 -0400 Subject: [PATCH 103/144] Renamed _c32 --- xlb/operator/boundary_condition/bc_regularized.py | 8 ++++---- xlb/operator/boundary_condition/bc_zouhe.py | 6 +++--- xlb/velocity_set/velocity_set.py | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 069847b..86e7a08 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -141,9 +141,9 @@ def _construct_warp(self): _opp_indices = self.velocity_set.opp_indices _w = self.velocity_set.w _c = self.velocity_set.c - _c32 = self.velocity_set.c32 + _c_float = self.velocity_set.c32 _qi = self.velocity_set.qi - # TODO: related to _c32: this is way less than ideal. we should not be making new types + # TODO: related to _c_float: this is way less than ideal. we should not be making new types @wp.func def _get_fsum( @@ -165,7 +165,7 @@ def get_normal_vectors_2d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l]) @wp.func def get_normal_vectors_3d( @@ -173,7 +173,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) @wp.func def bounceback_nonequilibrium( diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 56cd19a..00c4436 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -184,7 +184,7 @@ def _construct_warp(self): _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _c = self.velocity_set.c - _c32 = self.velocity_set.c32 + _c_float = self.velocity_set.c32 # TODO: this is way less than ideal. we should not be making new types @wp.func @@ -193,7 +193,7 @@ def get_normal_vectors_2d( ): l = lattice_direction if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l]) + normals = -_u_vec(_c_float[0, l], _c_float[1, l]) return normals @wp.func @@ -216,7 +216,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) @wp.func def bounceback_nonequilibrium( diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 91069d2..27ee5a0 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -53,7 +53,7 @@ def _init_numpy_properties(self, c, w): self._w = np.array(w) self._opp_indices = self._construct_opposite_indices() self._cc = self._construct_lattice_moment() - self._c32 = self._c.astype(np.float64) + self._c_float = self._c.astype(np.float64) self._qi = self._construct_qi() # Constants in NumPy @@ -75,7 +75,7 @@ def _init_warp_properties(self): self.w = wp.constant(wp.vec(self.q, dtype=dtype)(self._w)) self.opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self._opp_indices)) self.cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._cc)) - self.c32 = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c32)) + self.c32 = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c_float)) self.qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._qi)) def _init_jax_properties(self): @@ -87,7 +87,7 @@ def _init_jax_properties(self): self.w = jnp.array(self._w, dtype=dtype) self.opp_indices = jnp.array(self._opp_indices, dtype=jnp.int32) self.cc = jnp.array(self._cc, dtype=dtype) - self.c32 = jnp.array(self._c32, dtype=dtype) + self.c32 = jnp.array(self._c_float, dtype=dtype) self.qi = jnp.array(self._qi, dtype=dtype) def _init_backend_constants(self): From 9a37ab36a511a0c0eb3e2ca1141bd4c5a66a7512 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 13 Sep 2024 18:11:41 -0400 Subject: [PATCH 104/144] Renamed _c32 --- xlb/operator/boundary_condition/bc_regularized.py | 8 ++++---- xlb/operator/boundary_condition/bc_zouhe.py | 6 +++--- xlb/velocity_set/velocity_set.py | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 069847b..8fdec8f 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -141,9 +141,9 @@ def _construct_warp(self): _opp_indices = self.velocity_set.opp_indices _w = self.velocity_set.w _c = self.velocity_set.c - _c32 = self.velocity_set.c32 + _c_float = self.velocity_set.c_float _qi = self.velocity_set.qi - # TODO: related to _c32: this is way less than ideal. we should not be making new types + # TODO: related to _c_float: this is way less than ideal. we should not be making new types @wp.func def _get_fsum( @@ -165,7 +165,7 @@ def get_normal_vectors_2d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l]) @wp.func def get_normal_vectors_3d( @@ -173,7 +173,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) @wp.func def bounceback_nonequilibrium( diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 56cd19a..db910f9 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -184,7 +184,7 @@ def _construct_warp(self): _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _c = self.velocity_set.c - _c32 = self.velocity_set.c32 + _c_float = self.velocity_set.c_float # TODO: this is way less than ideal. we should not be making new types @wp.func @@ -193,7 +193,7 @@ def get_normal_vectors_2d( ): l = lattice_direction if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l]) + normals = -_u_vec(_c_float[0, l], _c_float[1, l]) return normals @wp.func @@ -216,7 +216,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) @wp.func def bounceback_nonequilibrium( diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 91069d2..2405b36 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -53,7 +53,7 @@ def _init_numpy_properties(self, c, w): self._w = np.array(w) self._opp_indices = self._construct_opposite_indices() self._cc = self._construct_lattice_moment() - self._c32 = self._c.astype(np.float64) + self._c_float = self._c.astype(np.float64) self._qi = self._construct_qi() # Constants in NumPy @@ -75,7 +75,7 @@ def _init_warp_properties(self): self.w = wp.constant(wp.vec(self.q, dtype=dtype)(self._w)) self.opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self._opp_indices)) self.cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._cc)) - self.c32 = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c32)) + self.c_float = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c_float)) self.qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._qi)) def _init_jax_properties(self): @@ -87,7 +87,7 @@ def _init_jax_properties(self): self.w = jnp.array(self._w, dtype=dtype) self.opp_indices = jnp.array(self._opp_indices, dtype=jnp.int32) self.cc = jnp.array(self._cc, dtype=dtype) - self.c32 = jnp.array(self._c32, dtype=dtype) + self.c_float = jnp.array(self._c_float, dtype=dtype) self.qi = jnp.array(self._qi, dtype=dtype) def _init_backend_constants(self): From db392f163ecdf92a35f7f5f7c0d13e20364b00f6 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 13 Sep 2024 18:16:13 -0400 Subject: [PATCH 105/144] Removed instances of fp32 --- xlb/operator/boundary_condition/bc_equilibrium.py | 2 +- xlb/operator/boundary_condition/bc_regularized.py | 2 +- xlb/operator/boundary_condition/bc_zouhe.py | 2 +- xlb/operator/collision/kbc.py | 2 +- xlb/operator/force/momentum_transfer.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 6853c0e..c018a60 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -73,7 +73,7 @@ def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = wp.float32(self.rho) + _rho = self.compute_dtype(self.rho) _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) # Construct the functional for this BC diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 8fdec8f..9734368 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -136,7 +136,7 @@ def _construct_warp(self): # compute Qi tensor and store it in self _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = wp.float32(rho) + _rho = self.compute_dtype(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _w = self.velocity_set.w diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index db910f9..782eb4c 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -180,7 +180,7 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = wp.float32(rho) + _rho = self.compute_dtype(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) _opp_indices = self.velocity_set.opp_indices _c = self.velocity_set.c diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index ddd7ecc..748ebea 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -251,7 +251,7 @@ def entropic_scalar_product( feq: Any, ): e = wp.cw_div(wp.cw_mul(x, y), feq) - e_sum = wp.float32(0.0) + e_sum = self.compute_dtype(0.0) for i in range(self.velocity_set.q): e_sum += e[i] return e_sum diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index d273f8a..d25baf7 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -144,7 +144,7 @@ def kernel2d( if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) wp.atomic_add(force, 0, m) @@ -193,7 +193,7 @@ def kernel3d( if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) wp.atomic_add(force, 0, m) From e2028cbce6e107cfbad475595f09e224fdd0c11e Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 13 Sep 2024 18:29:36 -0400 Subject: [PATCH 106/144] Operators work in fp64 and fp64 with no issues --- examples/cfd/flow_past_sphere_3d.py | 5 +++++ examples/cfd/lid_driven_cavity_2d.py | 5 +++++ examples/cfd/lid_driven_cavity_2d_distributed.py | 5 +++++ xlb/default_config.py | 6 +++--- xlb/operator/equilibrium/quadratic_equilibrium.py | 4 ++-- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index e09a4bf..5a70d20 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -19,6 +19,7 @@ import numpy as np import jax.numpy as jnp import time +import jax class FlowOverSphere: @@ -141,6 +142,10 @@ def post_process(self, i): grid_shape = (512 // 2, 128 // 2, 128 // 2) backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 837aa01..f73fb0e 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -8,6 +8,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp +import jax import jax.numpy as jnp import xlb.velocity_set @@ -106,6 +107,10 @@ def post_process(self, i): grid_shape = (grid_size, grid_size) backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 397d476..7efe907 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -1,4 +1,5 @@ import xlb +import jax from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy from xlb.operator.stepper import IncompressibleNavierStokesStepper @@ -27,6 +28,10 @@ def setup_stepper(self, omega): grid_shape = (grid_size, grid_size) backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/xlb/default_config.py b/xlb/default_config.py index f1ca25f..20eac44 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -1,5 +1,7 @@ +import jax from xlb.compute_backend import ComputeBackend from dataclasses import dataclass +from xlb.precision_policy import PrecisionPolicy @dataclass @@ -17,7 +19,7 @@ def init(velocity_set, default_backend, default_precision_policy): if default_backend == ComputeBackend.WARP: import warp as wp - wp.init() + wp.init() # TODO: Must be removed in the future versions of WARP elif default_backend == ComputeBackend.JAX: check_multi_gpu_support() else: @@ -29,8 +31,6 @@ def default_backend() -> ComputeBackend: def check_multi_gpu_support(): - import jax - gpus = jax.devices("gpu") if len(gpus) > 1: print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 2fe1526..0cce91b 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -52,10 +52,10 @@ def functional( cu *= self.compute_dtype(3.0) # Compute usqr - usqr = 1.5 * wp.dot(u, u) + usqr = self.compute_dtype(1.5) * wp.dot(u, u) # Compute feq - feq[l] = rho * _w[l] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + feq[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu * (self.compute_dtype(1.0) + self.compute_dtype(0.5) * cu) - usqr) return feq From b85ab6ace622c8a6befb5eea6c43a7fe171098eb Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 16 Sep 2024 12:35:48 -0400 Subject: [PATCH 107/144] fixing merge issues --- examples/performance/mlups_3d.py | 13 +++++++------ .../mask/test_bc_indices_masker_jax.py | 1 + tests/grids/test_grid_warp.py | 1 + tests/kernels/collision/test_bgk_collision_jax.py | 1 + tests/kernels/equilibrium/test_equilibrium_jax.py | 1 + tests/kernels/equilibrium/test_equilibrium_warp.py | 2 ++ xlb/velocity_set/velocity_set.py | 2 +- 7 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 52be973..04e1902 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -32,7 +32,7 @@ def setup_simulation(args): raise ValueError("Invalid precision") xlb.init( - velocity_set=xlb.velocity_set.D3Q19(), + velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend), default_backend=backend, default_precision_policy=precision_policy, ) @@ -55,7 +55,7 @@ def define_boundary_indices(grid): + grid.boundingBoxIndices["right"][i] + grid.boundingBoxIndices["front"][i] + grid.boundingBoxIndices["back"][i] - for i in range(xlb.velocity_set.D3Q19().d) + for i in range(len(grid.shape)) ] return lid, walls @@ -67,7 +67,7 @@ def setup_boundary_conditions(grid): return [bc_top, bc_walls] -def run(f_0, f_1, backend, grid, bc_id, missing_mask, num_steps): +def run(f_0, f_1, backend, precision_policy, grid, bc_id, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) @@ -75,7 +75,7 @@ def run(f_0, f_1, backend, grid, bc_id, missing_mask, num_steps): stepper = distribute( stepper, grid, - xlb.velocity_set.D3Q19(), + xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend), ) start_time = time.time() @@ -98,10 +98,11 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time): def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) + velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) grid, f_0, f_1, missing_mask, bc_id = create_grid_and_fields(args.cube_edge) - f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) + f_0 = initialize_eq(f_0, grid, velocity_set, backend) - elapsed_time = run(f_0, f_1, backend, grid, bc_id, missing_mask, args.num_steps) + elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_id, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index 30c17d2..b0eb6ae 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -15,6 +15,7 @@ def init_xlb_env(velocity_set): velocity_set=vel_set, ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 11c8b2a..0d9bdca 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -15,6 +15,7 @@ def init_xlb_env(velocity_set): velocity_set=vel_set, ) + @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_warp_grid_create_field(grid_size): for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]: diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index aebc726..1672cd5 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -16,6 +16,7 @@ def init_xlb_env(velocity_set): velocity_set=vel_set, ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape,omega", [ diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index 50418bc..6dac5b1 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -15,6 +15,7 @@ def init_xlb_env(velocity_set): velocity_set=vel_set, ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index 9759fb2..ecd3834 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -14,6 +14,8 @@ def init_xlb_env(velocity_set): default_backend=ComputeBackend.JAX, velocity_set=vel_set, ) + + @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 2405b36..50934c5 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -8,6 +8,7 @@ from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend + class VelocitySet(object): """ Base class for the velocity set of the Lattice Boltzmann Method (LBM), e.g. D2Q9, D3Q27, etc. @@ -224,4 +225,3 @@ def __repr__(self): This function returns the name of the lattice in the format of DxQy. """ return "D{}Q{}".format(self.d, self.q) - From a7de4bbd6f18a5700c6a0ed0fd81d578080ab8e3 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 16 Sep 2024 17:31:52 -0400 Subject: [PATCH 108/144] resolved all unit test issues --- .../bc_equilibrium/test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_jax.py | 5 +++-- .../test_bc_fullway_bounce_back_warp.py | 4 ++-- .../mask/test_bc_indices_masker_warp.py | 5 +++-- tests/grids/test_grid_warp.py | 2 +- tests/kernels/collision/test_bgk_collision_warp.py | 2 +- tests/kernels/equilibrium/test_equilibrium_jax.py | 2 +- tests/kernels/equilibrium/test_equilibrium_warp.py | 2 +- tests/kernels/macroscopic/test_macroscopic_warp.py | 2 +- tests/kernels/stream/test_stream_jax.py | 3 ++- tests/kernels/stream/test_stream_warp.py | 3 ++- .../boundary_condition/bc_halfway_bounce_back.py | 3 +++ xlb/operator/boundary_condition/bc_zouhe.py | 3 +++ .../boundary_condition/boundary_condition.py | 4 ++++ .../boundary_masker/indices_boundary_masker.py | 14 +++++++++----- xlb/velocity_set/velocity_set.py | 2 +- 16 files changed, 38 insertions(+), 20 deletions(-) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 597b841..a7ef555 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -11,7 +11,7 @@ def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index d17e0a5..84e3763 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -6,6 +6,7 @@ from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig +from xlb.operator.boundary_masker import IndicesBoundaryMasker def init_xlb_env(velocity_set): @@ -37,7 +38,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + indices_boundary_masker = IndicesBoundaryMasker() # Make indices for boundary conditions (sphere) sphere_radius = grid_shape[0] // 4 @@ -74,7 +75,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): for i in range(velocity_set.q): jnp.allclose( - f[velocity_set.get_opp_index(i)][tuple(indices)], + f[velocity_set.opp_indices[i]][tuple(indices)], f_pre[i][tuple(indices)], ) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index d57053d..b67ef43 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -12,7 +12,7 @@ def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) @@ -75,7 +75,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): for i in range(velocity_set.q): np.allclose( - f[velocity_set.get_opp_index(i)][tuple(indices)], + f[velocity_set.opp_indices[i]][tuple(indices)], f_post[i][tuple(indices)], ) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 381b9b7..41e9f23 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -4,13 +4,14 @@ from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig from xlb.grid import grid_factory +from xlb.operator.boundary_masker import IndicesBoundaryMasker def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) @@ -35,7 +36,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() + indices_boundary_masker = IndicesBoundaryMasker() # Make indices for boundary conditions (sphere) sphere_radius = grid_shape[0] // 4 diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 0d9bdca..782434d 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -11,7 +11,7 @@ def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 2743050..382e368 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -12,7 +12,7 @@ def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index 6dac5b1..aa4f051 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index ecd3834..fdd796a 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -11,7 +11,7 @@ def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 6a97927..4f33bc2 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -12,7 +12,7 @@ def init_xlb_env(velocity_set): vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.JAX, + default_backend=ComputeBackend.WARP, velocity_set=vel_set, ) diff --git a/tests/kernels/stream/test_stream_jax.py b/tests/kernels/stream/test_stream_jax.py index ef635ea..c1cae52 100644 --- a/tests/kernels/stream/test_stream_jax.py +++ b/tests/kernels/stream/test_stream_jax.py @@ -8,10 +8,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index b83368d..0d100cf 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -10,10 +10,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 91b95d4..1e072a2 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -49,6 +49,9 @@ def __init__( mesh_vertices, ) + # This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior + self.needs_padding = True + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, bc_id, missing_mask): diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index e8fdbad..337c4aa 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -69,6 +69,9 @@ def __init__( self.prescribed_value = jnp.atleast_1d(prescribed_value)[(slice(None),) + (None,) * dim] # TODO: this won't work if the prescribed values are a profile with the length of bdry indices! + # This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior + self.needs_padding = True + @partial(jit, static_argnums=(0,), inline=True) def _get_known_middle_mask(self, missing_mask): known_mask = missing_mask[self.velocity_set.opp_indices] diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index faeac77..4f5cca5 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -48,6 +48,10 @@ def __init__( # Set the implementation step self.implementation_step = implementation_step + # A flag to indicate whether bc indices need to be padded in both normal directions to identify missing directions + # when inside/outside of the geoemtry is not known + self.needs_padding = False + if self.compute_backend == ComputeBackend.WARP: # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 6715294..7cb9c61 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -73,7 +73,7 @@ def jax_implementation(self, bclist, bc_id, missing_mask, start_index=None): local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] bmap = bmap.at[tuple(padded_indices)].set(id_number) - if any(self.are_indices_in_interior(bc.indices, domain_shape)): + if any(self.are_indices_in_interior(bc.indices, domain_shape)) and bc.needs_padding: # checking if all indices associated with this BC are in the interior of the domain (not at the boundary). # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain. if dim == 2: @@ -83,7 +83,7 @@ def jax_implementation(self, bclist, bc_id, missing_mask, start_index=None): # Assign the boundary id to the push indices push_indices = padded_indices[:, :, None] + self.velocity_set.c[:, None, :] - push_indices = push_indices.reshape(3, -1) + push_indices = push_indices.reshape(dim, -1) bmap = bmap.at[tuple(push_indices)].set(id_number) # We are done with bc.indices. Remove them from BC objects @@ -132,12 +132,14 @@ def kernel2d( pull_index[d] = index[d] - _c[d, l] push_index[d] = index[d] + _c[d, l] + # set bc_id for all bc indices + bc_id[0, index[0], index[1]] = id_number[ii] + # check if pull index is out of bound # These directions will have missing information after streaming if pull_index[0] < 0 or pull_index[0] >= missing_mask.shape[1] or pull_index[1] < 0 or pull_index[1] >= missing_mask.shape[2]: # Set the missing mask missing_mask[l, index[0], index[1]] = True - bc_id[0, index[0], index[1]] = id_number[ii] # handling geometries in the interior of the computational domain elif ( @@ -188,6 +190,9 @@ def kernel3d( pull_index[d] = index[d] - _c[d, l] push_index[d] = index[d] + _c[d, l] + # set bc_id for all bc indices + bc_id[0, index[0], index[1], index[2]] = id_number[ii] + # check if pull index is out of bound # These directions will have missing information after streaming if ( @@ -200,7 +205,6 @@ def kernel3d( ): # Set the missing mask missing_mask[l, index[0], index[1], index[2]] = True - bc_id[0, index[0], index[1], index[2]] = id_number[ii] # handling geometries in the interior of the computational domain elif ( @@ -232,7 +236,7 @@ def warp_implementation(self, bclist, bc_id, missing_mask, start_index=None): for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) - is_interior += self.are_indices_in_interior(bc.indices, bc_id[0].shape) + is_interior += self.are_indices_in_interior(bc.indices, bc_id[0].shape) if bc.needs_padding else [False] * len(bc.indices[0]) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 50934c5..3b4a973 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -84,7 +84,7 @@ def _init_jax_properties(self): Convert NumPy properties to JAX-specific properties. """ dtype = self.precision_policy.compute_precision.jax_dtype - self.c = jnp.array(self._c, dtype=dtype) + self.c = jnp.array(self._c, dtype=jnp.int32) self.w = jnp.array(self._w, dtype=dtype) self.opp_indices = jnp.array(self._opp_indices, dtype=jnp.int32) self.cc = jnp.array(self._cc, dtype=dtype) From 1fa7ad6a1430d103a3d00c6be2d2726117d08eeb Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 16 Sep 2024 18:15:05 -0400 Subject: [PATCH 109/144] fixed dtype issues for fp64 --- examples/cfd/flow_past_sphere_3d.py | 4 -- examples/cfd/lid_driven_cavity_2d.py | 4 -- .../cfd/lid_driven_cavity_2d_distributed.py | 4 -- .../bc_extrapolation_outflow.py | 4 +- .../boundary_condition/bc_regularized.py | 14 +++--- xlb/operator/boundary_condition/bc_zouhe.py | 10 ++-- xlb/operator/collision/kbc.py | 50 +++++++++++-------- xlb/operator/macroscopic/second_moment.py | 2 +- xlb/operator/operator.py | 8 +++ xlb/velocity_set/velocity_set.py | 6 +++ 10 files changed, 57 insertions(+), 49 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 4286384..9645ce5 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -19,7 +19,6 @@ import numpy as np import jax.numpy as jnp import time -import jax class FlowOverSphere: @@ -143,9 +142,6 @@ def post_process(self, i): backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 - if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: - jax.config.update("jax_enable_x64", True) - velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 0178616..adee825 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -8,7 +8,6 @@ from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp -import jax import jax.numpy as jnp import xlb.velocity_set @@ -108,9 +107,6 @@ def post_process(self, i): backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 - if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: - jax.config.update("jax_enable_x64", True) - velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 7efe907..a7d43b4 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -1,5 +1,4 @@ import xlb -import jax from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy from xlb.operator.stepper import IncompressibleNavierStokesStepper @@ -29,9 +28,6 @@ def setup_stepper(self, omega): backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! precision_policy = PrecisionPolicy.FP32FP32 - if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: - jax.config.update("jax_enable_x64", True) - velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index e5a6128..9ca05ef 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -135,7 +135,7 @@ def apply_jax(self, f_pre, f_post, bc_id, missing_mask): def _construct_warp(self): # Set local constants - sound_speed = 1.0 / wp.sqrt(3.0) + sound_speed = self.compute_dtype(1.0 / wp.sqrt(3.0)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _c = self.velocity_set.c _q = self.velocity_set.q @@ -187,7 +187,7 @@ def prepare_bc_auxilary_data( _f = f_post for l in range(self.velocity_set.q): if missing_mask[l] == wp.uint8(1): - _f[_opp_indices[l]] = (1.0 - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux[l] return _f # Construct the warp kernel diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 32f391b..48715ba 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -154,7 +154,7 @@ def _get_fsum( fsum_middle = self.compute_dtype(0.0) for l in range(_q): if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += 2.0 * fpop[l] + fsum_known += self.compute_dtype(2.0) * fpop[l] elif missing_mask[l] != wp.uint8(1): fsum_middle += fpop[l] return fsum_known + fsum_middle @@ -202,13 +202,13 @@ def regularize_fpop( nt = _d * (_d + 1) // 2 QiPi1 = _f_vec() for l in range(_q): - QiPi1[l] = 0.0 + QiPi1[l] = self.compute_dtype(0.0) for t in range(nt): QiPi1[l] += _qi[l, t] * PiNeq[t] # assign all populations based on eq 45 of Latt et al (2008) # fneq ~ f^1 - fpop1 = 9.0 / 2.0 * _w[l] * QiPi1[l] + fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1[l] fpop[l] = feq[l] + fpop1 return fpop @@ -230,7 +230,7 @@ def functional3d_velocity( unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = fsum / (1.0 + unormal) + _rho = fsum / (self.compute_dtype(1.0) + unormal) # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) @@ -255,7 +255,7 @@ def functional3d_pressure( # calculate velocity fsum = _get_fsum(_f, missing_mask) - unormal = -1.0 + fsum / _rho + unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback @@ -284,7 +284,7 @@ def functional2d_velocity( unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = fsum / (1.0 + unormal) + _rho = fsum / (self.compute_dtype(1.0) + unormal) # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) @@ -309,7 +309,7 @@ def functional2d_pressure( # calculate velocity fsum = _get_fsum(_f, missing_mask) - unormal = -1.0 + fsum / _rho + unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 337c4aa..41b2380 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -208,7 +208,7 @@ def _get_fsum( fsum_middle = self.compute_dtype(0.0) for l in range(_q): if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += 2.0 * fpop[l] + fsum_known += self.compute_dtype(2.0) * fpop[l] elif missing_mask[l] != wp.uint8(1): fsum_middle += fpop[l] return fsum_known + fsum_middle @@ -250,7 +250,7 @@ def functional3d_velocity( unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = fsum / (1.0 + unormal) + _rho = fsum / (self.compute_dtype(1.0) + unormal) # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) @@ -272,7 +272,7 @@ def functional3d_pressure( # calculate velocity fsum = _get_fsum(_f, missing_mask) - unormal = -1.0 + fsum / _rho + unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback @@ -298,7 +298,7 @@ def functional2d_velocity( unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] - _rho = fsum / (1.0 + unormal) + _rho = fsum / (self.compute_dtype(1.0) + unormal) # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) @@ -320,7 +320,7 @@ def functional2d_pressure( # calculate velocity fsum = _get_fsum(_f, missing_mask) - unormal = -1.0 + fsum / _rho + unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index c67b3da..7984829 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -217,30 +217,34 @@ def decompose_shear_d3q27( s = _f_vec() # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) - s[9] = (2.0 * nxz - nyz) / 6.0 - s[18] = (2.0 * nxz - nyz) / 6.0 - s[3] = (-nxz + 2.0 * nyz) / 6.0 - s[6] = (-nxz + 2.0 * nyz) / 6.0 - s[1] = (-nxz - nyz) / 6.0 - s[2] = (-nxz - nyz) / 6.0 + two = self.self.compute_dtype(2.0) + four = self.self.compute_dtype(4.0) + six = self.self.compute_dtype(6.0) + + s[9] = (two * nxz - nyz) / six + s[18] = (two * nxz - nyz) / six + s[3] = (-nxz + two * nyz) / six + s[6] = (-nxz + two * nyz) / six + s[1] = (-nxz - nyz) / six + s[2] = (-nxz - nyz) / six # For c = (i, j, 0) - s[12] = pi[1] / 4.0 - s[24] = pi[1] / 4.0 - s[21] = -pi[1] / 4.0 - s[15] = -pi[1] / 4.0 + s[12] = pi[1] / four + s[24] = pi[1] / four + s[21] = -pi[1] / four + s[15] = -pi[1] / four # For c = (i, 0, k) - s[10] = pi[2] / 4.0 - s[20] = pi[2] / 4.0 - s[19] = -pi[2] / 4.0 - s[11] = -pi[2] / 4.0 + s[10] = pi[2] / four + s[20] = pi[2] / four + s[19] = -pi[2] / four + s[11] = -pi[2] / four # For c = (0, j, k) - s[8] = pi[4] / 4.0 - s[4] = pi[4] / 4.0 - s[7] = -pi[4] / 4.0 - s[5] = -pi[4] / 4.0 + s[8] = pi[4] / four + s[4] = pi[4] / four + s[7] = -pi[4] / four + s[5] = -pi[4] / four return s @@ -272,10 +276,11 @@ def functional2d( # Perform collision delta_h = fneq - delta_s - gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( + two = self.compute_dtype(2.0) + gamma = _inv_beta - (two - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( _epsilon + entropic_scalar_product(delta_h, delta_h, feq) ) - fout = f - _beta * (2.0 * delta_s + gamma * delta_h) + fout = f - _beta * (two * delta_s + gamma * delta_h) return fout @@ -294,10 +299,11 @@ def functional3d( # Perform collision delta_h = fneq - delta_s - gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( + two = self.compute_dtype(2.0) + gamma = _inv_beta - (two - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( _epsilon + entropic_scalar_product(delta_h, delta_h, feq) ) - fout = f - _beta * (2.0 * delta_s + gamma * delta_h) + fout = f - _beta * (two * delta_s + gamma * delta_h) return fout diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index 5209d69..bda2a25 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -72,7 +72,7 @@ def functional( # Get second order moment (a symmetric tensore shaped into a vector) pi = _pi_vec() for d in range(_pi_dim): - pi[d] = 0.0 + pi[d] = self.compute_dtype(0.0) for q in range(self.velocity_set.q): pi[d] += _cc[q, d] * fneq[q] return pi diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 83c6538..6e8bbbb 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,8 +1,10 @@ import inspect import traceback +import jax from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig +from xlb.precision_policy import PrecisionPolicy class Operator: @@ -28,6 +30,12 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non if self.compute_backend == ComputeBackend.WARP: self.warp_functional, self.warp_kernel = self._construct_warp() + # Updating JAX config in case fp64 is requested + if self.compute_backend == ComputeBackend.JAX and ( + precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32 + ): + jax.config.update("jax_enable_x64", True) + @classmethod def register_backend(cls, backend_name): """ diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 3b4a973..41db991 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -4,9 +4,11 @@ import numpy as np import warp as wp import jax.numpy as jnp +import jax from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy class VelocitySet(object): @@ -32,6 +34,10 @@ def __init__(self, d, q, c, w, precision_policy, backend): self.precision_policy = precision_policy self.backend = backend + # Updating JAX config in case fp64 is requested + if backend == ComputeBackend.JAX and (precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32): + jax.config.update("jax_enable_x64", True) + # Create all properties in NumPy first self._init_numpy_properties(c, w) From d6db85d25698c8081ac2ff6e010064b7d0623a68 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 16 Sep 2024 18:24:07 -0400 Subject: [PATCH 110/144] used numpy instead of lists for finding boundingBox indices --- examples/cfd/turbulent_channel_3d.py | 19 +++++--- xlb/grid/grid.py | 72 ++++++++++++++++------------ 2 files changed, 54 insertions(+), 37 deletions(-) diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index e160fd5..11e6bbd 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -77,7 +77,8 @@ def _setup(self): def define_boundary_indices(self): # top and bottom sides of the channel are no-slip and the other directions are periodic - walls = [self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["top"][i] for i in range(self.velocity_set.d)] + boundingBoxIndices = self.grid.bounding_box_indices(remove_edges=True) + walls = [boundingBoxIndices["bottom"][i] + boundingBoxIndices["top"][i] for i in range(self.velocity_set.d)] return walls def setup_boundary_conditions(self): @@ -129,7 +130,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) @@ -182,10 +187,10 @@ def plot_uplus(self, u, timestep): # Runtime & backend configurations backend = ComputeBackend.WARP - velocity_set = xlb.velocity_set.D3Q27() - precision_policy = PrecisionPolicy.FP32FP32 - num_steps = 100000 - print_interval = 1000 + precision_policy = PrecisionPolicy.FP64FP64 + velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend) + num_steps = 10000000 + print_interval = 100000 # Print simulation info print("\n" + "=" * 50 + "\n") @@ -199,4 +204,4 @@ def plot_uplus(self, u, timestep): print("\n" + "=" * 50 + "\n") simulation = TurbulentChannel3D(channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps, print_interval, post_process_interval=1000) + simulation.run(num_steps, print_interval, post_process_interval=100000) diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 7d8a678..7494c3e 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Tuple +import numpy as np from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend @@ -24,52 +25,63 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): self.shape = shape self.dim = len(shape) self.compute_backend = compute_backend - self._bounding_box_indices() + self.boundingBoxIndices = self.bounding_box_indices() self._initialize_backend() @abstractmethod def _initialize_backend(self): pass - def _bounding_box_indices(self): + def bounding_box_indices(self, remove_edges=False): """ This function calculates the indices of the bounding box of a 2D or 3D grid. The bounding box is defined as the set of grid points on the outer edge of the grid. + Parameters + ---------- + remove_edges : bool, optional + If True, the nodes along the edges (not just the corners) are removed from the bounding box indices. + Default is False. + Returns ------- - boundingBox (dict): A dictionary where keys are the names of the bounding box faces - ("bottom", "top", "left", "right" for 2D; additional "front", "back" for 3D), and values - are numpy arrays of indices corresponding to each face. + boundingBox (dict): A dictionary where keys are the names of the bounding box faces + ("bottom", "top", "left", "right" for 2D; additional "front", "back" for 3D), and values + are numpy arrays of indices corresponding to each face. """ - def to_tuple(lst): - d = len(lst[0]) - return [tuple([sublist[i] for sublist in lst]) for i in range(d)] + # Get the shape of the grid + origin = np.array([0, 0, 0]) + bounds = np.array(self.shape) + if remove_edges: + origin += 1 + bounds -= 1 + slice_x = slice(origin[0], bounds[0]) + slice_y = slice(origin[1], bounds[1]) + dim = len(bounds) + + # Generate bounding box indices for each face + grid = np.indices(self.shape) + boundingBoxIndices = {} - if self.dim == 2: - # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. - # Each edge is represented as an array of indices. For example, the bottom edge includes - # all points where the y-coordinate is 0, so its indices are [[i, 0] for i in range(nx)]. + if dim == 2: nx, ny = self.shape - self.boundingBoxIndices = { - "bottom": to_tuple([[i, 0] for i in range(nx)]), - "top": to_tuple([[i, ny - 1] for i in range(nx)]), - "left": to_tuple([[0, i] for i in range(ny)]), - "right": to_tuple([[nx - 1, i] for i in range(ny)]), + boundingBoxIndices = { + "bottom": grid[:, slice_x, 0], + "top": grid[:, slice_x, ny - 1], + "left": grid[:, 0, slice_y], + "right": grid[:, nx - 1, slice_y], } - - elif self.dim == 3: - # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. - # Each face is represented as an array of indices. For example, the bottom face includes all points - # where the z-coordinate is 0, so its indices are [[i, j, 0] for i in range(nx) for j in range(ny)]. + elif dim == 3: nx, ny, nz = self.shape - self.boundingBoxIndices = { - "bottom": to_tuple([[i, j, 0] for i in range(nx) for j in range(ny)]), - "top": to_tuple([[i, j, nz - 1] for i in range(nx) for j in range(ny)]), - "left": to_tuple([[0, j, k] for j in range(ny) for k in range(nz)]), - "right": to_tuple([[nx - 1, j, k] for j in range(ny) for k in range(nz)]), - "front": to_tuple([[i, 0, k] for i in range(nx) for k in range(nz)]), - "back": to_tuple([[i, ny - 1, k] for i in range(nx) for k in range(nz)]), + slice_z = slice(origin[2], bounds[2]) + boundingBoxIndices = { + "bottom": grid[:, slice_x, slice_y, 0].reshape(3, -1), + "top": grid[:, slice_x, slice_y, nz - 1].reshape(3, -1), + "left": grid[:, 0, slice_y, slice_z].reshape(3, -1), + "right": grid[:, nx - 1, slice_y, slice_z].reshape(3, -1), + "front": grid[:, slice_x, 0, slice_z].reshape(3, -1), + "back": grid[:, slice_x, ny - 1, slice_z].reshape(3, -1), } - return + + return {k: v.tolist() for k, v in boundingBoxIndices.items()} From 29aa3a233faddac8b4d5c3e475616f307d084518 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 16 Sep 2024 18:43:28 -0400 Subject: [PATCH 111/144] renaming to bc_mask --- examples/cfd/flow_past_sphere_3d.py | 6 +-- examples/cfd/lid_driven_cavity_2d.py | 6 +-- examples/cfd/turbulent_channel_3d.py | 6 +-- examples/cfd/windtunnel_3d.py | 10 ++--- .../flow_past_sphere.py | 12 +++--- .../cfd_old_to_be_migrated/taylor_green.py | 6 +-- examples/performance/mlups_3d.py | 12 +++--- .../bc_equilibrium/test_bc_equilibrium_jax.py | 6 +-- .../test_bc_equilibrium_warp.py | 6 +-- .../test_bc_fullway_bounce_back_jax.py | 6 +-- .../test_bc_fullway_bounce_back_warp.py | 6 +-- .../mask/test_bc_indices_masker_jax.py | 24 ++++++------ .../mask/test_bc_indices_masker_warp.py | 28 +++++++------- xlb/helper/nse_solver.py | 4 +- .../boundary_condition/bc_do_nothing.py | 16 ++++---- .../boundary_condition/bc_equilibrium.py | 16 ++++---- .../bc_extrapolation_outflow.py | 20 +++++----- .../bc_fullway_bounce_back.py | 16 ++++---- .../bc_halfway_bounce_back.py | 16 ++++---- .../boundary_condition/bc_regularized.py | 16 ++++---- xlb/operator/boundary_condition/bc_zouhe.py | 16 ++++---- .../boundary_condition/boundary_condition.py | 10 ++--- .../indices_boundary_masker.py | 38 +++++++++---------- .../boundary_masker/mesh_boundary_masker.py | 26 ++++++------- xlb/operator/force/momentum_transfer.py | 20 +++++----- xlb/operator/stepper/nse_stepper.py | 20 +++++----- 26 files changed, 184 insertions(+), 184 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 9645ce5..8c99dff 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -34,7 +34,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -91,7 +91,7 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.bc_id, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_id, self.missing_mask, (0, 0, 0)) + self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -102,7 +102,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index adee825..c67ce8e 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -25,7 +25,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -58,7 +58,7 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.bc_id, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_id, self.missing_mask) + self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -68,7 +68,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index 11e6bbd..ea8cad7 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -55,7 +55,7 @@ def __init__(self, channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -92,7 +92,7 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.bc_id, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_id, self.missing_mask) + self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) def initialize_fields(self): shape = (self.velocity_set.d,) + (self.grid_shape) @@ -113,7 +113,7 @@ def setup_stepper(self): def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 795cd42..92c94aa 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -36,7 +36,7 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_id = create_nse_fields(grid_shape) + self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -118,8 +118,8 @@ def setup_boundary_masker(self): bc_mesh = self.boundary_conditions[-1] dx = self.grid_spacing origin, spacing = (0, 0, 0), (dx, dx, dx) - self.bc_id, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_id, self.missing_mask) - self.bc_id, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_id, self.missing_mask) + self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask) + self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) @@ -134,7 +134,7 @@ def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_id, self.missing_mask, i) + self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: @@ -169,7 +169,7 @@ def post_process(self, i): save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) # Compute lift and drag - boundary_force = self.momentum_transfer(self.f_0, self.bc_id, self.missing_mask) + boundary_force = self.momentum_transfer(self.f_0, self.bc_mask, self.missing_mask) drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane lift = boundary_force[2] c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py index ca404a9..1684266 100644 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -75,7 +75,7 @@ def warp_implementation(self, rho, u, vel): u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - bc_id = grid.create_field(cardinality=1, dtype=wp.uint8) + bc_mask = grid.create_field(cardinality=1, dtype=wp.uint8) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) # Make operators @@ -154,19 +154,19 @@ def warp_implementation(self, rho, u, vel): indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - bc_id, missing_mask = indices_boundary_masker(indices, half_way_bc.id, bc_id, missing_mask, (0, 0, 0)) + bc_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, bc_mask, missing_mask, (0, 0, 0)) # Set inlet bc lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) - bc_id, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, bc_id, missing_mask, (0, 0, 0)) + bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, bc_mask, missing_mask, (0, 0, 0)) # Set outlet bc lower_bound = (nr - 1, 0, 0) upper_bound = (nr - 1, nr, nr) direction = (-1, 0, 0) - bc_id, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, bc_id, missing_mask, (0, 0, 0)) + bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, bc_mask, missing_mask, (0, 0, 0)) # Set initial conditions rho, u = initializer(rho, u, vel) @@ -181,7 +181,7 @@ def warp_implementation(self, rho, u, vel): num_steps = 1024 * 8 start = time.time() for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, bc_id, missing_mask, _) + f1 = stepper(f0, f1, bc_mask, missing_mask, _) f1, f0 = f0, f1 if (_ % plot_freq == 0) and (not compute_mlup): rho, u = macroscopic(f0, rho, u) @@ -191,7 +191,7 @@ def warp_implementation(self, rho, u, vel): plt.imshow(u[0, :, nr // 2, :].numpy()) plt.colorbar() plt.subplot(1, 2, 2) - plt.imshow(bc_id[0, :, nr // 2, :].numpy()) + plt.imshow(bc_mask[0, :, nr // 2, :].numpy()) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py index 970a246..846ba30 100644 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -113,7 +113,7 @@ def run_taylor_green(backend, compute_mlup=True): u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - bc_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + bc_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators @@ -149,10 +149,10 @@ def run_taylor_green(backend, compute_mlup=True): for _ in tqdm(range(num_steps)): # Time step if backend == "warp": - f1 = stepper(f0, f1, bc_id, missing_mask, _) + f1 = stepper(f0, f1, bc_mask, missing_mask, _) f1, f0 = f0, f1 elif backend == "jax": - f0 = stepper(f0, bc_id, missing_mask, _) + f0 = stepper(f0, bc_mask, missing_mask, _) # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 04e1902..32c3d6d 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -42,9 +42,9 @@ def setup_simulation(args): def create_grid_and_fields(cube_edge): grid_shape = (cube_edge, cube_edge, cube_edge) - grid, f_0, f_1, missing_mask, bc_id = create_nse_fields(grid_shape) + grid, f_0, f_1, missing_mask, bc_mask = create_nse_fields(grid_shape) - return grid, f_0, f_1, missing_mask, bc_id + return grid, f_0, f_1, missing_mask, bc_mask def define_boundary_indices(grid): @@ -67,7 +67,7 @@ def setup_boundary_conditions(grid): return [bc_top, bc_walls] -def run(f_0, f_1, backend, precision_policy, grid, bc_id, missing_mask, num_steps): +def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_steps): omega = 1.0 stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, precision_policy, grid, bc_id, missing_mask, num_step start_time = time.time() for i in range(num_steps): - f_1 = stepper(f_0, f_1, bc_id, missing_mask, i) + f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i) f_0, f_1 = f_1, f_0 wp.synchronize() @@ -99,10 +99,10 @@ def main(): args = parse_arguments() backend, precision_policy = setup_simulation(args) velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) - grid, f_0, f_1, missing_mask, bc_id = create_grid_and_fields(args.cube_edge) + grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, velocity_set, backend) - elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_id, missing_mask, args.num_steps) + elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index f9d7d33..1025e3b 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -33,7 +33,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -59,7 +59,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices=indices, ) - bc_id, missing_mask = indices_boundary_masker([equilibrium_bc], bc_id, missing_mask, start_index=None) + bc_mask, missing_mask = indices_boundary_masker([equilibrium_bc], bc_mask, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -67,7 +67,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, bc_id, missing_mask) + f = equilibrium_bc(f_pre, f_post, bc_mask, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index a7ef555..5eb0c10 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -32,7 +32,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -59,7 +59,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices=indices, ) - bc_id, missing_mask = indices_boundary_masker([equilibrium_bc], bc_id, missing_mask, start_index=None) + bc_mask, missing_mask = indices_boundary_masker([equilibrium_bc], bc_mask, missing_mask, start_index=None) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -67,7 +67,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = equilibrium_bc(f_pre, f_post, bc_id, missing_mask) + f = equilibrium_bc(f_pre, f_post, bc_mask, missing_mask) f = f.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 84e3763..cd18975 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -36,7 +36,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -56,7 +56,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - bc_id, missing_mask = indices_boundary_masker([fullway_bc], bc_id, missing_mask, start_index=None) + bc_mask, missing_mask = indices_boundary_masker([fullway_bc], bc_mask, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0) # Generate a random field with the same shape @@ -69,7 +69,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f = fullway_bc(f_pre, f_post, bc_id, missing_mask) + f = fullway_bc(f_pre, f_post, bc_mask, missing_mask) assert f.shape == (velocity_set.q,) + grid_shape diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index b67ef43..10b9244 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -35,7 +35,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -55,7 +55,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - bc_id, missing_mask = indices_boundary_masker([fullway_bc], bc_id, missing_mask, start_index=None) + bc_mask, missing_mask = indices_boundary_masker([fullway_bc], bc_mask, missing_mask, start_index=None) # Generate a random field with the same shape random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) @@ -66,7 +66,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0 ) # Arbitrary value so that we can check if the values are changed outside the boundary - f_pre = fullway_bc(f_pre, f_post, bc_id, missing_mask) + f_pre = fullway_bc(f_pre, f_post, bc_mask, missing_mask) f = f_pre.numpy() f_post = f_post.numpy() diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index b0eb6ae..79d56d8 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -35,7 +35,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker() @@ -57,26 +57,26 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - bc_id, missing_mask = indices_boundary_masker([test_bc], bc_id, missing_mask, start_index=None) + bc_mask, missing_mask = indices_boundary_masker([test_bc], bc_mask, missing_mask, start_index=None) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype - assert bc_id.dtype == xlb.Precision.UINT8.jax_dtype + assert bc_mask.dtype == xlb.Precision.UINT8.jax_dtype - assert bc_id.shape == (1,) + grid_shape + assert bc_mask.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert jnp.all(bc_id[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the bc_id is zero - bc_id = bc_id.at[0, indices[0], indices[1]].set(0) - assert jnp.all(bc_id == 0) + assert jnp.all(bc_mask[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the bc_mask is zero + bc_mask = bc_mask.at[0, indices[0], indices[1]].set(0) + assert jnp.all(bc_mask == 0) if dim == 3: - assert jnp.all(bc_id[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the bc_id is zero - bc_id = bc_id.at[0, indices[0], indices[1], indices[2]].set(0) - assert jnp.all(bc_id == 0) + assert jnp.all(bc_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the bc_mask is zero + bc_mask = bc_mask.at[0, indices[0], indices[1], indices[2]].set(0) + assert jnp.all(bc_mask == 0) if __name__ == "__main__": diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 41e9f23..f56c9ce 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -34,7 +34,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) - bc_id = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) + bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) indices_boundary_masker = IndicesBoundaryMasker() @@ -56,33 +56,33 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - bc_id, missing_mask = indices_boundary_masker( + bc_mask, missing_mask = indices_boundary_masker( [test_bc], - bc_id, + bc_mask, missing_mask, start_index=(0, 0, 0) if dim == 3 else (0, 0), ) assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype - assert bc_id.dtype == xlb.Precision.UINT8.wp_dtype + assert bc_mask.dtype == xlb.Precision.UINT8.wp_dtype - bc_id = bc_id.numpy() + bc_mask = bc_mask.numpy() missing_mask = missing_mask.numpy() - assert bc_id.shape == (1,) + grid_shape + assert bc_mask.shape == (1,) + grid_shape assert missing_mask.shape == (velocity_set.q,) + grid_shape if dim == 2: - assert np.all(bc_id[0, indices[0], indices[1]] == test_bc.id) - # assert that the rest of the bc_id is zero - bc_id[0, indices[0], indices[1]] = 0 - assert np.all(bc_id == 0) + assert np.all(bc_mask[0, indices[0], indices[1]] == test_bc.id) + # assert that the rest of the bc_mask is zero + bc_mask[0, indices[0], indices[1]] = 0 + assert np.all(bc_mask == 0) if dim == 3: - assert np.all(bc_id[0, indices[0], indices[1], indices[2]] == test_bc.id) - # assert that the rest of the bc_id is zero - bc_id[0, indices[0], indices[1], indices[2]] = 0 - assert np.all(bc_id == 0) + assert np.all(bc_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) + # assert that the rest of the bc_mask is zero + bc_mask[0, indices[0], indices[1], indices[2]] = 0 + assert np.all(bc_mask == 0) if __name__ == "__main__": diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index 58027ff..3d0ef3e 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -14,6 +14,6 @@ def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compu f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) - bc_id = grid.create_field(cardinality=1, dtype=Precision.UINT8) + bc_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) - return grid, f_0, f_1, missing_mask, bc_id + return grid, f_0, f_1, missing_mask, bc_mask diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 54371dc..dcdc8fd 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -48,8 +48,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, bc_id, missing_mask): - boundary = bc_id == self.id + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): + boundary = bc_mask == self.id return jnp.where(boundary, f_pre, f_post) def _construct_warp(self): @@ -67,7 +67,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.uint8), ): # Get the global index @@ -75,7 +75,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): @@ -93,7 +93,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -101,7 +101,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): @@ -119,11 +119,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 452dc5a..716fd8e 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -62,11 +62,11 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, bc_id, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) new_shape = feq.shape + (1,) * self.velocity_set.d feq = lax.broadcast_in_dim(feq, new_shape, [0]) - boundary = bc_id == self.id + boundary = bc_mask == self.id return jnp.where(boundary, feq, f_post) @@ -92,7 +92,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -100,7 +100,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): @@ -118,7 +118,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -126,7 +126,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): @@ -144,11 +144,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 9ca05ef..87c6850 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -94,13 +94,13 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, bc_id, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ Prepare the auxilary distribution functions for the boundary condition. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) - boundary = bc_id == self.id + boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -123,8 +123,8 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, bc_id, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_id, missing_mask): - boundary = bc_id == self.id + def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -195,7 +195,7 @@ def prepare_bc_auxilary_data( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -203,7 +203,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data @@ -236,7 +236,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -244,7 +244,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) _f_aux = _f_vec() # special preparation of auxiliary data @@ -277,11 +277,11 @@ def kernel3d( return (functional, prepare_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index a83b65c..85bffb1 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -48,8 +48,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_id, missing_mask): - boundary = bc_id == self.id + def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) @@ -77,14 +77,14 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): @@ -102,7 +102,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -110,7 +110,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): @@ -128,11 +128,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 1e072a2..c2483c9 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -54,8 +54,8 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_id, missing_mask): - boundary = bc_id == self.id + def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( @@ -91,7 +91,7 @@ def functional( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -99,7 +99,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): @@ -117,7 +117,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -125,7 +125,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): @@ -143,11 +143,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 48715ba..7879137 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -105,9 +105,9 @@ def regularize_fpop(self, fpop, feq): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_id, missing_mask): + def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): # creat a mask to slice boundary cells - boundary = bc_id == self.id + boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -325,7 +325,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -333,7 +333,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -351,7 +351,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -359,7 +359,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -385,11 +385,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 41b2380..0bf68b8 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -159,9 +159,9 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_id, missing_mask): + def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): # creat a mask to slice boundary cells - boundary = bc_id == self.id + boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -333,7 +333,7 @@ def functional2d_pressure( def kernel2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), ): # Get the global index @@ -341,7 +341,7 @@ def kernel2d( index = wp.vec2i(i, j) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -359,7 +359,7 @@ def kernel2d( def kernel3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), ): # Get the global index @@ -367,7 +367,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_id, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == wp.uint8(self.id): @@ -393,11 +393,11 @@ def kernel3d( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_id, missing_mask): + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_id, missing_mask], + inputs=[f_pre, f_post, bc_mask, missing_mask], dim=f_pre.shape[1:], ) return f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 4f5cca5..8d47929 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -70,14 +70,14 @@ def prepare_bc_auxilary_data( def _get_thread_data_2d( f_pre: wp.array3d(dtype=Any), f_post: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), index: wp.vec2i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = bc_id[0, index[0], index[1]] + _boundary_id = bc_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -95,14 +95,14 @@ def _get_thread_data_2d( def _get_thread_data_3d( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), index: wp.vec3i, ): # Get the boundary id and missing mask _f_pre = _f_vec() _f_post = _f_vec() - _boundary_id = bc_id[0, index[0], index[1], index[2]] + _boundary_id = bc_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations @@ -123,7 +123,7 @@ def _get_thread_data_3d( self.prepare_bc_auxilary_data = prepare_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) - def prepare_bc_auxilary_data(self, f_pre, f_post, bc_id, missing_mask): + def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ A placeholder function for prepare the auxilary distribution functions for the boundary condition. currently being called after collision only. diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 7cb9c61..fdc4331 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -47,7 +47,7 @@ def are_indices_in_interior(self, indices, shape): @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation(self, bclist, bc_id, missing_mask, start_index=None): + def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): # Pad the missing mask to create a grid mask to identify out of bound boundaries # Set padded regin to True (i.e. boundary) dim = missing_mask.ndim - 1 @@ -55,17 +55,17 @@ def jax_implementation(self, bclist, bc_id, missing_mask, start_index=None): pad_x, pad_y, pad_z = nDevices, 1, 1 if dim == 2: grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) - bmap = jnp.pad(bc_id[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) + bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) if dim == 3: grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) - bmap = jnp.pad(bc_id[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) + bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) # shift indices shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) if start_index is None: start_index = (0,) * dim - domain_shape = bc_id[0].shape + domain_shape = bc_mask[0].shape for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" @@ -92,11 +92,11 @@ def jax_implementation(self, bclist, bc_id, missing_mask, start_index=None): grid_mask = self.stream(grid_mask) if dim == 2: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] - bc_id = bc_id.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y]) + bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y]) if dim == 3: missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] - bc_id = bc_id.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) - return bc_id, missing_mask + bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) + return bc_mask, missing_mask def _construct_warp(self): # Make constants for warp @@ -109,7 +109,7 @@ def kernel2d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), is_interior: wp.array1d(dtype=wp.bool), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), start_index: wp.vec2i, ): @@ -132,8 +132,8 @@ def kernel2d( pull_index[d] = index[d] - _c[d, l] push_index[d] = index[d] + _c[d, l] - # set bc_id for all bc indices - bc_id[0, index[0], index[1]] = id_number[ii] + # set bc_mask for all bc indices + bc_mask[0, index[0], index[1]] = id_number[ii] # check if pull index is out of bound # These directions will have missing information after streaming @@ -151,7 +151,7 @@ def kernel2d( ): # Set the missing mask missing_mask[l, push_index[0], push_index[1]] = True - bc_id[0, push_index[0], push_index[1]] = id_number[ii] + bc_mask[0, push_index[0], push_index[1]] = id_number[ii] # Construct the warp 3D kernel @wp.kernel @@ -159,7 +159,7 @@ def kernel3d( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), is_interior: wp.array1d(dtype=wp.bool), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -190,8 +190,8 @@ def kernel3d( pull_index[d] = index[d] - _c[d, l] push_index[d] = index[d] + _c[d, l] - # set bc_id for all bc indices - bc_id[0, index[0], index[1], index[2]] = id_number[ii] + # set bc_mask for all bc indices + bc_mask[0, index[0], index[1], index[2]] = id_number[ii] # check if pull index is out of bound # These directions will have missing information after streaming @@ -218,14 +218,14 @@ def kernel3d( ): # Set the missing mask missing_mask[l, push_index[0], push_index[1], push_index[2]] = True - bc_id[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] + bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, bclist, bc_id, missing_mask, start_index=None): + def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] @@ -236,7 +236,7 @@ def warp_implementation(self, bclist, bc_id, missing_mask, start_index=None): for d in range(dim): index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) - is_interior += self.are_indices_in_interior(bc.indices, bc_id[0].shape) if bc.needs_padding else [False] * len(bc.indices[0]) + is_interior += self.are_indices_in_interior(bc.indices, bc_mask[0].shape) if bc.needs_padding else [False] * len(bc.indices[0]) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) @@ -255,11 +255,11 @@ def warp_implementation(self, bclist, bc_id, missing_mask, start_index=None): indices, id_number, is_interior, - bc_id, + bc_mask, missing_mask, start_index, ], dim=indices.shape[1], ) - return bc_id, missing_mask + return bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 91b5c2e..b88d251 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -33,17 +33,17 @@ def jax_implementation( bc, origin, spacing, - bc_id, + bc_mask, missing_mask, start_index=(0, 0, 0), ): raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") # Use Warp backend even for this particular operation. wp.init() - bc_id = wp.from_jax(bc_id) + bc_mask = wp.from_jax(bc_mask) missing_mask = wp.from_jax(missing_mask) - bc_id, missing_mask = self.warp_implementation(bc, origin, spacing, bc_id, missing_mask, start_index) - return wp.to_jax(bc_id), wp.to_jax(missing_mask) + bc_mask, missing_mask = self.warp_implementation(bc, origin, spacing, bc_mask, missing_mask, start_index) + return wp.to_jax(bc_mask), wp.to_jax(missing_mask) def _construct_warp(self): # Make constants for warp @@ -57,7 +57,7 @@ def kernel( origin: wp.vec3, spacing: wp.vec3, id_number: wp.int32, - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), start_index: wp.vec3i, ): @@ -77,9 +77,9 @@ def kernel( # Compute the maximum length max_length = wp.sqrt( - (spacing[0] * wp.float32(bc_id.shape[1])) ** 2.0 - + (spacing[1] * wp.float32(bc_id.shape[2])) ** 2.0 - + (spacing[2] * wp.float32(bc_id.shape[3])) ** 2.0 + (spacing[0] * wp.float32(bc_mask.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(bc_mask.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(bc_mask.shape[3])) ** 2.0 ) # evaluate if point is inside mesh @@ -98,7 +98,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and missing_mask - bc_id[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) missing_mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @@ -109,7 +109,7 @@ def warp_implementation( bc, origin, spacing, - bc_id, + bc_mask, missing_mask, start_index=(0, 0, 0), ): @@ -138,11 +138,11 @@ def warp_implementation( origin, spacing, id_number, - bc_id, + bc_mask, missing_mask, start_index, ], - dim=bc_id.shape[1:], + dim=bc_mask.shape[1:], ) - return bc_id, missing_mask + return bc_mask, missing_mask diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 5ee785d..da64e67 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -50,13 +50,13 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f, bc_id, missing_mask): + def jax_implementation(self, f, bc_mask, missing_mask): """ Parameters ---------- f : jax.numpy.ndarray The post-collision distribution function at each node in the grid. - bc_id : jax.numpy.ndarray + bc_mask : jax.numpy.ndarray A grid field with 0 everywhere except for boundary nodes which are designated by their respective boundary id's. missing_mask : jax.numpy.ndarray @@ -71,10 +71,10 @@ def jax_implementation(self, f, bc_id, missing_mask): # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. f_post_collision = f f_post_stream = self.stream(f_post_collision) - f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_id, missing_mask) + f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask) # Compute momentum transfer - boundary = bc_id == self.no_slip_bc_instance.id + boundary = bc_mask == self.no_slip_bc_instance.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -103,7 +103,7 @@ def _construct_warp(self): @wp.kernel def kernel2d( f: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=wp.uint8), + bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), force: wp.array(dtype=Any), ): @@ -112,7 +112,7 @@ def kernel2d( index = wp.vec2i(i, j) # Get the boundary id - _boundary_id = bc_id[0, index[0], index[1]] + _boundary_id = bc_mask[0, index[0], index[1]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -153,7 +153,7 @@ def kernel2d( @wp.kernel def kernel3d( f: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), ): @@ -162,7 +162,7 @@ def kernel3d( index = wp.vec3i(i, j, k) # Get the boundary id - _boundary_id = bc_id[0, index[0], index[1], index[2]] + _boundary_id = bc_mask[0, index[0], index[1], index[2]] _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # TODO fix vec bool @@ -205,14 +205,14 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, bc_id, missing_mask): + def warp_implementation(self, f, bc_mask, missing_mask): # Allocate the force vector (the total integral value will be computed) force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f, bc_id, missing_mask, force], + inputs=[f, bc_mask, missing_mask, force], dim=f.shape[1:], ) return force.numpy()[0] diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index d3fbb94..18b2167 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -44,7 +44,7 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK", forcing_ @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): + def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ @@ -61,7 +61,7 @@ def jax_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): f_post_stream = bc( f_0, f_post_stream, - bc_id, + bc_mask, missing_mask, ) @@ -76,12 +76,12 @@ def jax_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, bc_id, missing_mask) + f_post_collision = bc.prepare_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, f_post_collision, - bc_id, + bc_mask, missing_mask, ) @@ -273,7 +273,7 @@ def get_bc_auxilary_data_3d( def kernel2d( f_0: wp.array3d(dtype=Any), f_1: wp.array3d(dtype=Any), - bc_id: wp.array3d(dtype=Any), + bc_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), bc_struct: Any, timestep: int, @@ -289,7 +289,7 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = bc_id[0, index[0], index[1]] + _boundary_id = bc_mask[0, index[0], index[1]] f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions @@ -321,7 +321,7 @@ def kernel2d( def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), - bc_id: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), bc_struct: Any, timestep: int, @@ -337,7 +337,7 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) - _boundary_id = bc_id[0, index[0], index[1], index[2]] + _boundary_id = bc_mask[0, index[0], index[1], index[2]] f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions @@ -365,7 +365,7 @@ def kernel3d( return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): # Get the boundary condition ids from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry @@ -406,7 +406,7 @@ def warp_implementation(self, f_0, f_1, bc_id, missing_mask, timestep): inputs=[ f_0, f_1, - bc_id, + bc_mask, missing_mask, bc_struct, timestep, From c4c994bfcf543d753b081c0b9cc5dd282968d07d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 19 Sep 2024 18:16:19 -0400 Subject: [PATCH 112/144] Separated zero and first kernels --- xlb/operator/macroscopic/__init__.py | 6 +- xlb/operator/macroscopic/first_moment.py | 83 +++++++++++++++++++++++ xlb/operator/macroscopic/macroscopic.py | 85 ++++++++++++++++++++++++ xlb/operator/macroscopic/zero_moment.py | 69 +++++++++++++++++++ 4 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 xlb/operator/macroscopic/first_moment.py create mode 100644 xlb/operator/macroscopic/macroscopic.py create mode 100644 xlb/operator/macroscopic/zero_moment.py diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 38195cd..75dec9e 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1,2 +1,4 @@ -from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic -from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment +from xlb.operator.macroscopic.macroscopic import Macroscopic +from xlb.operator.macroscopic.second_moment import SecondMoment +from xlb.operator.macroscopic.zero_moment import ZeroMoment +from xlb.operator.macroscopic.first_moment import FirstMoment diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py new file mode 100644 index 0000000..218458e --- /dev/null +++ b/xlb/operator/macroscopic/first_moment.py @@ -0,0 +1,83 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +class FirstMoment(Operator): + """A class to compute the first moment (velocity) of distribution functions.""" + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, f, rho): + u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho + return u + + def _construct_warp(self): + _c = self.velocity_set.c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + + @wp.func + def functional(f: _f_vec, rho: float): + u = _u_vec() + for l in range(self.velocity_set.q): + for d in range(self.velocity_set.d): + if _c[d, l] == 1: + u[d] += f[l] + elif _c[d, l] == -1: + u[d] -= f[l] + u /= rho + return u + + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + _u = functional(_f, _rho) + + for d in range(self.velocity_set.d): + u[d, index[0], index[1], index[2]] = _u[d] + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + i, j = wp.tid() + index = wp.vec2i(i, j) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _rho = rho[0, index[0], index[1]] + _u = functional(_f, _rho) + + for d in range(self.velocity_set.d): + u[d, index[0], index[1]] = _u[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho, u): + wp.launch( + self.warp_kernel, + inputs=[f, rho, u], + dim=u.shape[1:], + ) + return u \ No newline at end of file diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py new file mode 100644 index 0000000..b585b5e --- /dev/null +++ b/xlb/operator/macroscopic/macroscopic.py @@ -0,0 +1,85 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic.zero_moment import ZeroMoment +from xlb.operator.macroscopic.first_moment import FirstMoment + +class Macroscopic(Operator): + """A class to compute both zero and first moments of distribution functions (rho, u).""" + + def __init__(self, *args, **kwargs): + self.zero_moment = ZeroMoment(*args, **kwargs) + self.first_moment = FirstMoment(*args, **kwargs) + super().__init__(*args, **kwargs) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, f): + rho = self.zero_moment(f) + u = self.first_moment(f, rho) + return rho, u + + def _construct_warp(self): + zero_moment_func = self.zero_moment.warp_functional + first_moment_func = self.first_moment.warp_functional + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @wp.func + def functional(f: _f_vec): + rho = zero_moment_func(f) + u = first_moment_func(f, rho) + return rho, u + + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _rho, _u = functional(_f) + + rho[0, index[0], index[1], index[2]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1], index[2]] = _u[d] + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + i, j = wp.tid() + index = wp.vec2i(i, j) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _rho, _u = functional(_f) + + rho[0, index[0], index[1]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1]] = _u[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho, u): + wp.launch( + self.warp_kernel, + inputs=[f, rho, u], + dim=rho.shape[1:], + ) + return rho, u \ No newline at end of file diff --git a/xlb/operator/macroscopic/zero_moment.py b/xlb/operator/macroscopic/zero_moment.py new file mode 100644 index 0000000..a37ede7 --- /dev/null +++ b/xlb/operator/macroscopic/zero_moment.py @@ -0,0 +1,69 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +class ZeroMoment(Operator): + """A class to compute the zeroth moment (density) of distribution functions.""" + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, f): + return jnp.sum(f, axis=0, keepdims=True) + + def _construct_warp(self): + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @wp.func + def functional(f: _f_vec): + rho = self.compute_dtype(0.0) + for l in range(self.velocity_set.q): + rho += f[l] + return rho + + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + ): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _rho = functional(_f) + + rho[0, index[0], index[1], index[2]] = _rho + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + ): + i, j = wp.tid() + index = wp.vec2i(i, j) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _rho = functional(_f) + + rho[0, index[0], index[1]] = _rho + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho): + wp.launch( + self.warp_kernel, + inputs=[f, rho], + dim=rho.shape[1:], + ) + return rho \ No newline at end of file From e6de3371067bb230456b41f551c60a511faf6fd0 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 20 Sep 2024 13:33:49 -0400 Subject: [PATCH 113/144] Deleted old file --- .../macroscopic/zero_first_moments.py | 129 ------------------ 1 file changed, 129 deletions(-) delete mode 100644 xlb/operator/macroscopic/zero_first_moments.py diff --git a/xlb/operator/macroscopic/zero_first_moments.py b/xlb/operator/macroscopic/zero_first_moments.py deleted file mode 100644 index 48cf108..0000000 --- a/xlb/operator/macroscopic/zero_first_moments.py +++ /dev/null @@ -1,129 +0,0 @@ -# Base class for all equilibriums - -from functools import partial -import jax.numpy as jnp -from jax import jit -import warp as wp -from typing import Any - -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - - -class ZeroAndFirstMoments(Operator): - """ - A class to compute first and zeroth moments of distribution functions. - - TODO: Currently this is only used for the standard rho and u moments. - In the future, this should be extended to include higher order moments - and other physic types (e.g. temperature, electromagnetism, etc...) - """ - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), inline=True) - def jax_implementation(self, f): - """ - Apply the macroscopic operator to the lattice distribution function - TODO: Check if the following implementation is more efficient ( - as the compiler may be able to remove operations resulting in zero) - c_x = tuple(self.velocity_set.c[0]) - c_y = tuple(self.velocity_set.c[1]) - - u_x = 0.0 - u_y = 0.0 - - rho = jnp.sum(f, axis=0, keepdims=True) - - for i in range(self.velocity_set.q): - u_x += c_x[i] * f[i, ...] - u_y += c_y[i] * f[i, ...] - return rho, jnp.stack((u_x, u_y)) - """ - rho = jnp.sum(f, axis=0, keepdims=True) - u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho - - return rho, u - - def _construct_warp(self): - # Make constants for warp - _c = self.velocity_set.c - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - - # Construct the functional - @wp.func - def functional(f: _f_vec): - # Compute rho and u - rho = self.compute_dtype(0.0) - u = _u_vec() - for l in range(self.velocity_set.q): - rho += f[l] - for d in range(self.velocity_set.d): - if _c[d, l] == 1: - u[d] += f[l] - elif _c[d, l] == -1: - u[d] -= f[l] - u /= rho - - return rho, u - - # Construct the kernel - @wp.kernel - def kernel3d( - f: wp.array4d(dtype=Any), - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # Get the equilibrium - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1], index[2]] - (_rho, _u) = functional(_f) - - # Set the output - rho[0, index[0], index[1], index[2]] = _rho - for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] - - @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Get the equilibrium - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - (_rho, _u) = functional(_f) - - # Set the output - rho[0, index[0], index[1]] = _rho - for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, rho, u): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - f, - rho, - u, - ], - dim=rho.shape[1:], - ) - return rho, u From 448883f07999751d03389d7da71e894e64b8588a Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 20 Sep 2024 14:55:30 -0400 Subject: [PATCH 114/144] Changed rho type --- xlb/operator/macroscopic/first_moment.py | 8 ++++++-- xlb/operator/macroscopic/macroscopic.py | 3 ++- xlb/operator/macroscopic/second_moment.py | 9 +-------- xlb/operator/macroscopic/zero_moment.py | 9 +++------ 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py index 218458e..561fe7a 100644 --- a/xlb/operator/macroscopic/first_moment.py +++ b/xlb/operator/macroscopic/first_moment.py @@ -7,6 +7,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator + class FirstMoment(Operator): """A class to compute the first moment (velocity) of distribution functions.""" @@ -22,7 +23,10 @@ def _construct_warp(self): _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) @wp.func - def functional(f: _f_vec, rho: float): + def functional( + f: _f_vec, + rho: Any, + ): u = _u_vec() for l in range(self.velocity_set.q): for d in range(self.velocity_set.d): @@ -80,4 +84,4 @@ def warp_implementation(self, f, rho, u): inputs=[f, rho, u], dim=u.shape[1:], ) - return u \ No newline at end of file + return u diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index b585b5e..495a6a0 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -9,6 +9,7 @@ from xlb.operator.macroscopic.zero_moment import ZeroMoment from xlb.operator.macroscopic.first_moment import FirstMoment + class Macroscopic(Operator): """A class to compute both zero and first moments of distribution functions (rho, u).""" @@ -82,4 +83,4 @@ def warp_implementation(self, f, rho, u): inputs=[f, rho, u], dim=rho.shape[1:], ) - return rho, u \ No newline at end of file + return rho, u diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index bda2a25..917c86a 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -123,12 +123,5 @@ def kernel2d( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f, pi): # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - f, - pi, - ], - dim=pi.shape[1:], - ) + wp.launch(self.warp_kernel, inputs=[f, pi], dim=pi.shape[1:]) return pi diff --git a/xlb/operator/macroscopic/zero_moment.py b/xlb/operator/macroscopic/zero_moment.py index a37ede7..d0fbf51 100644 --- a/xlb/operator/macroscopic/zero_moment.py +++ b/xlb/operator/macroscopic/zero_moment.py @@ -7,6 +7,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator + class ZeroMoment(Operator): """A class to compute the zeroth moment (density) of distribution functions.""" @@ -61,9 +62,5 @@ def kernel2d( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f, rho): - wp.launch( - self.warp_kernel, - inputs=[f, rho], - dim=rho.shape[1:], - ) - return rho \ No newline at end of file + wp.launch(self.warp_kernel, inputs=[f, rho], dim=rho.shape[1:]) + return rho From 9496f6fb80bbccab2541cd5092889773908df23c Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Tue, 24 Sep 2024 09:38:02 -0400 Subject: [PATCH 115/144] Fixed mixed-precision for Warp --- examples/cfd/flow_past_sphere_3d.py | 2 +- examples/cfd/lid_driven_cavity_2d.py | 2 +- examples/cfd/turbulent_channel_3d.py | 2 +- examples/cfd/windtunnel_3d.py | 2 +- examples/performance/mlups_3d.py | 2 +- xlb/helper/initializers.py | 6 +- .../boundary_condition/bc_do_nothing.py | 4 +- .../boundary_condition/bc_equilibrium.py | 4 +- .../bc_extrapolation_outflow.py | 4 +- .../bc_fullway_bounce_back.py | 4 +- .../bc_halfway_bounce_back.py | 4 +- .../boundary_condition/bc_regularized.py | 4 +- xlb/operator/boundary_condition/bc_zouhe.py | 4 +- .../boundary_condition/boundary_condition.py | 8 +-- xlb/operator/collision/bgk.py | 4 +- xlb/operator/collision/kbc.py | 4 +- .../equilibrium/quadratic_equilibrium.py | 4 +- xlb/operator/force/exact_difference_force.py | 4 +- xlb/operator/macroscopic/first_moment.py | 4 +- xlb/operator/macroscopic/macroscopic.py | 8 +-- xlb/operator/macroscopic/second_moment.py | 4 +- xlb/operator/stepper/nse_stepper.py | 57 +++++++++---------- xlb/operator/stream/stream.py | 9 +-- xlb/precision_policy.py | 11 +--- 24 files changed, 74 insertions(+), 87 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 8c99dff..2487919 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -94,7 +94,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index c67ce8e..f94e209 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -61,7 +61,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index ea8cad7..65b56bf 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -102,7 +102,7 @@ def initialize_fields(self): u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init) else: u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype) - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend, u=u_init) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init) def setup_stepper(self): force = self.get_force() diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 92c94aa..077ae98 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -122,7 +122,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 32c3d6d..907c1f2 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -100,7 +100,7 @@ def main(): backend, precision_policy = setup_simulation(args) velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge) - f_0 = initialize_eq(f_0, grid, velocity_set, backend) + f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, backend) elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py index c8439d9..ccb4a82 100644 --- a/xlb/helper/initializers.py +++ b/xlb/helper/initializers.py @@ -2,11 +2,11 @@ from xlb.operator.equilibrium import QuadraticEquilibrium -def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None): +def initialize_eq(f, grid, velocity_set, precision_policy, backend, rho=None, u=None): if rho is None: - rho = grid.create_field(cardinality=1, fill_value=1.0) + rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision) if u is None: - u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0) + u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0, dtype=precision_policy.compute_precision) equilibrium = QuadraticEquilibrium() if backend == ComputeBackend.JAX: diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index dcdc8fd..0ddbcfc 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -86,7 +86,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -112,7 +112,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 716fd8e..6d4e3ed 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -111,7 +111,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -137,7 +137,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 87c6850..55a5851 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -229,7 +229,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -270,7 +270,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 85bffb1..57d29fd 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -95,7 +95,7 @@ def kernel2d( # Write the result to the output for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -121,7 +121,7 @@ def kernel3d( # Write the result to the output for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index c2483c9..e723570 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -110,7 +110,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -136,7 +136,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 7879137..a42b695 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -344,7 +344,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -370,7 +370,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d if self.velocity_set.d == 3 and self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 0bf68b8..4e9fe29 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -352,7 +352,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -378,7 +378,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d if self.velocity_set.d == 3 and self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 8d47929..9f6ef5d 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -81,8 +81,8 @@ def _get_thread_data_2d( _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] + _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1]]) + _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1]]) # TODO fix vec bool if missing_mask[l, index[0], index[1]]: @@ -106,8 +106,8 @@ def _get_thread_data_3d( _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] + _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]]) + _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]]) # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 196e3ba..60f63ef 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -59,7 +59,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -86,7 +86,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 7984829..bc731c6 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -337,7 +337,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -369,7 +369,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 0cce91b..ba337f0 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -79,7 +79,7 @@ def kernel3d( # Set the output for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = feq[l] + f[l, index[0], index[1], index[2]] = self.store_dtype(feq[l]) @wp.kernel def kernel2d( @@ -100,7 +100,7 @@ def kernel2d( # Set the output for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = feq[l] + f[l, index[0], index[1]] = self.store_dtype(feq[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/force/exact_difference_force.py b/xlb/operator/force/exact_difference_force.py index f148e12..b4da602 100644 --- a/xlb/operator/force/exact_difference_force.py +++ b/xlb/operator/force/exact_difference_force.py @@ -108,7 +108,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -134,7 +134,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py index 561fe7a..329a71f 100644 --- a/xlb/operator/macroscopic/first_moment.py +++ b/xlb/operator/macroscopic/first_moment.py @@ -53,7 +53,7 @@ def kernel3d( _u = functional(_f, _rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] + u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) @wp.kernel def kernel2d( @@ -71,7 +71,7 @@ def kernel2d( _u = functional(_f, _rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] + u[d, index[0], index[1]] = self.store_dtype(_u[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 495a6a0..b574436 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -50,9 +50,9 @@ def kernel3d( _f[l] = f[l, index[0], index[1], index[2]] _rho, _u = functional(_f) - rho[0, index[0], index[1], index[2]] = _rho + rho[0, index[0], index[1], index[2]] = self.store_dtype(_rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] + u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) @wp.kernel def kernel2d( @@ -68,9 +68,9 @@ def kernel2d( _f[l] = f[l, index[0], index[1]] _rho, _u = functional(_f) - rho[0, index[0], index[1]] = _rho + rho[0, index[0], index[1]] = self.store_dtype(_rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] + u[d, index[0], index[1]] = self.store_dtype(_u[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index 917c86a..687b38a 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -95,7 +95,7 @@ def kernel3d( # Set the output for d in range(_pi_dim): - pi[d, index[0], index[1], index[2]] = _pi[d] + pi[d, index[0], index[1], index[2]] = self.store_dtype(_pi[d]) @wp.kernel def kernel2d( @@ -114,7 +114,7 @@ def kernel2d( # Set the output for d in range(_pi_dim): - pi[d, index[0], index[1]] = _pi[d] + pi[d, index[0], index[1]] = self.store_dtype(_pi[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 18b2167..05cee7b 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -190,18 +190,18 @@ def get_thread_data_2d( index: Any, ): # Get the boundary id and missing mask - f_post_collision = _f_vec() + _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1]] + _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1]]) # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_post_collision, _missing_mask + return _f_post_collision, _missing_mask @wp.func def get_thread_data_3d( @@ -210,18 +210,18 @@ def get_thread_data_3d( index: Any, ): # Get the boundary id and missing mask - f_post_collision = _f_vec() + _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1], index[2]]) # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_post_collision, _missing_mask + return _f_post_collision, _missing_mask @wp.func def get_bc_auxilary_data_2d( @@ -243,7 +243,7 @@ def get_bc_auxilary_data_2d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1]] + f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) return f_auxiliary @wp.func @@ -266,7 +266,7 @@ def get_bc_auxilary_data_3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] + f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) return f_auxiliary @wp.kernel @@ -283,38 +283,33 @@ def kernel2d( index = wp.vec2i(i, j) # TODO warp should fix this # Read thread data for populations and missing mask - f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + _f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) _boundary_id = bc_mask[0, index[0], index[1]] - f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u - rho, u = self.macroscopic.warp_functional(f_post_stream) + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) # Compute equilibrium - feq = self.equilibrium.warp_functional(rho, u) + _feq = self.equilibrium.warp_functional(_rho, _u) # Apply collision - f_post_collision = self.collision.warp_functional( - f_post_stream, - feq, - rho, - u, - ) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = f_post_collision[l] + f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) # Construct the kernel @wp.kernel @@ -331,33 +326,33 @@ def kernel3d( index = wp.vec3i(i, j, k) # TODO warp should fix this # Read thread data for populations and missing mask - f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + _f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) _boundary_id = bc_mask[0, index[0], index[1], index[2]] - f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u - rho, u = self.macroscopic.warp_functional(f_post_stream) + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) # Compute equilibrium - feq = self.equilibrium.warp_functional(rho, u) + _feq = self.equilibrium.warp_functional(_rho, _u) # Apply collision - f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1], index[2]] = f_post_collision[l] + f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index d96c307..dc2417a 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -76,7 +76,7 @@ def functional2d( pull_index[d] = 0 # Read the distribution function - _f[l] = f[l, pull_index[0], pull_index[1]] + _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1]]) return _f @@ -94,7 +94,7 @@ def kernel2d( # Write the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = _f[l] + f_1[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the funcional to get streamed indices @wp.func @@ -117,7 +117,8 @@ def functional3d( pull_index[d] = 0 # Read the distribution function - _f[l] = f[l, pull_index[0], pull_index[1], pull_index[2]] + # Unlike other functionals, we need to cast the type here since we read from the buffer + _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1], pull_index[2]]) return _f @@ -136,7 +137,7 @@ def kernel3d( # Write the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1], index[2]] = _f[l] + f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 3b0f85f..d85deed 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -1,7 +1,6 @@ # Enum for precision policy from enum import Enum, auto - import jax.numpy as jnp import warp as wp @@ -87,12 +86,4 @@ def cast_to_compute_jax(self, array): def cast_to_store_jax(self, array): store_precision = self.store_precision - return jnp.array(array, dtype=store_precision.jax_dtype) - - def cast_to_compute_warp(self, array): - compute_precision = self.compute_precision - return wp.array(array, dtype=compute_precision.wp_dtype) - - def cast_to_store_warp(self, array): - store_precision = self.store_precision - return wp.array(array, dtype=store_precision.wp_dtype) + return jnp.array(array, dtype=store_precision.jax_dtype) \ No newline at end of file From 53910bc61a138e577196df49b20110303773f56b Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 26 Sep 2024 10:33:57 -0400 Subject: [PATCH 116/144] minor renaming --- xlb/operator/boundary_condition/bc_extrapolation_outflow.py | 4 ++-- xlb/operator/boundary_condition/bc_fullway_bounce_back.py | 2 +- xlb/operator/boundary_condition/bc_halfway_bounce_back.py | 2 +- xlb/operator/boundary_condition/bc_regularized.py | 2 +- xlb/operator/boundary_condition/bc_zouhe.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 55a5851..9e4812c 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -123,7 +123,7 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) @@ -182,7 +182,7 @@ def prepare_bc_auxilary_data( missing_mask: Any, ): # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and - # f_pre (posti-streaming values of the current voxel). We use directions that leave the domain + # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. _f = f_post for l in range(self.velocity_set.q): diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 57d29fd..ec298b2 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -48,7 +48,7 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index e723570..4a6a97f 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -54,7 +54,7 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index a42b695..1854dce 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -105,7 +105,7 @@ def regularize_fpop(self, fpop, feq): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): # creat a mask to slice boundary cells boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 4e9fe29..f40cb22 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -159,7 +159,7 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def apply_jax(self, f_pre, f_post, bc_mask, missing_mask): + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): # creat a mask to slice boundary cells boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] From b913f3505dc6e6ba2af9b9bdb562954f5a3c183e Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 26 Sep 2024 15:26:56 -0400 Subject: [PATCH 117/144] WIP: adding boundary_distance for curved and moving boundaries --- xlb/operator/boundary_masker/__init__.py | 3 + .../boundary_masker/mesh_boundary_masker.py | 9 +- .../mesh_grid_boundary_distance.py | 147 ++++++++++++++++++ 3 files changed, 153 insertions(+), 6 deletions(-) create mode 100644 xlb/operator/boundary_masker/mesh_grid_boundary_distance.py diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index 20b16b5..c2b0358 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -4,3 +4,6 @@ from xlb.operator.boundary_masker.mesh_boundary_masker import ( MeshBoundaryMasker as MeshBoundaryMasker, ) +from xlb.operator.boundary_masker.mesh_grid_boundary_distance import ( + MeshGridBoundaryDistance as MeshGridBoundaryDistance, +) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index b88d251..ac97111 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -83,13 +83,10 @@ def kernel( ) # evaluate if point is inside mesh - face_index = int(0) - face_u = float(0.0) - face_v = float(0.0) - sign = float(0.0) - if wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length, sign, face_index, face_u, face_v): + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos, max_length) + if query.result: # set point to be solid - if sign <= 0: # TODO: fix this + if query.sign <= 0: # TODO: fix this # Stream indices for l in range(_q): # Get the index of the streaming direction diff --git a/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py b/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py new file mode 100644 index 0000000..e040db4 --- /dev/null +++ b/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py @@ -0,0 +1,147 @@ +# Base class for all equilibriums + +import numpy as np +import warp as wp +import jax +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + + +class MeshGridBoundaryDistance(Operator): + """ + Operator for creating a boundary missing_mask from an STL file + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.WARP, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + # Also using Warp kernels for JAX implementation + if self.compute_backend == ComputeBackend.JAX: + self.warp_functional, self.warp_kernel = self._construct_warp() + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation( + self, + mesh_vertices, + origin, + spacing, + missing_mask, + boundary_distance, + start_index=(0, 0, 0), + ): + raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.c + _q = wp.constant(self.velocity_set.q) + + @wp.func + def index_to_position(index: wp.vec3i, origin: wp.vec3, spacing: wp.vec3): + # position of the point + ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) + ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center + pos = wp.cw_mul(ijk, spacing) + origin + return pos + + # Construct the warp kernel + @wp.kernel + def kernel( + mesh_id: wp.uint64, + origin: wp.vec3, + spacing: wp.vec3, + missing_mask: wp.array4d(dtype=wp.bool), + boundary_distance: wp.array4d(dtype=wp.float32), + start_index: wp.vec3i, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i() + index[0] = i - start_index[0] + index[1] = j - start_index[1] + index[2] = k - start_index[2] + + # position of the point + pos_solid_cell = index_to_position(index, origin, spacing) + + # Compute the maximum length + max_length = wp.sqrt( + (spacing[0] * wp.float32(missing_mask.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(missing_mask.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(missing_mask.shape[3])) ** 2.0 + ) + + # evaluate if point is inside mesh + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_solid_cell, max_length) + if query.result: + # set point to be solid + if query.sign <= 0: # TODO: fix this + # get position of the mesh triangle that intersects with the solid cell + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + + # Stream indices + for l in range(1, _q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and missing_mask + if missing_mask[l, push_index[0], push_index[1], push_index[2]]: + pos_fluid_cell = index_to_position(push_index, origin, spacing) + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_fluid_cell, max_length) + if query.result and query.sign > 0: + # get signed-distance field of the fluid voxel (i.e. sdf_f) + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + weight = wp.length(pos_fluid_cell - pos_mesh) / wp.length(pos_fluid_cell - pos_solid_cell) + boundary_distance[l, push_index[0], push_index[1], push_index[2]] = weight + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + mesh_vertices, + origin, + spacing, + missing_mask, + boundary_distance, + start_index=(0, 0, 0), + ): + assert mesh_vertices is not None, "Please provide the mesh vertices for which the boundary_distace wrt grid is sought!" + assert mesh_vertices.shape[1] == self.velocity_set.d, "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + assert ( + boundary_distance is not None and boundary_distance.shape == missing_mask.shape + ), 'To compute "boundary_distance" for this BC a field with the same shape as "missing_mask" must be prvided!' + + mesh_indices = np.arange(mesh_vertices.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_vertices, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=int), + ) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + mesh.id, + origin, + spacing, + missing_mask, + boundary_distance, + start_index, + ], + dim=missing_mask.shape[1:], + ) + + return boundary_distance From 8cfbd501c5fa4c23d092939ada0f2d6cdf56b9bc Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Sun, 29 Sep 2024 18:46:26 -0400 Subject: [PATCH 118/144] storing mesh distance in f_0 --- .../boundary_masker/mesh_boundary_masker.py | 3 +- .../mesh_grid_boundary_distance.py | 70 +++++++++++++------ xlb/precision_policy.py | 2 +- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index ac97111..701fc13 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -33,6 +33,7 @@ def jax_implementation( bc, origin, spacing, + id_number, bc_mask, missing_mask, start_index=(0, 0, 0), @@ -110,7 +111,7 @@ def warp_implementation( missing_mask, start_index=(0, 0, 0), ): - assert bc.mesh_vertices is not None, f'Please provide the mesh points for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' + assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" assert ( bc.mesh_vertices.shape[1] == self.velocity_set.d diff --git a/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py b/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py index e040db4..b63e664 100644 --- a/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py +++ b/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py @@ -3,6 +3,7 @@ import numpy as np import warp as wp import jax +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend @@ -30,11 +31,13 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) def jax_implementation( self, - mesh_vertices, + bc, origin, spacing, + id_number, + bc_mask, missing_mask, - boundary_distance, + f_0, start_index=(0, 0, 0), ): raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") @@ -43,6 +46,7 @@ def _construct_warp(self): # Make constants for warp _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) + _opp_indices = self.velocity_set.opp_indices @wp.func def index_to_position(index: wp.vec3i, origin: wp.vec3, spacing: wp.vec3): @@ -58,8 +62,10 @@ def kernel( mesh_id: wp.uint64, origin: wp.vec3, spacing: wp.vec3, + id_number: wp.int32, + bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - boundary_distance: wp.array4d(dtype=wp.float32), + f_0: wp.array4d(dtype=Any), start_index: wp.vec3i, ): # get index @@ -86,10 +92,8 @@ def kernel( if query.result: # set point to be solid if query.sign <= 0: # TODO: fix this - # get position of the mesh triangle that intersects with the solid cell - pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) - # Stream indices + missing_mask[0, index[0], index[1], index[2]] = True for l in range(1, _q): # Get the index of the streaming direction push_index = wp.vec3i() @@ -97,32 +101,44 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and missing_mask - if missing_mask[l, push_index[0], push_index[1], push_index[2]]: - pos_fluid_cell = index_to_position(push_index, origin, spacing) - query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_fluid_cell, max_length) - if query.result and query.sign > 0: - # get signed-distance field of the fluid voxel (i.e. sdf_f) - pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) - weight = wp.length(pos_fluid_cell - pos_mesh) / wp.length(pos_fluid_cell - pos_solid_cell) - boundary_distance[l, push_index[0], push_index[1], push_index[2]] = weight + bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + + # find neighbouring fluid cell + pos_fluid_cell = index_to_position(push_index, origin, spacing) + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_fluid_cell, max_length) + if query.result and query.sign > 0: + # get position of the mesh triangle that intersects with the solid cell + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + weight = wp.length(pos_fluid_cell - pos_mesh) / wp.length(pos_fluid_cell - pos_solid_cell) + f_0[_opp_indices[l], push_index[0], push_index[1], push_index[2]] = weight return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation( self, - mesh_vertices, + bc, origin, spacing, + bc_mask, missing_mask, - boundary_distance, + f_0, start_index=(0, 0, 0), ): - assert mesh_vertices is not None, "Please provide the mesh vertices for which the boundary_distace wrt grid is sought!" - assert mesh_vertices.shape[1] == self.velocity_set.d, "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' + assert bc.indices is None, f"The boundary_distance of {bc.__class__.__name__} cannot be found without a mesh!" + assert ( + bc.mesh_vertices.shape[1] == self.velocity_set.d + ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" assert ( - boundary_distance is not None and boundary_distance.shape == missing_mask.shape - ), 'To compute "boundary_distance" for this BC a field with the same shape as "missing_mask" must be prvided!' + f_0 is not None and f_0.shape == missing_mask.shape + ), 'To compute and store the "boundary_distance" for this BC, input the population field "f_0"!' + mesh_vertices = bc.mesh_vertices + id_number = bc.id + + # We are done with bc.mesh_vertices. Remove them from BC objects + bc.__dict__.pop("mesh_vertices", None) mesh_indices = np.arange(mesh_vertices.shape[0]) mesh = wp.Mesh( @@ -130,18 +146,26 @@ def warp_implementation( indices=wp.array(mesh_indices, dtype=int), ) + # Convert input tuples to warp vectors + origin = wp.vec3(origin[0], origin[1], origin[2]) + spacing = wp.vec3(spacing[0], spacing[1], spacing[2]) + start_index = wp.vec3i(start_index[0], start_index[1], start_index[2]) + mesh_id = wp.uint64(mesh.id) + # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ - mesh.id, + mesh_id, origin, spacing, + id_number, + bc_mask, missing_mask, - boundary_distance, + f_0, start_index, ], dim=missing_mask.shape[1:], ) - return boundary_distance + return bc_mask, missing_mask, f_0 diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index d85deed..7d31c8a 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -86,4 +86,4 @@ def cast_to_compute_jax(self, array): def cast_to_store_jax(self, array): store_precision = self.store_precision - return jnp.array(array, dtype=store_precision.jax_dtype) \ No newline at end of file + return jnp.array(array, dtype=store_precision.jax_dtype) From 9cb6bf39101998116f9a5330769f581fcf0a52ed Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Sun, 29 Sep 2024 18:50:29 -0400 Subject: [PATCH 119/144] renaming file --- xlb/operator/boundary_masker/__init__.py | 4 ++-- ...ce.py => mesh_distance_boundary_masker.py} | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) rename xlb/operator/boundary_masker/{mesh_grid_boundary_distance.py => mesh_distance_boundary_masker.py} (90%) diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index c2b0358..fbe851d 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -4,6 +4,6 @@ from xlb.operator.boundary_masker.mesh_boundary_masker import ( MeshBoundaryMasker as MeshBoundaryMasker, ) -from xlb.operator.boundary_masker.mesh_grid_boundary_distance import ( - MeshGridBoundaryDistance as MeshGridBoundaryDistance, +from xlb.operator.boundary_masker.mesh_distance_boundary_masker import ( + MeshDistanceBoundaryMasker as MeshDistanceBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py similarity index 90% rename from xlb/operator/boundary_masker/mesh_grid_boundary_distance.py rename to xlb/operator/boundary_masker/mesh_distance_boundary_masker.py index b63e664..6176edb 100644 --- a/xlb/operator/boundary_masker/mesh_grid_boundary_distance.py +++ b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py @@ -10,7 +10,7 @@ from xlb.operator.operator import Operator -class MeshGridBoundaryDistance(Operator): +class MeshDistanceBoundaryMasker(Operator): """ Operator for creating a boundary missing_mask from an STL file """ @@ -37,7 +37,7 @@ def jax_implementation( id_number, bc_mask, missing_mask, - f_0, + f_field, start_index=(0, 0, 0), ): raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") @@ -65,7 +65,7 @@ def kernel( id_number: wp.int32, bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), - f_0: wp.array4d(dtype=Any), + f_field: wp.array4d(dtype=Any), start_index: wp.vec3i, ): # get index @@ -111,7 +111,7 @@ def kernel( # get position of the mesh triangle that intersects with the solid cell pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) weight = wp.length(pos_fluid_cell - pos_mesh) / wp.length(pos_fluid_cell - pos_solid_cell) - f_0[_opp_indices[l], push_index[0], push_index[1], push_index[2]] = weight + f_field[_opp_indices[l], push_index[0], push_index[1], push_index[2]] = self.store_dtype(weight) return None, kernel @@ -123,17 +123,17 @@ def warp_implementation( spacing, bc_mask, missing_mask, - f_0, + f_field, start_index=(0, 0, 0), ): assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' - assert bc.indices is None, f"The boundary_distance of {bc.__class__.__name__} cannot be found without a mesh!" + assert bc.indices is None, f"Cannot find the implicit distance to the boundary for {bc.__class__.__name__} without a mesh!" assert ( bc.mesh_vertices.shape[1] == self.velocity_set.d ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" assert ( - f_0 is not None and f_0.shape == missing_mask.shape - ), 'To compute and store the "boundary_distance" for this BC, input the population field "f_0"!' + f_field is not None and f_field.shape == missing_mask.shape + ), 'To compute and store the implicit distance to the boundary for this BC, use a population field!' mesh_vertices = bc.mesh_vertices id_number = bc.id @@ -162,10 +162,10 @@ def warp_implementation( id_number, bc_mask, missing_mask, - f_0, + f_field, start_index, ], dim=missing_mask.shape[1:], ) - return bc_mask, missing_mask, f_0 + return bc_mask, missing_mask, f_field From cc09003988f989ae37519fd6cd74f4a0e78df2bf Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Mon, 30 Sep 2024 11:18:17 -0400 Subject: [PATCH 120/144] WIP: reformulated Grads BC --- .../bc_grads_approximation.py | 344 ++++++++++++++++++ .../boundary_condition/bc_regularized.py | 7 +- xlb/operator/collision/kbc.py | 6 +- xlb/operator/stepper/nse_stepper.py | 109 +++--- 4 files changed, 409 insertions(+), 57 deletions(-) create mode 100644 xlb/operator/boundary_condition/bc_grads_approximation.py diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py new file mode 100644 index 0000000..4193253 --- /dev/null +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -0,0 +1,344 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit +import jax.lax as lax +from functools import partial +import warp as wp +from typing import Any +from collections import Counter +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.macroscopic.zero_moment import ZeroMoment +from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + + +class GradsApproximationBC(BoundaryCondition): + """ + Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and + Dirichlet BCs [2] + [1] S. Chikatamarla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann + simulations", Europhys. Lett. 74, 215 (2006). + [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + mesh_vertices=None, + ): + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + indices, + mesh_vertices, + ) + + # Instantiate the operator for computing macroscopic values + self.macroscopic = Macroscopic() + self.zero_moment = ZeroMoment() + self.equilibrium = QuadraticEquilibrium() + self.momentum_flux = MomentumFlux() + + # if indices is not None: + # # this BC would be limited to stationary boundaries + # # assert mesh_vertices is None + # if mesh_vertices is not None: + # # this BC would be applicable for stationary and moving boundaries + # assert indices is None + # if mesh_velocity_function is not None: + # # mesh is moving and/or deforming + + assert self.compute_backend == ComputeBackend.WARP, "This BC is currently only implemented with the Warp backend!" + + # Unpack the two warp functionals needed for this BC! + if self.compute_backend == ComputeBackend.WARP: + self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # TODO + raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") + return + + def _construct_warp(self): + # Set local variables and constants + _c = self.velocity_set.c + _q = self.velocity_set.q + _d = self.velocity_set.d + _w = self.velocity_set.w + _qi = self.velocity_set.qi + _opp_indices = self.velocity_set.opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _u_wall = _u_vec(self.u[0], self.u[1], self.u[2]) if _d == 3 else _u_vec(self.u[0], self.u[1]) + diagonal = wp.vec3i(0, 3, 5) if _d == 3 else wp.vec2i(0, 2) + + @wp.func + def grads_approximate_fpop( + f_post: Any, + missing_mask: Any, + ): + """ + Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and + Dirichlet BCs [2] + [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann + simulations", Europhys. Lett. 74, 215 (2006). + [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + + Note: See also self.regularize_fpop function which is somewhat similar. + """ + # Compute density, velocity and pressure tensor Pi using all f_post-streaming values + rho, u = self.macroscopic.warp_functional(f_post) + Pi = self.momentum_flux.warp_functional(f_post) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + # If the mask is missing then use f_post + if missing_mask[l] == wp.uint8(1): + QiPi = self.compute_dtype(0.0) + for t in range(nt): + if t in diagonal: + Pi[t] -= rho/3. + QiPi += _qi[l, t] * Pi[t] + + # Compute c.u + cu = self.compute_dtype(0.0) + for d in range(self.velocity_set.d): + if _c[d, l] == 1: + cu += u[d] + elif _c[d, l] == -1: + cu -= u[d] + cu *= self.compute_dtype(3.0) + + # change f_post using the Grad's approximation + f_post[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu + self.compute_dtype(4.5) * QiPi / rho) + + return f_post + + # Construct the functionals for this BC + @wp.func + def functional( + f_pre: Any, + f_post: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + index: Any, + ): + # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. + # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). + # NOTE: f_aux should contain populations at "x_f" (see their fig 1) in the missign direction of the BC which amounts + # to post-collision values being pulled from appropriate cells like ExtrapolationBC + # + # here I need to compute all terms in Eq (10) + # Strategy: + # 1) "weights" should have been stored somewhere to be used here. + # 2) Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) + # NOTE: in the original paper "u_target" is associated with the previous time step not current time. + # 3) Given "weights" use differentiable interpolated BB to find f_missing as I had before: + # fmissing = ((1. - weights) * f_poststreaming_iknown + weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + weights) + # 4) Add contribution due to u_w to f_missing as is usual in regular Bouzidi BC (ie. -6.0 * self.lattice.w * jnp.dot(self.vel, c) + # 5) Compute rho_target = \sum(f_ibb) based on these values + # 6) Compute feq using feq = self.equilibrium(rho_target, u_target) + # 7) Compute Pi_neq and Pi_eq using all f_post-streaming values as per: + # Pi_neq = self.momentum_flux(fneq) and Pi_eq = self.momentum_flux(feq) + # 8) Compute Grad's appriximation using full equation as in Eq (10) + # NOTE: this is very similar to the regularization procedure. + + # _f_nbr = _f_vec() + # u_target = _u_vec(0.0, 0.0, 0.0) if _d == 3 else _u_vec(0.0, 0.0) + # num_missing = 0 + one = self.compute_dtype(1.0) + for l in range(_q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + + # # Find the neighbour and its velocity value + # for ll in range(_q): + # # f_0 is the post-collision values of the current time-step + # # Get index associated with the fluid neighbours + # fluid_nbr_index = type(index)() + # for d in range(_d): + # fluid_nbr_index[d] = index[d] + _c[d, l] + # # The following is the post-collision values of the fluid neighbor cell + # _f_nbr[ll] = self.compute_dtype(f_0[ll, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) + + # # Compute the velocity vector at the fluid neighbouring cells + # _, u_f = self.macroscopic.warp_functional(_f_nbr) + + # # Record the number of missing directions + # num_missing += 1 + + # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 + weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + + # # Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) + # for d in range(_d): + # u_target[d] += (weight * u_f[d] + _u_wall[d]) / (one + weight) + + # Use differentiable interpolated BB to find f_missing: + f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) + + # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + for ll in range(_q): + # Compute cu + cu = self.compute_dtype(0.0) + for d in range(_d): + if _c[d, l] == 1: + cu += _u_wall[d] + elif _c[d, l] == -1: + cu -= _u_wall[d] + cu *= self.compute_dtype(-6.0) * self.velocity_set.w + f_post[l] += cu + + # Compute rho_target = \sum(f_ibb) based on these values + # rho_target = self.zero_moment.warp_functional(f_post) + # for d in range(_d): + # u_target[d] /= num_missing + + # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. + f_post = grads_approximate_fpop(f_post, missing_mask) + + return f_post + + # Construct the warp kernel + @wp.kernel + def kernel2d( + f_pre: wp.array3d(dtype=Any), + f_post: wp.array3d(dtype=Any), + bc_mask: wp.array3d(dtype=wp.uint8), + missing_mask: wp.array3d(dtype=wp.bool), + ): + # Get the global index + i, j = wp.tid() + index = wp.vec2i(i, j) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) + _f_aux = _f_vec() + + # special preparation of auxiliary data + if _boundary_id == wp.uint8(GradsApproximation.id): + nv = get_normal_vectors_2d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] + + # Apply the boundary condition + if _boundary_id == wp.uint8(GradsApproximation.id): + # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both + # collision and streaming? + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) + + # Construct the warp kernel + @wp.kernel + def kernel3d( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # we need to store fractional distance during the initialization + mesh_indices = np.arange(mesh_vertices.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_vertices, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=int), + ) + # Compute the maximum length + max_length = wp.sqrt( + (grid_spacing[0] * wp.float32(grid_shape[0])) ** 2.0 + + (grid_spacing[1] * wp.float32(grid_shape[1])) ** 2.0 + + (grid_spacing[2] * wp.float32(grid_shape[2])) ** 2.0 + ) + + query = wp.mesh_query_point_sign_normal(mesh.id, xpred, max_length) + if query.result: + p = wp.mesh_eval_position(mesh, query.face, query.u, query.v) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) + _f_aux = _f_vec() + + # special preparation of auxiliary data + if _boundary_id == wp.uint8(GradsApproximation.id): + nv = get_normal_vectors_3d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] + + # Apply the boundary condition + if _boundary_id == wp.uint8(GradsApproximation.id): + # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both + # collision and streaming? + _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + else: + _f = _f_post + + # Write the distribution function + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 1854dce..8d8c0fc 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -200,15 +200,14 @@ def regularize_fpop( # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) nt = _d * (_d + 1) // 2 - QiPi1 = _f_vec() for l in range(_q): - QiPi1[l] = self.compute_dtype(0.0) + QiPi1 = self.compute_dtype(0.0) for t in range(nt): - QiPi1[l] += _qi[l, t] * PiNeq[t] + QiPi1 += _qi[l, t] * PiNeq[t] # assign all populations based on eq 45 of Latt et al (2008) # fneq ~ f^1 - fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1[l] + fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 fpop[l] = feq[l] + fpop1 return fpop diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index bc731c6..cc2fb04 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -217,9 +217,9 @@ def decompose_shear_d3q27( s = _f_vec() # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) - two = self.self.compute_dtype(2.0) - four = self.self.compute_dtype(4.0) - six = self.self.compute_dtype(6.0) + two = self.compute_dtype(2.0) + four = self.compute_dtype(4.0) + six = self.compute_dtype(6.0) s[9] = (two * nxz - nyz) / six s[18] = (two * nxz - nyz) / six diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 05cee7b..bf8e45f 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -111,6 +111,7 @@ class BoundaryConditionIDStruct: id_RegularizedBC_velocity: wp.uint8 id_RegularizedBC_pressure: wp.uint8 id_ExtrapolationOutflowBC: wp.uint8 + id_GradsApproximationBC: wp.uint8 @wp.func def apply_post_streaming_bc( @@ -188,10 +189,15 @@ def get_thread_data_2d( f_0: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), index: Any, + _boundary_id: Any, + bc_struct: Any, ): # Get the boundary id and missing mask _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() + + # special preparation of auxiliary data + _f_aux_bc = _f_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1]]) @@ -199,19 +205,48 @@ def get_thread_data_2d( # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) + + # get _f_aux_bc data + if _boundary_id == bc_struct.id_GradsApproximationBC: + # f_0 is the post-collision values of the current time-step + # Get index associated with the fluid neighbours + fluid_nbr_index = type(index)() + for d in range(self.velocity_set.d): + fluid_nbr_index[d] = index[d] + _c[d, l] + # The following is the post-collision values of the fluid neighbor cell + _f_aux_bc[l] = self.compute_dtype(f_0[l, fluid_nbr_index[0], fluid_nbr_index[1]]) else: _missing_mask[l] = wp.uint8(0) - return _f_post_collision, _missing_mask + + # special treatment for obtaining _f_aux_bc in cases where all missing directions need to be known + if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: + nv = get_normal_vectors_2d(_missing_mask) + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + _f_aux_bc[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) + + return _f_post_collision, _missing_mask, _f_aux_bc @wp.func def get_thread_data_3d( f_0: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), index: Any, + _boundary_id: Any, + bc_struct: Any, ): # Get the boundary id and missing mask _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() + + # special preparation of auxiliary data + _f_aux_bc = _f_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1], index[2]]) @@ -219,43 +254,20 @@ def get_thread_data_3d( # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) - return _f_post_collision, _missing_mask - @wp.func - def get_bc_auxilary_data_2d( - f_0: wp.array3d(dtype=Any), - index: Any, - _boundary_id: Any, - _missing_mask: Any, - bc_struct: Any, - ): - # special preparation of auxiliary data - f_auxiliary = _f_vec() - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - nv = get_normal_vectors_2d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): + # get _f_aux_bc data + if _boundary_id == bc_struct.id_GradsApproximationBC: # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() + # Get index associated with the fluid neighbours + fluid_nbr_index = type(index)() for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) - return f_auxiliary + fluid_nbr_index[d] = index[d] + _c[d, l] + # The following is the post-collision values of the fluid neighbor cell + _f_aux_bc[l] = self.compute_dtype(f_0[l, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) + else: + _missing_mask[l] = wp.uint8(0) - @wp.func - def get_bc_auxilary_data_3d( - f_0: wp.array4d(dtype=Any), - index: Any, - _boundary_id: Any, - _missing_mask: Any, - bc_struct: Any, - ): - # special preparation of auxiliary data - f_auxiliary = _f_vec() + # special treatment for obtaining _f_aux_bc in cases where all missing directions need to be known if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: nv = get_normal_vectors_3d(_missing_mask) for l in range(self.velocity_set.q): @@ -266,8 +278,9 @@ def get_bc_auxilary_data_3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) - return f_auxiliary + _f_aux_bc[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) + + return _f_post_collision, _missing_mask, _f_aux_bc @wp.kernel def kernel2d( @@ -283,17 +296,15 @@ def kernel2d( index = wp.vec2i(i, j) # TODO warp should fix this # Read thread data for populations and missing mask - _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + # Also get auxilary data for BC (if applicable) + _boundary_id = bc_mask[0, index[0], index[1]] + _f_post_collision, _missing_mask, _f_aux_bc = get_thread_data_2d(f_0, missing_mask, index, _boundary_id, bc_struct) # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) - # Prepare auxilary data for BC (if applicable) - _boundary_id = bc_mask[0, index[0], index[1]] - _f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) - # Apply post-streaming type boundary conditions - _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -305,7 +316,7 @@ def kernel2d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): @@ -326,17 +337,15 @@ def kernel3d( index = wp.vec3i(i, j, k) # TODO warp should fix this # Read thread data for populations and missing mask - _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + # Also get auxilary data for BC (if applicable) + _boundary_id = bc_mask[0, index[0], index[1], index[2]] + _f_post_collision, _missing_mask, _f_aux_bc = get_thread_data_3d(f_0, missing_mask, index, _boundary_id, bc_struct) # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) - # Prepare auxilary data for BC (if applicable) - _boundary_id = bc_mask[0, index[0], index[1], index[2]] - _f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) - # Apply post-streaming type boundary conditions - _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -348,7 +357,7 @@ def kernel3d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): From 6e327062eecc5f41301934ea13b2f7b99b0875fb Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 1 Oct 2024 00:17:38 -0400 Subject: [PATCH 121/144] major improvement to BC functionals and nse_stepper --- .../boundary_condition/bc_do_nothing.py | 15 +- .../boundary_condition/bc_equilibrium.py | 15 +- .../bc_extrapolation_outflow.py | 82 +++++--- .../bc_fullway_bounce_back.py | 15 +- .../bc_grads_approximation.py | 3 +- .../bc_halfway_bounce_back.py | 15 +- .../boundary_condition/bc_regularized.py | 36 ++-- xlb/operator/boundary_condition/bc_zouhe.py | 36 ++-- .../boundary_condition/boundary_condition.py | 7 +- .../mesh_distance_boundary_masker.py | 2 +- xlb/operator/force/momentum_transfer.py | 9 +- xlb/operator/stepper/nse_stepper.py | 175 ++++++------------ 12 files changed, 205 insertions(+), 205 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 0ddbcfc..55ce9ed 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -56,10 +56,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): return f_pre @@ -79,8 +82,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -105,8 +108,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(DoNothingBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 6d4e3ed..8c33d29 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -79,10 +79,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f @@ -104,8 +107,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -130,8 +133,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(EquilibriumBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 9e4812c..8b5f139 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -160,10 +160,13 @@ def get_normal_vectors_3d( # Construct the functionals for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -171,23 +174,60 @@ def functional( # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): _f[l] = f_pre[_opp_indices[l]] - return _f @wp.func - def prepare_bc_auxilary_data( + def prepare_bc_auxilary_data_2d( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, + ): + # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and + # f_pre (post-streaming values of the current voxel). We use directions that leave the domain + # for storing this prepared data. + _f = f_post + nv = get_normal_vectors_2d(missing_mask) + for l in range(self.velocity_set.q): + if missing_mask[l] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) + _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux + return _f + + @wp.func + def prepare_bc_auxilary_data_3d( + index: Any, + timestep: Any, missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, ): # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. _f = f_post + nv = get_normal_vectors_3d(missing_mask) for l in range(self.velocity_set.q): if missing_mask[l] == wp.uint8(1): - _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux[l] + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - (_c[d, l] + nv[d]) + # The following is the post-streaming values of the neighbor cell + f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) + _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux return _f # Construct the warp kernel @@ -201,29 +241,20 @@ def kernel2d( # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) + timestep = 0 # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - _f_aux = _f_vec() # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - nv = get_normal_vectors_2d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] + _f_pre = prepare_bc_auxilary_data_2d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post) # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both # collision and streaming? - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -242,6 +273,7 @@ def kernel3d( # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) + timestep = 0 # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) @@ -249,22 +281,13 @@ def kernel3d( # special preparation of auxiliary data if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - nv = get_normal_vectors_3d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] + _f_pre = prepare_bc_auxilary_data_3d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post) # Apply the boundary condition if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -273,6 +296,7 @@ def kernel3d( f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + prepare_bc_auxilary_data = prepare_bc_auxilary_data_3d if self.velocity_set.d == 3 else prepare_bc_auxilary_data_2d return (functional, prepare_bc_auxilary_data), kernel diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index ec298b2..29f83c1 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -63,10 +63,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): fliped_f = _f_vec() for l in range(_q): @@ -88,8 +91,8 @@ def kernel2d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -114,8 +117,8 @@ def kernel3d( # Check if the boundary is active if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 4193253..ba26d88 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -127,7 +127,7 @@ def grads_approximate_fpop( QiPi = self.compute_dtype(0.0) for t in range(nt): if t in diagonal: - Pi[t] -= rho/3. + Pi[t] -= rho / 3.0 QiPi += _qi[l, t] * Pi[t] # Compute c.u @@ -181,7 +181,6 @@ def functional( for l in range(_q): # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): - # # Find the neighbour and its velocity value # for ll in range(_q): # # f_0 is the post-collision values of the current time-step diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 4a6a97f..6e787c2 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -71,10 +71,13 @@ def _construct_warp(self): # Construct the functional for this BC @wp.func def functional( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -103,8 +106,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -129,8 +132,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 8d8c0fc..bb4b5f0 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -213,10 +213,13 @@ def regularize_fpop( @wp.func def functional3d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -241,10 +244,13 @@ def functional3d_velocity( @wp.func def functional3d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -267,10 +273,13 @@ def functional3d_pressure( @wp.func def functional2d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -295,10 +304,13 @@ def functional2d_velocity( @wp.func def functional2d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -336,8 +348,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -362,8 +374,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_vec() - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index f40cb22..66b6377 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -234,10 +234,13 @@ def bounceback_nonequilibrium( @wp.func def functional3d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -259,10 +262,13 @@ def functional3d_velocity( @wp.func def functional3d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -282,10 +288,13 @@ def functional3d_pressure( @wp.func def functional2d_velocity( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -307,10 +316,13 @@ def functional2d_velocity( @wp.func def functional2d_pressure( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): # Post-streaming values are only modified at missing direction _f = f_post @@ -345,8 +357,8 @@ def kernel2d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -371,8 +383,8 @@ def kernel3d( # Apply the boundary condition if _boundary_id == wp.uint8(self.id): - _f_aux = _f_post - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 9f6ef5d..90ee127 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -59,10 +59,13 @@ def __init__( @wp.func def prepare_bc_auxilary_data( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, - f_aux: Any, - missing_mask: Any, ): return f_post diff --git a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py index 6176edb..fadd23f 100644 --- a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py @@ -133,7 +133,7 @@ def warp_implementation( ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" assert ( f_field is not None and f_field.shape == missing_mask.shape - ), 'To compute and store the implicit distance to the boundary for this BC, use a population field!' + ), "To compute and store the implicit distance to the boundary for this BC, use a population field!" mesh_vertices = bc.mesh_vertices id_number = bc.id diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index da64e67..7067ad0 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -152,7 +152,7 @@ def kernel2d( # Construct the warp kernel @wp.kernel def kernel3d( - f: wp.array4d(dtype=Any), + fpop: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -183,11 +183,12 @@ def kernel3d( # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = f[l, index[0], index[1], index[2]] + f_post_collision[l] = fpop[l, index[0], index[1], index[2]] # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + timestep = 0 + f_post_stream = self.stream.warp_functional(fpop, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, fpop, fpop, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index bf8e45f..ad97e94 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -92,10 +92,10 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update + _q = self.velocity_set.q + _d = self.velocity_set.d _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - _c = self.velocity_set.c - _q = self.velocity_set.q @wp.struct class BoundaryConditionIDStruct: @@ -115,172 +115,103 @@ class BoundaryConditionIDStruct: @wp.func def apply_post_streaming_bc( - f_pre: Any, - f_post: Any, - f_aux: Any, - missing_mask: Any, + index: Any, + timestep: Any, _boundary_id: Any, bc_struct: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, ): # Apply post-streaming type boundary conditions + # NOTE: 'f_pre' is included here as an input to the BC functionals for consistency with the BC API, + # particularly when compared to post-collision boundary conditions (see below). + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ZouHeBC_velocity: # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ZouHeBC_pressure: # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_RegularizedBC_velocity: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_RegularizedBC_pressure: # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_post @wp.func def apply_post_collision_bc( - f_pre: Any, - f_post: Any, - f_aux: Any, - missing_mask: Any, + index: Any, + timestep: Any, _boundary_id: Any, bc_struct: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, ): + # Apply post-collision type boundary conditions or special boundary preparations if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(f_pre, f_post, f_aux, missing_mask) + f_post = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # f_aux is the neighbour's post-streaming values # Storing post-streaming data in directions that leave the domain - f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask) + f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_post - @wp.func - def get_normal_vectors_2d( - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -wp.vec2i(_c[0, l], _c[1, l]) - - @wp.func - def get_normal_vectors_3d( - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -wp.vec3i(_c[0, l], _c[1, l], _c[2, l]) - @wp.func def get_thread_data_2d( - f_0: wp.array3d(dtype=Any), + f_buffer: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), index: Any, - _boundary_id: Any, - bc_struct: Any, ): - # Get the boundary id and missing mask - _f_post_collision = _f_vec() + # Read thread data for populations and missing mask + f_thread = _f_vec() _missing_mask = _missing_mask_vec() - - # special preparation of auxiliary data - _f_aux_bc = _f_vec() for l in range(self.velocity_set.q): - # q-sized vector of pre-streaming populations - _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1]]) - - # TODO fix vec bool + f_thread[l] = self.compute_dtype(f_buffer[l, index[0], index[1]]) if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) - - # get _f_aux_bc data - if _boundary_id == bc_struct.id_GradsApproximationBC: - # f_0 is the post-collision values of the current time-step - # Get index associated with the fluid neighbours - fluid_nbr_index = type(index)() - for d in range(self.velocity_set.d): - fluid_nbr_index[d] = index[d] + _c[d, l] - # The following is the post-collision values of the fluid neighbor cell - _f_aux_bc[l] = self.compute_dtype(f_0[l, fluid_nbr_index[0], fluid_nbr_index[1]]) else: _missing_mask[l] = wp.uint8(0) - - # special treatment for obtaining _f_aux_bc in cases where all missing directions need to be known - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - nv = get_normal_vectors_2d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux_bc[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) - - return _f_post_collision, _missing_mask, _f_aux_bc + return f_thread, _missing_mask @wp.func def get_thread_data_3d( - f_0: wp.array4d(dtype=Any), + f_buffer: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), index: Any, - _boundary_id: Any, - bc_struct: Any, ): - # Get the boundary id and missing mask - _f_post_collision = _f_vec() + # Read thread data for populations + f_thread = _f_vec() _missing_mask = _missing_mask_vec() - - # special preparation of auxiliary data - _f_aux_bc = _f_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1], index[2]]) - - # TODO fix vec bool + f_thread[l] = self.compute_dtype(f_buffer[l, index[0], index[1], index[2]]) if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) - - # get _f_aux_bc data - if _boundary_id == bc_struct.id_GradsApproximationBC: - # f_0 is the post-collision values of the current time-step - # Get index associated with the fluid neighbours - fluid_nbr_index = type(index)() - for d in range(self.velocity_set.d): - fluid_nbr_index[d] = index[d] + _c[d, l] - # The following is the post-collision values of the fluid neighbor cell - _f_aux_bc[l] = self.compute_dtype(f_0[l, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) else: _missing_mask[l] = wp.uint8(0) - # special treatment for obtaining _f_aux_bc in cases where all missing directions need to be known - if _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - nv = get_normal_vectors_3d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux_bc[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) - - return _f_post_collision, _missing_mask, _f_aux_bc + return f_thread, _missing_mask @wp.kernel def kernel2d( @@ -295,16 +226,17 @@ def kernel2d( i, j = wp.tid() index = wp.vec2i(i, j) # TODO warp should fix this - # Read thread data for populations and missing mask - # Also get auxilary data for BC (if applicable) + # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1]] - _f_post_collision, _missing_mask, _f_aux_bc = get_thread_data_2d(f_0, missing_mask, index, _boundary_id, bc_struct) # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) # Apply post-streaming type boundary conditions - _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) + _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + _f_post_stream = apply_post_streaming_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream + ) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -316,7 +248,9 @@ def kernel2d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision + ) # Set the output for l in range(self.velocity_set.q): @@ -336,16 +270,17 @@ def kernel3d( i, j, k = wp.tid() index = wp.vec3i(i, j, k) # TODO warp should fix this - # Read thread data for populations and missing mask - # Also get auxilary data for BC (if applicable) + # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1], index[2]] - _f_post_collision, _missing_mask, _f_aux_bc = get_thread_data_3d(f_0, missing_mask, index, _boundary_id, bc_struct) # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) # Apply post-streaming type boundary conditions - _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) + _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + _f_post_stream = apply_post_streaming_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream + ) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -357,7 +292,9 @@ def kernel3d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_aux_bc, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc( + index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision + ) # Set the output for l in range(self.velocity_set.q): From 8c0fe1d9e017c34c2d57b78ae5f1819452fc7ed1 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 1 Oct 2024 09:33:19 -0400 Subject: [PATCH 122/144] finalizing reformulated Grads BC --- xlb/operator/boundary_condition/__init__.py | 1 + .../bc_grads_approximation.py | 195 +++++++++--------- xlb/operator/stepper/nse_stepper.py | 4 +- 3 files changed, 99 insertions(+), 101 deletions(-) diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index b7ede03..925dfdc 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -9,3 +9,4 @@ from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC +from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC as GradsApproximationBC diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index ba26d88..a57954f 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -76,9 +76,6 @@ def __init__( assert self.compute_backend == ComputeBackend.WARP, "This BC is currently only implemented with the Warp backend!" - # Unpack the two warp functionals needed for this BC! - if self.compute_backend == ComputeBackend.WARP: - self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) @@ -102,21 +99,21 @@ def _construct_warp(self): @wp.func def grads_approximate_fpop( - f_post: Any, missing_mask: Any, + rho: Any, + u: Any, + f_post: Any, ): - """ - Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and - Dirichlet BCs [2] - [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann - simulations", Europhys. Lett. 74, 215 (2006). - [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and - stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. - - Note: See also self.regularize_fpop function which is somewhat similar. - """ - # Compute density, velocity and pressure tensor Pi using all f_post-streaming values - rho, u = self.macroscopic.warp_functional(f_post) + # Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and + # Dirichlet BCs [2] + # [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann + # simulations", Europhys. Lett. 74, 215 (2006). + # [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + + # Note: See also self.regularize_fpop function which is somewhat similar. + + # Compute pressure tensor Pi using all f_post-streaming values Pi = self.momentum_flux.warp_functional(f_post) # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) @@ -146,13 +143,54 @@ def grads_approximate_fpop( # Construct the functionals for this BC @wp.func - def functional( + def functional_method1( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, f_pre: Any, f_post: Any, + ): + # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. + # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). + one = self.compute_dtype(1.0) + for l in range(_q): + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 + weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + + # Use differentiable interpolated BB to find f_missing: + f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) + + # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + cu = self.compute_dtype(0.0) + for d in range(_d): + if _c[d, l] == 1: + cu += _u_wall[d] + elif _c[d, l] == -1: + cu -= _u_wall[d] + cu *= self.compute_dtype(-6.0) * self.velocity_set.w + f_post[l] += cu + + # Compute density, velocity using all f_post-streaming values + rho, u = self.macroscopic.warp_functional(f_post) + + # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. + f_post = grads_approximate_fpop(missing_mask, rho, u, f_post) + return f_post + + # Construct the functionals for this BC + @wp.func + def functional_method2( + index: Any, + timestep: Any, missing_mask: Any, f_0: Any, f_1: Any, - index: Any, + f_pre: Any, + f_post: Any, ): # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). @@ -174,59 +212,56 @@ def functional( # 8) Compute Grad's appriximation using full equation as in Eq (10) # NOTE: this is very similar to the regularization procedure. - # _f_nbr = _f_vec() - # u_target = _u_vec(0.0, 0.0, 0.0) if _d == 3 else _u_vec(0.0, 0.0) - # num_missing = 0 + _f_nbr = _f_vec() + u_target = _u_vec(0.0, 0.0, 0.0) if _d == 3 else _u_vec(0.0, 0.0) + num_missing = 0 one = self.compute_dtype(1.0) for l in range(_q): # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): - # # Find the neighbour and its velocity value - # for ll in range(_q): - # # f_0 is the post-collision values of the current time-step - # # Get index associated with the fluid neighbours - # fluid_nbr_index = type(index)() - # for d in range(_d): - # fluid_nbr_index[d] = index[d] + _c[d, l] - # # The following is the post-collision values of the fluid neighbor cell - # _f_nbr[ll] = self.compute_dtype(f_0[ll, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) - - # # Compute the velocity vector at the fluid neighbouring cells - # _, u_f = self.macroscopic.warp_functional(_f_nbr) - - # # Record the number of missing directions - # num_missing += 1 + # Find the neighbour and its velocity value + for ll in range(_q): + # f_0 is the post-collision values of the current time-step + # Get index associated with the fluid neighbours + fluid_nbr_index = type(index)() + for d in range(_d): + fluid_nbr_index[d] = index[d] + _c[d, l] + # The following is the post-collision values of the fluid neighbor cell + _f_nbr[ll] = self.compute_dtype(f_0[ll, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) + + # Compute the velocity vector at the fluid neighbouring cells + _, u_f = self.macroscopic.warp_functional(_f_nbr) + + # Record the number of missing directions + num_missing += 1 # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 weight = f_1[_opp_indices[l], index[0], index[1], index[2]] - # # Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) - # for d in range(_d): - # u_target[d] += (weight * u_f[d] + _u_wall[d]) / (one + weight) + # Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) + for d in range(_d): + u_target[d] += (weight * u_f[d] + _u_wall[d]) / (one + weight) # Use differentiable interpolated BB to find f_missing: f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC - for ll in range(_q): - # Compute cu - cu = self.compute_dtype(0.0) - for d in range(_d): - if _c[d, l] == 1: - cu += _u_wall[d] - elif _c[d, l] == -1: - cu -= _u_wall[d] - cu *= self.compute_dtype(-6.0) * self.velocity_set.w + cu = self.compute_dtype(0.0) + for d in range(_d): + if _c[d, l] == 1: + cu += _u_wall[d] + elif _c[d, l] == -1: + cu -= _u_wall[d] + cu *= self.compute_dtype(-6.0) * _w[l] f_post[l] += cu # Compute rho_target = \sum(f_ibb) based on these values - # rho_target = self.zero_moment.warp_functional(f_post) - # for d in range(_d): - # u_target[d] /= num_missing + rho_target = self.zero_moment.warp_functional(f_post) + for d in range(_d): + u_target[d] /= num_missing # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. - f_post = grads_approximate_fpop(f_post, missing_mask) - + f_post = grads_approximate_fpop(missing_mask, rho_target, u_target, f_post) return f_post # Construct the warp kernel @@ -240,29 +275,17 @@ def kernel2d( # Get the global index i, j = wp.tid() index = wp.vec2i(i, j) + timestep = 0 # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) _f_aux = _f_vec() - # special preparation of auxiliary data - if _boundary_id == wp.uint8(GradsApproximation.id): - nv = get_normal_vectors_2d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]] - # Apply the boundary condition - if _boundary_id == wp.uint8(GradsApproximation.id): + if _boundary_id == wp.uint8(GradsApproximationBC.id): # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both # collision and streaming? - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -281,46 +304,17 @@ def kernel3d( # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) - - # we need to store fractional distance during the initialization - mesh_indices = np.arange(mesh_vertices.shape[0]) - mesh = wp.Mesh( - points=wp.array(mesh_vertices, dtype=wp.vec3), - indices=wp.array(mesh_indices, dtype=int), - ) - # Compute the maximum length - max_length = wp.sqrt( - (grid_spacing[0] * wp.float32(grid_shape[0])) ** 2.0 - + (grid_spacing[1] * wp.float32(grid_shape[1])) ** 2.0 - + (grid_spacing[2] * wp.float32(grid_shape[2])) ** 2.0 - ) - - query = wp.mesh_query_point_sign_normal(mesh.id, xpred, max_length) - if query.result: - p = wp.mesh_eval_position(mesh, query.face, query.u, query.v) + timestep = 0 # read tid data _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) _f_aux = _f_vec() - # special preparation of auxiliary data - if _boundary_id == wp.uint8(GradsApproximation.id): - nv = get_normal_vectors_3d(_missing_mask) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - _f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]] - # Apply the boundary condition - if _boundary_id == wp.uint8(GradsApproximation.id): + if _boundary_id == wp.uint8(GradsApproximationBC.id): # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both # collision and streaming? - _f = functional(_f_pre, _f_post, _f_aux, _missing_mask) + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) else: _f = _f_post @@ -329,6 +323,7 @@ def kernel3d( f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + functional = functional_method1 return functional, kernel diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index ad97e94..57988f5 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -153,6 +153,9 @@ def apply_post_streaming_bc( elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Regularized boundary condition (bc type = velocity) f_post = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + elif _boundary_id == bc_struct.id_GradsApproximationBC: + # Reformulated Grads boundary condition + f_post = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_post @wp.func @@ -174,7 +177,6 @@ def apply_post_collision_bc( elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: # Storing post-streaming data in directions that leave the domain f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post @wp.func From 0448d2b4308b8ed0c4cd3fa38aaeab5d2de28484 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Wed, 2 Oct 2024 16:44:26 -0400 Subject: [PATCH 123/144] fixed out-of-bound issues and added large bc_id for skipping lbm on solid voxels --- .../mesh_distance_boundary_masker.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py index fadd23f..6dab0ee 100644 --- a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py @@ -48,6 +48,11 @@ def _construct_warp(self): _q = wp.constant(self.velocity_set.q) _opp_indices = self.velocity_set.opp_indices + @wp.func + def check_index_bounds(index: wp.vec3i, shape: wp.vec3i): + is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2] + return is_in_bounds + @wp.func def index_to_position(index: wp.vec3i, origin: wp.vec3, spacing: wp.vec3): # position of the point @@ -89,25 +94,26 @@ def kernel( # evaluate if point is inside mesh query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_solid_cell, max_length) - if query.result: - # set point to be solid - if query.sign <= 0: # TODO: fix this - # Stream indices - missing_mask[0, index[0], index[1], index[2]] = True - for l in range(1, _q): - # Get the index of the streaming direction - push_index = wp.vec3i() - for d in range(self.velocity_set.d): - push_index[d] = index[d] + _c[d, l] - - # Set the boundary id and missing_mask - bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) - missing_mask[l, push_index[0], push_index[1], push_index[2]] = True - + if query.result and query.sign <= 0: # TODO: fix this + # Set bc_mask of solid to a large number to enable skipping LBM operations + bc_mask[0, index[0], index[1], index[2]] = wp.uint8(255) + + # Find neighboring fluid cells along each lattice direction and the their fractional distance to the mesh + for l in range(1, _q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3]) + if check_index_bounds(push_index, shape): # find neighbouring fluid cell pos_fluid_cell = index_to_position(push_index, origin, spacing) query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_fluid_cell, max_length) if query.result and query.sign > 0: + # Set the boundary id and missing_mask + bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[l, push_index[0], push_index[1], push_index[2]] = True + # get position of the mesh triangle that intersects with the solid cell pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) weight = wp.length(pos_fluid_cell - pos_mesh) / wp.length(pos_fluid_cell - pos_solid_cell) From ad33b3eb210f6b4f8352d32741781a1660989cad Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Wed, 2 Oct 2024 17:10:39 -0400 Subject: [PATCH 124/144] adding a flag to assert mesh distance masker is used for bc_grads --- xlb/operator/boundary_condition/bc_grads_approximation.py | 4 +++- xlb/operator/boundary_condition/boundary_condition.py | 3 +++ xlb/operator/boundary_masker/mesh_boundary_masker.py | 3 +++ xlb/operator/boundary_masker/mesh_distance_boundary_masker.py | 3 +++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index a57954f..b2cd2c1 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -65,6 +65,9 @@ def __init__( self.equilibrium = QuadraticEquilibrium() self.momentum_flux = MomentumFlux() + # This BC needs implicit distance to the mesh + self.needs_mesh_distance = True + # if indices is not None: # # this BC would be limited to stationary boundaries # # assert mesh_vertices is None @@ -76,7 +79,6 @@ def __init__( assert self.compute_backend == ComputeBackend.WARP, "This BC is currently only implemented with the Warp backend!" - @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 90ee127..be920bf 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -52,6 +52,9 @@ def __init__( # when inside/outside of the geoemtry is not known self.needs_padding = False + # A flag for BCs that need implicit boundary distance between the grid and a mesh (to be set to True if applicable inside each BC) + self.needs_mesh_distance = False + if self.compute_backend == ComputeBackend.WARP: # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 701fc13..ee3df68 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -122,6 +122,9 @@ def warp_implementation( # We are done with bc.mesh_vertices. Remove them from BC objects bc.__dict__.pop("mesh_vertices", None) + # Ensure this masker is called only for BCs that need implicit distance to the mesh + assert not bc.needs_mesh_distance, 'Please use "MeshDistanceBoundaryMasker" if this BC needs mesh distance!' + mesh_indices = np.arange(mesh_vertices.shape[0]) mesh = wp.Mesh( points=wp.array(mesh_vertices, dtype=wp.vec3), diff --git a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py index 6dab0ee..45825b1 100644 --- a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py @@ -146,6 +146,9 @@ def warp_implementation( # We are done with bc.mesh_vertices. Remove them from BC objects bc.__dict__.pop("mesh_vertices", None) + # Ensure this masker is called only for BCs that need implicit distance to the mesh + assert bc.needs_mesh_distance, 'Please use "MeshBoundaryMasker" if this BC does NOT need mesh distance!' + mesh_indices = np.arange(mesh_vertices.shape[0]) mesh = wp.Mesh( points=wp.array(mesh_vertices, dtype=wp.vec3), From 974848064ba71d5786777695d6d9c0334fdcf223 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 3 Oct 2024 11:36:16 -0400 Subject: [PATCH 125/144] fixing some remaining type issues --- .../boundary_masker/mesh_boundary_masker.py | 2 +- xlb/operator/force/momentum_transfer.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index ee3df68..228ea4d 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -89,7 +89,7 @@ def kernel( # set point to be solid if query.sign <= 0: # TODO: fix this # Stream indices - for l in range(_q): + for l in range(1, _q): # Get the index of the streaming direction push_index = wp.vec3i() for d in range(self.velocity_set.d): diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 7067ad0..79d75a4 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -90,6 +90,7 @@ def _construct_warp(self): _c = self.velocity_set.c _opp_indices = self.velocity_set.opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool _no_slip_id = self.no_slip_bc_instance.id @@ -128,7 +129,7 @@ def kernel2d( is_edge = wp.bool(True) # If the boundary is an edge then add the momentum transfer - m = wp.vec2() + m = _u_vec() if is_edge: # Get the distribution function f_post_collision = _f_vec() @@ -145,7 +146,10 @@ def kernel2d( for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) + if _c[d, _opp_indices[l]] == 1: + m[d] += phi + elif _c[d, _opp_indices[l]] == -1: + m[d] -= phi wp.atomic_add(force, 0, m) @@ -178,7 +182,7 @@ def kernel3d( is_edge = wp.bool(True) # If the boundary is an edge then add the momentum transfer - m = wp.vec3() + m = _u_vec() if is_edge: # Get the distribution function f_post_collision = _f_vec() @@ -196,7 +200,10 @@ def kernel3d( for l in range(self.velocity_set.q): if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) + if _c[d, _opp_indices[l]] == 1: + m[d] += phi + elif _c[d, _opp_indices[l]] == -1: + m[d] -= phi wp.atomic_add(force, 0, m) @@ -208,7 +215,8 @@ def kernel3d( @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f, bc_mask, missing_mask): # Allocate the force vector (the total integral value will be computed) - force = wp.zeros((1), dtype=wp.vec3) if self.velocity_set.d == 3 else wp.zeros((1), dtype=wp.vec2) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + force = wp.zeros((1), dtype=_u_vec) # Launch the warp kernel wp.launch( From eb26a59d716473778e6b1975c34fe798fa69d60c Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 3 Oct 2024 14:50:10 -0400 Subject: [PATCH 126/144] Finalized a novel hybrid interpolated bounceback and regularized bc for stationary and moving BCs --- .../bc_grads_approximation.py | 89 +++++++++++++------ 1 file changed, 61 insertions(+), 28 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index b2cd2c1..bdfc27c 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -49,6 +49,9 @@ def __init__( indices=None, mesh_vertices=None, ): + # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. + self.u = (0, 0, 0) + # Call the parent constructor super().__init__( ImplementationStep.STREAMING, @@ -97,7 +100,35 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _u_wall = _u_vec(self.u[0], self.u[1], self.u[2]) if _d == 3 else _u_vec(self.u[0], self.u[1]) - diagonal = wp.vec3i(0, 3, 5) if _d == 3 else wp.vec2i(0, 2) + # diagonal = wp.vec3i(0, 3, 5) if _d == 3 else wp.vec2i(0, 2) + + @wp.func + def regularize_fpop( + missing_mask: Any, + rho: Any, + u: Any, + fpop: Any, + ): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + """ + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + feq = self.equilibrium.warp_functional(rho, u) + f_neq = fpop - feq + PiNeq = self.momentum_flux.warp_functional(f_neq) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + QiPi1 = self.compute_dtype(0.0) + for t in range(nt): + QiPi1 += _qi[l, t] * PiNeq[t] + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 + fpop[l] = feq[l] + fpop1 + return fpop @wp.func def grads_approximate_fpop( @@ -121,25 +152,25 @@ def grads_approximate_fpop( # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) nt = _d * (_d + 1) // 2 for l in range(_q): - # If the mask is missing then use f_post - if missing_mask[l] == wp.uint8(1): - QiPi = self.compute_dtype(0.0) - for t in range(nt): - if t in diagonal: - Pi[t] -= rho / 3.0 + # if missing_mask[l] == wp.uint8(1): + QiPi = self.compute_dtype(0.0) + for t in range(nt): + if t == 0 or t == 3 or t == 5: + QiPi += _qi[l, t] * (Pi[t] - rho / self.compute_dtype(3.0)) + else: QiPi += _qi[l, t] * Pi[t] - # Compute c.u - cu = self.compute_dtype(0.0) - for d in range(self.velocity_set.d): - if _c[d, l] == 1: - cu += u[d] - elif _c[d, l] == -1: - cu -= u[d] - cu *= self.compute_dtype(3.0) + # Compute c.u + cu = self.compute_dtype(0.0) + for d in range(self.velocity_set.d): + if _c[d, l] == 1: + cu += u[d] + elif _c[d, l] == -1: + cu -= u[d] + cu *= self.compute_dtype(3.0) - # change f_post using the Grad's approximation - f_post[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu + self.compute_dtype(4.5) * QiPi / rho) + # change f_post using the Grad's approximation + f_post[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu) + _w[l] * self.compute_dtype(4.5) * QiPi return f_post @@ -161,26 +192,28 @@ def functional_method1( # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 - weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + # weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + weight = self.compute_dtype(0.5) # Use differentiable interpolated BB to find f_missing: f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) - # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC - cu = self.compute_dtype(0.0) - for d in range(_d): - if _c[d, l] == 1: - cu += _u_wall[d] - elif _c[d, l] == -1: - cu -= _u_wall[d] - cu *= self.compute_dtype(-6.0) * self.velocity_set.w - f_post[l] += cu + # # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + # cu = self.compute_dtype(0.0) + # for d in range(_d): + # if _c[d, l] == 1: + # cu += _u_wall[d] + # elif _c[d, l] == -1: + # cu -= _u_wall[d] + # cu *= self.compute_dtype(-6.0) * _w[l] + # f_post[l] += cu # Compute density, velocity using all f_post-streaming values rho, u = self.macroscopic.warp_functional(f_post) # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. - f_post = grads_approximate_fpop(missing_mask, rho, u, f_post) + f_post = regularize_fpop(missing_mask, rho, u, f_post) + # f_post = grads_approximate_fpop(missing_mask, rho, u, f_post) return f_post # Construct the functionals for this BC From 0dbe84058122600595f9f14efd80acf12b378360 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 3 Oct 2024 16:15:23 -0400 Subject: [PATCH 127/144] added necessary changes to the nse_stepper --- examples/cfd/flow_past_sphere_3d.py | 2 +- examples/cfd/lid_driven_cavity_2d.py | 2 +- examples/cfd/turbulent_channel_3d.py | 2 +- examples/cfd/windtunnel_3d.py | 25 +++++++++++----- examples/performance/mlups_3d.py | 2 +- xlb/operator/stepper/nse_stepper.py | 44 +++++++++++++++++++--------- 6 files changed, 52 insertions(+), 25 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2487919..2e0df95 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -102,7 +102,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index f94e209..20f3b7c 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -68,7 +68,7 @@ def setup_stepper(self, omega): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index 65b56bf..2ec5560 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -113,7 +113,7 @@ def setup_stepper(self): def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 077ae98..338cf8d 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -12,10 +12,11 @@ RegularizedBC, HalfwayBounceBackBC, ExtrapolationOutflowBC, + GradsApproximationBC, ) from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.macroscopic import Macroscopic -from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker +from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker, MeshDistanceBoundaryMasker from xlb.utils import save_fields_vtk, save_image import warp as wp import numpy as np @@ -51,9 +52,10 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.lift_coefficients = [] def _setup(self): + # NOTE: it is important to initialize fields before setup_boundary_masker is called because f_0 or f_1 might be used to store BC information + self.initialize_fields() self.setup_boundary_conditions() self.setup_boundary_masker() - self.initialize_fields() self.setup_stepper() def voxelize_stl(self, stl_filename, length_lbm_unit): @@ -99,7 +101,8 @@ def setup_boundary_conditions(self): # bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = HalfwayBounceBackBC(mesh_vertices=car) + # bc_car = HalfwayBounceBackBC(mesh_vertices=car) + bc_car = GradsApproximationBC(mesh_vertices=car) # bc_car = FullwayBounceBackBC(mesh_vertices=car) self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] @@ -109,7 +112,12 @@ def setup_boundary_masker(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - mesh_boundary_masker = MeshBoundaryMasker( + # mesh_boundary_masker = MeshBoundaryMasker( + # velocity_set=self.velocity_set, + # precision_policy=self.precision_policy, + # compute_backend=self.backend, + # ) + mesh_distance_boundary_masker = MeshDistanceBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.backend, @@ -119,10 +127,12 @@ def setup_boundary_masker(self): dx = self.grid_spacing origin, spacing = (0, 0, 0), (dx, dx, dx) self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask) - self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) + # self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) + self.bc_mask, self.missing_mask, self.f_1 = mesh_distance_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask, self.f_1) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) + self.f_1 = initialize_eq(self.f_1, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") @@ -134,7 +144,7 @@ def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: @@ -226,7 +236,8 @@ def plot_drag_coefficient(self): print_interval = 1000 # Set up Reynolds number and deduce relaxation time (omega) - Re = 50000.0 + # Re = 50000.0 + Re = 500000000000.0 clength = grid_size_x - 1 visc = wind_speed * clength / Re omega = 1.0 / (3.0 * visc + 0.5) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 907c1f2..1812d95 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_st start_time = time.time() for i in range(num_steps): - f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i) + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i) f_0, f_1 = f_1, f_0 wp.synchronize() diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 57988f5..7b1c4da 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -88,14 +88,13 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): # Copy back to store precision f_1 = self.precision_policy.cast_to_store_jax(f_post_collision) - return f_1 + return f_0, f_1 def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _q = self.velocity_set.q - _d = self.velocity_set.d _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _opp_indices = self.velocity_set.opp_indices @wp.struct class BoundaryConditionIDStruct: @@ -181,39 +180,45 @@ def apply_post_collision_bc( @wp.func def get_thread_data_2d( - f_buffer: wp.array3d(dtype=Any), + f0_buffer: wp.array3d(dtype=Any), + f1_buffer: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), index: Any, ): # Read thread data for populations and missing mask - f_thread = _f_vec() + f0_thread = _f_vec() + f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): - f_thread[l] = self.compute_dtype(f_buffer[l, index[0], index[1]]) + f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) + f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_thread, _missing_mask + return f0_thread, f1_thread, _missing_mask @wp.func def get_thread_data_3d( - f_buffer: wp.array4d(dtype=Any), + f0_buffer: wp.array4d(dtype=Any), + f1_buffer: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), index: Any, ): # Read thread data for populations - f_thread = _f_vec() + f0_thread = _f_vec() + f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_thread[l] = self.compute_dtype(f_buffer[l, index[0], index[1], index[2]]) + f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) + f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_thread, _missing_mask + return f0_thread, f1_thread, _missing_mask @wp.kernel def kernel2d( @@ -230,12 +235,15 @@ def kernel2d( # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1]] + if _boundary_id == wp.uint8(255): + return # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) # Apply post-streaming type boundary conditions - _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + f0_thread, f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) + _f_post_collision = f0_thread _f_post_stream = apply_post_streaming_bc( index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream ) @@ -274,12 +282,15 @@ def kernel3d( # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1], index[2]] + if _boundary_id == wp.uint8(255): + return # Apply streaming (pull method) _f_post_stream = self.stream.warp_functional(f_0, index) # Apply post-streaming type boundary conditions - _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + f0_thread, f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) + _f_post_collision = f0_thread _f_post_stream = apply_post_streaming_bc( index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream ) @@ -300,6 +311,11 @@ def kernel3d( # Set the output for l in range(self.velocity_set.q): + # TODO 1: fix the perf drop due to l324-l236 even in cases where this BC is not used. + # TODO 2: is there better way to move these lines to a function inside BC class like "restore_bc_data" + if _boundary_id == bc_struct.id_GradsApproximationBC: + if _missing_mask[l] == wp.uint8(1): + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel @@ -356,4 +372,4 @@ def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): ], dim=f_0.shape[1:], ) - return f_1 + return f_0, f_1 From 371dfd8e7c0245472abf313b2d4c28620b6778eb Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 3 Oct 2024 16:59:12 -0400 Subject: [PATCH 128/144] added not implemented error for 2d settings --- .../bc_grads_approximation.py | 58 +++++-------------- .../boundary_masker/mesh_boundary_masker.py | 4 ++ .../mesh_distance_boundary_masker.py | 4 ++ 3 files changed, 24 insertions(+), 42 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index bdfc27c..e549668 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -71,6 +71,10 @@ def __init__( # This BC needs implicit distance to the mesh self.needs_mesh_distance = True + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This BC is not implemented in 2D!") + # if indices is not None: # # this BC would be limited to stationary boundaries # # assert mesh_vertices is None @@ -192,21 +196,21 @@ def functional_method1( # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 - # weight = f_1[_opp_indices[l], index[0], index[1], index[2]] - weight = self.compute_dtype(0.5) + weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + # weight = self.compute_dtype(0.5) # Use differentiable interpolated BB to find f_missing: f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) - # # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC - # cu = self.compute_dtype(0.0) - # for d in range(_d): - # if _c[d, l] == 1: - # cu += _u_wall[d] - # elif _c[d, l] == -1: - # cu -= _u_wall[d] - # cu *= self.compute_dtype(-6.0) * _w[l] - # f_post[l] += cu + # # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + # cu = self.compute_dtype(0.0) + # for d in range(_d): + # if _c[d, l] == 1: + # cu += _u_wall[d] + # elif _c[d, l] == -1: + # cu -= _u_wall[d] + # cu *= self.compute_dtype(-6.0) * _w[l] + # f_post[l] += cu # Compute density, velocity using all f_post-streaming values rho, u = self.macroscopic.warp_functional(f_post) @@ -301,36 +305,7 @@ def functional_method2( # Construct the warp kernel @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - timestep = 0 - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - _f_aux = _f_vec() - - # Apply the boundary condition - if _boundary_id == wp.uint8(GradsApproximationBC.id): - # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both - # collision and streaming? - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( + def kernel( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), @@ -357,7 +332,6 @@ def kernel3d( for l in range(self.velocity_set.q): f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d functional = functional_method1 return functional, kernel diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 228ea4d..edfdb42 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -23,6 +23,10 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This Operator is not implemented in 2D!") + # Also using Warp kernels for JAX implementation if self.compute_backend == ComputeBackend.JAX: self.warp_functional, self.warp_kernel = self._construct_warp() diff --git a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py index 45825b1..87af94c 100644 --- a/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_distance_boundary_masker.py @@ -24,6 +24,10 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This Operator is not implemented in 2D!") + # Also using Warp kernels for JAX implementation if self.compute_backend == ComputeBackend.JAX: self.warp_functional, self.warp_kernel = self._construct_warp() From 96d80f08758c8dae475c120ceb7e719a055e3240 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 3 Oct 2024 17:47:58 -0400 Subject: [PATCH 129/144] fixed ruff issue and commented out the perf drop lines --- xlb/operator/boundary_condition/bc_grads_approximation.py | 6 +++--- xlb/operator/stepper/nse_stepper.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index e549668..2e6e597 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -196,8 +196,8 @@ def functional_method1( # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 - weight = f_1[_opp_indices[l], index[0], index[1], index[2]] - # weight = self.compute_dtype(0.5) + # weight = f_1[_opp_indices[l], index[0], index[1], index[2]] + weight = self.compute_dtype(0.5) # Use differentiable interpolated BB to find f_missing: f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) @@ -242,7 +242,7 @@ def functional_method2( # 2) Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) # NOTE: in the original paper "u_target" is associated with the previous time step not current time. # 3) Given "weights" use differentiable interpolated BB to find f_missing as I had before: - # fmissing = ((1. - weights) * f_poststreaming_iknown + weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + weights) + # fmissing = ((1. - weights) * f_poststreaming_iknown + weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + weights) # 4) Add contribution due to u_w to f_missing as is usual in regular Bouzidi BC (ie. -6.0 * self.lattice.w * jnp.dot(self.vel, c) # 5) Compute rho_target = \sum(f_ibb) based on these values # 6) Compute feq using feq = self.equilibrium(rho_target, u_target) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 7b1c4da..62790a6 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -313,9 +313,9 @@ def kernel3d( for l in range(self.velocity_set.q): # TODO 1: fix the perf drop due to l324-l236 even in cases where this BC is not used. # TODO 2: is there better way to move these lines to a function inside BC class like "restore_bc_data" - if _boundary_id == bc_struct.id_GradsApproximationBC: - if _missing_mask[l] == wp.uint8(1): - f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) + # if _boundary_id == bc_struct.id_GradsApproximationBC: + # if _missing_mask[l] == wp.uint8(1): + # f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel From 0f912d42780e11c94942442b922ca0a7725003db Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 4 Oct 2024 11:01:46 -0400 Subject: [PATCH 130/144] enabled the new BC to work with indices bc masker --- xlb/operator/boundary_condition/bc_grads_approximation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 2e6e597..870635e 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -71,6 +71,12 @@ def __init__( # This BC needs implicit distance to the mesh self.needs_mesh_distance = True + # If this BC is defined using indices, it would need padding in order to find missing directions + # when imposed on a geometry that is in the domain interior + if self.mesh_vertices is None: + assert self.indices is not None + self.needs_padding = True + # Raise error if used for 2d examples: if self.velocity_set.d == 2: raise NotImplementedError("This BC is not implemented in 2D!") From f215c5b1558a6bfb2971967d1c18abfbc434e455 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 4 Oct 2024 12:32:00 -0400 Subject: [PATCH 131/144] added recent changes to the momentum exchange method --- xlb/operator/force/momentum_transfer.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 79d75a4..8b0aacf 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -103,7 +103,8 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel2d( - f: wp.array3d(dtype=Any), + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -134,11 +135,12 @@ def kernel2d( # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = f[l, index[0], index[1]] + f_post_collision[l] = f_0[l, index[0], index[1]] # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + timestep = 0 + f_post_stream = self.stream.warp_functional(f_0, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -156,7 +158,8 @@ def kernel2d( # Construct the warp kernel @wp.kernel def kernel3d( - fpop: wp.array4d(dtype=Any), + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -187,12 +190,12 @@ def kernel3d( # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = fpop[l, index[0], index[1], index[2]] + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] # Apply streaming (pull method) timestep = 0 - f_post_stream = self.stream.warp_functional(fpop, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, fpop, fpop, f_post_collision, f_post_stream) + f_post_stream = self.stream.warp_functional(f_0, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -213,7 +216,7 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, bc_mask, missing_mask): + def warp_implementation(self, f_0, f_1, bc_mask, missing_mask): # Allocate the force vector (the total integral value will be computed) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) force = wp.zeros((1), dtype=_u_vec) @@ -221,7 +224,7 @@ def warp_implementation(self, f, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f, bc_mask, missing_mask, force], - dim=f.shape[1:], + inputs=[f_0, f_1, bc_mask, missing_mask, force], + dim=f_0.shape[1:], ) return force.numpy()[0] From a786ccd02c41e9c8f46e854711c3e3e5b8c1f9e8 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 4 Oct 2024 15:28:44 -0400 Subject: [PATCH 132/144] missed from prev commit --- examples/cfd/windtunnel_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 338cf8d..140e756 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -179,7 +179,7 @@ def post_process(self, i): save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i) # Compute lift and drag - boundary_force = self.momentum_transfer(self.f_0, self.bc_mask, self.missing_mask) + boundary_force = self.momentum_transfer(self.f_0, self.f_1, self.bc_mask, self.missing_mask) drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane lift = boundary_force[2] c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) From 1d1560bb6c12e2fe9c2e1a1aeabc5691aad396ec Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Sun, 6 Oct 2024 16:58:35 -0400 Subject: [PATCH 133/144] Significantly simplified boundary application --- .../flow_past_sphere.py | 203 -------------- .../cfd_old_to_be_migrated/taylor_green.py | 181 ------------- .../boundary_condition/bc_regularized.py | 2 + xlb/operator/stepper/nse_stepper.py | 248 +++++++----------- 4 files changed, 93 insertions(+), 541 deletions(-) delete mode 100644 examples/cfd_old_to_be_migrated/flow_past_sphere.py delete mode 100644 examples/cfd_old_to_be_migrated/taylor_green.py diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py deleted file mode 100644 index 1684266..0000000 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ /dev/null @@ -1,203 +0,0 @@ -# Simple flow past sphere example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import numpy as np - -from xlb.compute_backend import ComputeBackend - -import warp as wp - -import xlb - -xlb.init( - default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=xlb.velocity_set.D2Q9, -) - - -from xlb.operator import Operator - - -class UniformInitializer(Operator): - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - ): - # Get the global index - i, j, k = wp.tid() - - # Set the velocity - u[0, i, j, k] = vel - u[1, i, j, k] = 0.0 - u[2, i, j, k] = 0.0 - - # Set the density - rho[0, i, j, k] = 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - ], - dim=rho.shape[1:], - ) - return rho, u - - -if __name__ == "__main__": - # Set parameters - compute_backend = xlb.ComputeBackend.WARP - precision_policy = xlb.PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D3Q19() - - # Make feilds - nr = 256 - vel = 0.05 - shape = (nr, nr, nr) - grid = xlb.grid.grid_factory(shape=shape) - rho = grid.create_field(cardinality=1) - u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) - f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - bc_mask = grid.create_field(cardinality=1, dtype=wp.uint8) - missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) - - # Make operators - initializer = UniformInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - collision = xlb.operator.collision.BGK( - omega=1.95, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( - rho=1.0, - u=(vel, 0.0, 0.0), - equilibrium_operator=equilibrium, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream, - equilibrium_bc=equilibrium_bc, - do_nothing_bc=do_nothing_bc, - half_way_bc=half_way_bc, - ) - indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - - # Make indices for boundary conditions (sphere) - sphere_radius = 32 - x = np.arange(nr) - y = np.arange(nr) - z = np.arange(nr) - X, Y, Z = np.meshgrid(x, y, z) - indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) - indices = np.array(indices).T - indices = wp.from_numpy(indices, dtype=wp.int32) - - # Set boundary conditions on the indices - bc_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set inlet bc - lower_bound = (0, 0, 0) - upper_bound = (0, nr, nr) - direction = (1, 0, 0) - bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, equilibrium_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set outlet bc - lower_bound = (nr - 1, 0, 0) - upper_bound = (nr - 1, nr, nr) - direction = (-1, 0, 0) - bc_mask, missing_mask = planar_boundary_masker(lower_bound, upper_bound, direction, do_nothing_bc.id, bc_mask, missing_mask, (0, 0, 0)) - - # Set initial conditions - rho, u = initializer(rho, u, vel) - f0 = equilibrium(rho, u, f0) - - # Time stepping - plot_freq = 512 - save_dir = "flow_past_sphere" - os.makedirs(save_dir, exist_ok=True) - # compute_mlup = False # Plotting results - compute_mlup = True - num_steps = 1024 * 8 - start = time.time() - for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, bc_mask, missing_mask, _) - f1, f0 = f0, f1 - if (_ % plot_freq == 0) and (not compute_mlup): - rho, u = macroscopic(f0, rho, u) - - # Plot the velocity field and boundary id side by side - plt.subplot(1, 2, 1) - plt.imshow(u[0, :, nr // 2, :].numpy()) - plt.colorbar() - plt.subplot(1, 2, 2) - plt.imshow(bc_mask[0, :, nr // 2, :].numpy()) - plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") - plt.close() - - wp.synchronize() - end = time.time() - - # Print MLUPS - print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py deleted file mode 100644 index 846ba30..0000000 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ /dev/null @@ -1,181 +0,0 @@ -# Simple Taylor green example using the functional interface to xlb - -import time -from tqdm import tqdm -import os -import matplotlib.pyplot as plt -from typing import Any -import jax.numpy as jnp -import warp as wp - -wp.init() - -import xlb -from xlb.operator import Operator - - -class TaylorGreenInitializer(Operator): - """ - Initialize the Taylor-Green vortex. - """ - - @Operator.register_backend(xlb.ComputeBackend.JAX) - # @partial(jit, static_argnums=(0)) - def jax_implementation(self, vel, nr): - # Make meshgrid - x = jnp.linspace(0, 2 * jnp.pi, nr) - y = jnp.linspace(0, 2 * jnp.pi, nr) - z = jnp.linspace(0, 2 * jnp.pi, nr) - X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij") - - # Compute u - u = jnp.stack( - [ - vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), - -vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), - jnp.zeros_like(X), - ], - axis=0, - ) - - # Compute rho - rho = 3.0 * vel * vel * (1.0 / 16.0) * (jnp.cos(2.0 * X) + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0))) + 1.0 - rho = jnp.expand_dims(rho, axis=0) - - return rho, u - - def _construct_warp(self): - # Construct the warp kernel - @wp.kernel - def kernel( - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), - vel: float, - nr: int, - ): - # Get the global index - i, j, k = wp.tid() - - # Get real pos - x = 2.0 * wp.pi * wp.float(i) / wp.float(nr) - y = 2.0 * wp.pi * wp.float(j) / wp.float(nr) - z = 2.0 * wp.pi * wp.float(k) / wp.float(nr) - - # Compute u - u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1, i, j, k] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) - u[2, i, j, k] = 0.0 - - # Compute rho - rho[0, i, j, k] = 3.0 * vel * vel * (1.0 / 16.0) * (wp.cos(2.0 * x) + (wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0))) + 1.0 - - return None, kernel - - @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, rho, u, vel, nr): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[ - rho, - u, - vel, - nr, - ], - dim=rho.shape[1:], - ) - return rho, u - - -def run_taylor_green(backend, compute_mlup=True): - # Set the compute backend - if backend == "warp": - compute_backend = xlb.ComputeBackend.WARP - elif backend == "jax": - compute_backend = xlb.ComputeBackend.JAX - - # Set the precision policy - precision_policy = xlb.PrecisionPolicy.FP32FP32 - - # Set the velocity set - velocity_set = xlb.velocity_set.D3Q19() - - # Make grid - nr = 128 - shape = (nr, nr, nr) - if backend == "jax": - grid = xlb.grid.JaxGrid(shape=shape) - elif backend == "warp": - grid = xlb.grid.WarpGrid(shape=shape) - - # Make feilds - rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) - u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) - f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) - bc_mask = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) - missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) - - # Make operators - initializer = TaylorGreenInitializer(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - collision = xlb.operator.collision.BGK(omega=1.9, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend - ) - macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - stream = xlb.operator.stream.Stream(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) - stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream - ) - - # Parrallelize the stepper TODO: Add this functionality - # stepper = grid.parallelize_operator(stepper) - - # Set initial conditions - if backend == "warp": - rho, u = initializer(rho, u, 0.1, nr) - f0 = equilibrium(rho, u, f0) - elif backend == "jax": - rho, u = initializer(0.1, nr) - f0 = equilibrium(rho, u) - - # Time stepping - plot_freq = 32 - save_dir = "taylor_green" - os.makedirs(save_dir, exist_ok=True) - num_steps = 8192 - start = time.time() - - for _ in tqdm(range(num_steps)): - # Time step - if backend == "warp": - f1 = stepper(f0, f1, bc_mask, missing_mask, _) - f1, f0 = f0, f1 - elif backend == "jax": - f0 = stepper(f0, bc_mask, missing_mask, _) - - # Plot if needed - if (_ % plot_freq == 0) and (not compute_mlup): - if backend == "warp": - rho, u = macroscopic(f0, rho, u) - local_u = u.numpy() - elif backend == "jax": - rho, local_u = macroscopic(f0) - - plt.imshow(local_u[0, :, nr // 2, :]) - plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") - plt.close() - wp.synchronize() - end = time.time() - - # Print MLUPS - print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") - - -if __name__ == "__main__": - # Run Taylor-Green vortex on different backends - backends = ["warp", "jax"] - # backends = ["jax"] - for backend in backends: - run_taylor_green(backend, compute_mlup=True) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index bb4b5f0..e1505b7 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -62,6 +62,8 @@ def __init__( mesh_vertices, ) + self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__) + # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 62790a6..f977519 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -14,7 +14,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep -from xlb.operator.boundary_condition import DoNothingBC as DummyBC +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.collision import ForcedCollision @@ -40,6 +40,9 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK", forcing_ operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] + self.boundary_conditions = boundary_conditions + self.active_bcs = set(type(bc).__name__ for bc in boundary_conditions) + super().__init__(operators, boundary_conditions) @Operator.register_backend(ComputeBackend.JAX) @@ -91,92 +94,84 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): return f_0, f_1 def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update + # Set local constants _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) _opp_indices = self.velocity_set.opp_indices - @wp.struct - class BoundaryConditionIDStruct: - # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning - # One needs to manually add the names of additional BC's as they are added. - # TODO: Any way to improve this? - id_EquilibriumBC: wp.uint8 - id_DoNothingBC: wp.uint8 - id_HalfwayBounceBackBC: wp.uint8 - id_FullwayBounceBackBC: wp.uint8 - id_ZouHeBC_velocity: wp.uint8 - id_ZouHeBC_pressure: wp.uint8 - id_RegularizedBC_velocity: wp.uint8 - id_RegularizedBC_pressure: wp.uint8 - id_ExtrapolationOutflowBC: wp.uint8 - id_GradsApproximationBC: wp.uint8 + # Read the list of bc_to_id created upon instantiation + bc_to_id = boundary_condition_registry.bc_to_id + id_to_bc = boundary_condition_registry.id_to_bc + + for bc in self.boundary_conditions: + bc_name = id_to_bc[bc.id] + setattr(self, bc_name, bc) @wp.func def apply_post_streaming_bc( index: Any, timestep: Any, _boundary_id: Any, - bc_struct: Any, missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, f_post: Any, ): - # Apply post-streaming type boundary conditions - # NOTE: 'f_pre' is included here as an input to the BC functionals for consistency with the BC API, - # particularly when compared to post-collision boundary conditions (see below). - - if _boundary_id == bc_struct.id_EquilibriumBC: - # Equilibrium boundary condition - f_post = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_DoNothingBC: - # Do nothing boundary condition - f_post = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: - # Half way boundary condition - f_post = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ZouHeBC_velocity: - # Zouhe boundary condition (bc type = velocity) - f_post = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ZouHeBC_pressure: - # Zouhe boundary condition (bc type = pressure) - f_post = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_RegularizedBC_velocity: - # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_RegularizedBC_pressure: - # Regularized boundary condition (bc type = velocity) - f_post = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # Regularized boundary condition (bc type = velocity) - f_post = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_GradsApproximationBC: - # Reformulated Grads boundary condition - f_post = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post + f_result = f_post + + if wp.static("EquilibriumBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["EquilibriumBC"]): + f_result = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("DoNothingBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["DoNothingBC"]): + f_result = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("HalfwayBounceBackBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["HalfwayBounceBackBC"]): + f_result = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ZouHeBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ZouHeBC"]): + f_result = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("RegularizedBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["RegularizedBC"]): + f_result = self.RegularizedBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ExtrapolationOutflowBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ExtrapolationOutflowBC"]): + f_result = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("GradsApproximationBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): + f_result = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + return f_result @wp.func def apply_post_collision_bc( index: Any, timestep: Any, _boundary_id: Any, - bc_struct: Any, missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, f_post: Any, ): - # Apply post-collision type boundary conditions or special boundary preparations - if _boundary_id == bc_struct.id_FullwayBounceBackBC: - # Full way boundary condition - f_post = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC: - # Storing post-streaming data in directions that leave the domain - f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - return f_post + f_result = f_post + + if wp.static("FullwayBounceBackBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["FullwayBounceBackBC"]): + f_result = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ExtrapolationOutflowBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ExtrapolationOutflowBC"]): + f_result = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + return f_result @wp.func def get_thread_data_2d( @@ -186,17 +181,17 @@ def get_thread_data_2d( index: Any, ): # Read thread data for populations and missing mask - f0_thread = _f_vec() - f1_thread = _f_vec() + _f0_thread = _f_vec() + _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): - f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) - f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) + _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) + _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f0_thread, f1_thread, _missing_mask + return _f0_thread, _f1_thread, _missing_mask @wp.func def get_thread_data_3d( @@ -206,19 +201,19 @@ def get_thread_data_3d( index: Any, ): # Read thread data for populations - f0_thread = _f_vec() - f1_thread = _f_vec() + _f0_thread = _f_vec() + _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) - f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) + _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) + _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f0_thread, f1_thread, _missing_mask + return _f0_thread, _f1_thread, _missing_mask @wp.kernel def kernel2d( @@ -226,27 +221,23 @@ def kernel2d( f_1: wp.array3d(dtype=Any), bc_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), - bc_struct: Any, timestep: int, ): - # Get the global index i, j = wp.tid() - index = wp.vec2i(i, j) # TODO warp should fix this + index = wp.vec2i(i, j) - # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1]] if _boundary_id == wp.uint8(255): return - # Apply streaming (pull method) + # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming type boundary conditions - f0_thread, f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) - _f_post_collision = f0_thread - _f_post_stream = apply_post_streaming_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream - ) + _f0_thread, _f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) + _f_post_collision = _f0_thread + + # Apply post-streaming boundary conditions + _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -257,119 +248,62 @@ def kernel2d( # Apply collision _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision - ) + # Apply post-collision boundary conditions + _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) - # Set the output + # Store the result in f_1 for l in range(self.velocity_set.q): f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) - # Construct the kernel @wp.kernel def kernel3d( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), - bc_struct: Any, timestep: int, ): - # Get the global index i, j, k = wp.tid() - index = wp.vec3i(i, j, k) # TODO warp should fix this + index = wp.vec3i(i, j, k) - # Get the boundary id _boundary_id = bc_mask[0, index[0], index[1], index[2]] if _boundary_id == wp.uint8(255): return - # Apply streaming (pull method) + # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - # Apply post-streaming type boundary conditions - f0_thread, f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) - _f_post_collision = f0_thread - _f_post_stream = apply_post_streaming_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream - ) + _f0_thread, _f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) + _f_post_collision = _f0_thread - # Compute rho and u - _rho, _u = self.macroscopic.warp_functional(_f_post_stream) + # Apply post-streaming boundary conditions + _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) - # Compute equilibrium + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) - - # Apply collision _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - # Apply post-collision type boundary conditions - _f_post_collision = apply_post_collision_bc( - index, timestep, _boundary_id, bc_struct, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision - ) + # Apply post-collision boundary conditions + _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) - # Set the output + # Store the result in f_1 for l in range(self.velocity_set.q): - # TODO 1: fix the perf drop due to l324-l236 even in cases where this BC is not used. - # TODO 2: is there better way to move these lines to a function inside BC class like "restore_bc_data" - # if _boundary_id == bc_struct.id_GradsApproximationBC: - # if _missing_mask[l] == wp.uint8(1): - # f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(f1_thread[_opp_indices[l]]) + if wp.static("GradsApproximationBC" in self.active_bcs): + if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): + if _missing_mask[l] == wp.uint8(1): + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return BoundaryConditionIDStruct, kernel + return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): - # Get the boundary condition ids - from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry - - # Read the list of bc_to_id created upon instantiation - bc_to_id = boundary_condition_registry.bc_to_id - id_to_bc = boundary_condition_registry.id_to_bc - bc_struct = self.warp_functional() - active_bc_list = [] - for bc in self.boundary_conditions: - # Setting the Struct attributes and active BC classes based on the BC class names - bc_name = id_to_bc[bc.id] - setattr(self, bc_name, bc) - setattr(bc_struct, "id_" + bc_name, bc_to_id[bc_name]) - active_bc_list.append("id_" + bc_name) - - # Check if boundary_conditions is an empty list (e.g. all periodic and no BC) - # TODO: There is a huge issue here with perf. when boundary_conditions list - # is empty and is initialized with a dummy BC. If it is not empty, no perf - # loss ocurrs. The following code at least prevents syntax error for periodic examples. - if self.boundary_conditions: - bc_dummy = self.boundary_conditions[0] - else: - bc_dummy = DummyBC() - - # Setting the Struct attributes for inactive BC classes - for var in vars(bc_struct): - if var not in active_bc_list and not var.startswith("_"): - # set unassigned boundaries to the maximum integer in uint8 - setattr(bc_struct, var, 255) - - # Assing a fall-back BC for inactive BCs. This is just to ensure Warp codegen does not - # produce error when a particular BC is not used in an example. - setattr(self, var.replace("id_", ""), bc_dummy) - - # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[ - f_0, - f_1, - bc_mask, - missing_mask, - bc_struct, - timestep, - ], + inputs=[f_0, f_1, bc_mask, missing_mask, timestep], dim=f_0.shape[1:], ) return f_0, f_1 From d73a6d6cfaa5affd2b360a6f4d71c206e8c1c3f5 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Mon, 7 Oct 2024 18:01:57 -0400 Subject: [PATCH 134/144] Fixed zouhe and regularized bcs but there are limitations in applying the same bc type multiple times --- .../boundary_condition/bc_do_nothing.py | 2 - .../boundary_condition/bc_equilibrium.py | 2 - .../bc_extrapolation_outflow.py | 2 - .../bc_fullway_bounce_back.py | 2 - .../bc_grads_approximation.py | 3 +- .../bc_halfway_bounce_back.py | 2 - .../boundary_condition/bc_regularized.py | 5 +- xlb/operator/boundary_condition/bc_zouhe.py | 6 ++- .../boundary_condition/boundary_condition.py | 3 +- xlb/operator/stepper/nse_stepper.py | 53 +++++++++++-------- 10 files changed, 40 insertions(+), 40 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 55ce9ed..4b7ac90 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -27,8 +27,6 @@ class DoNothingBC(BoundaryCondition): boundary nodes. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 8c33d29..4dd4b9e 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -29,8 +29,6 @@ class EquilibriumBC(BoundaryCondition): Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, rho: float, diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 8b5f139..53645c6 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -38,8 +38,6 @@ class ExtrapolationOutflowBC(BoundaryCondition): doi:10.1016/j.camwa.2015.05.001. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 29f83c1..8569e84 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -27,8 +27,6 @@ class FullwayBounceBackBC(BoundaryCondition): Full Bounce-back boundary condition for a lattice Boltzmann method simulation. """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 870635e..f4af5a8 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -39,8 +39,6 @@ class GradsApproximationBC(BoundaryCondition): """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, @@ -49,6 +47,7 @@ def __init__( indices=None, mesh_vertices=None, ): + # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. self.u = (0, 0, 0) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 6e787c2..e8df6b7 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -29,8 +29,6 @@ class HalfwayBounceBackBC(BoundaryCondition): TODO: Implement moving boundary conditions for this """ - id = boundary_condition_registry.register_boundary_condition(__qualname__) - def __init__( self, velocity_set: VelocitySet = None, diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index e1505b7..f90ca60 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -61,9 +61,8 @@ def __init__( indices, mesh_vertices, ) - - self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__) - + # Overwrite the boundary condition registry id with the bc_type in the name + self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type) # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 66b6377..a1b79c2 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -35,6 +35,8 @@ class ZouHeBC(BoundaryCondition): Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. """ + + def __init__( self, bc_type, @@ -48,7 +50,6 @@ def __init__( # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." - self.id = boundary_condition_registry.register_boundary_condition(__class__.__name__ + "_" + bc_type) self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() self.prescribed_value = prescribed_value @@ -63,6 +64,9 @@ def __init__( mesh_vertices, ) + # Overwrite the boundary condition registry id with the bc_type in the name + self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type) + # Set the prescribed value for pressure or velocity dim = self.velocity_set.d if self.compute_backend == ComputeBackend.JAX: diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index be920bf..17f8226 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -13,7 +13,7 @@ from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator from xlb import DefaultConfig - +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry # Enum for implementation step class ImplementationStep(Enum): @@ -35,6 +35,7 @@ def __init__( indices=None, mesh_vertices=None, ): + self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__) velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy compute_backend = compute_backend or DefaultConfig.default_backend diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index f977519..b52b721 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -40,9 +40,6 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK", forcing_ operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] - self.boundary_conditions = boundary_conditions - self.active_bcs = set(type(bc).__name__ for bc in boundary_conditions) - super().__init__(operators, boundary_conditions) @Operator.register_backend(ComputeBackend.JAX) @@ -103,6 +100,8 @@ def _construct_warp(self): bc_to_id = boundary_condition_registry.bc_to_id id_to_bc = boundary_condition_registry.id_to_bc + active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) + for bc in self.boundary_conditions: bc_name = id_to_bc[bc.id] setattr(self, bc_name, bc) @@ -120,32 +119,40 @@ def apply_post_streaming_bc( ): f_result = f_post - if wp.static("EquilibriumBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["EquilibriumBC"]): + if wp.static("EquilibriumBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["EquilibriumBC"]): f_result = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("DoNothingBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["DoNothingBC"]): + if wp.static("DoNothingBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["DoNothingBC"]): f_result = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("HalfwayBounceBackBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["HalfwayBounceBackBC"]): + if wp.static("HalfwayBounceBackBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["HalfwayBounceBackBC"]): f_result = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("ZouHeBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ZouHeBC"]): + if wp.static("ZouHeBC_pressure" in active_bcs): + if _boundary_id == wp.static(bc_to_id["ZouHeBC_pressure"]): + f_result = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("ZouHeBC_velocity" in active_bcs): + if _boundary_id == wp.static(bc_to_id["ZouHeBC_velocity"]): f_result = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("RegularizedBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["RegularizedBC"]): - f_result = self.RegularizedBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + if wp.static("RegularizedBC_pressure" in active_bcs): + if _boundary_id == wp.static(bc_to_id["RegularizedBC_pressure"]): + f_result = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + + if wp.static("RegularizedBC_velocity" in active_bcs): + if _boundary_id == wp.static(bc_to_id["RegularizedBC_velocity"]): + f_result = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("ExtrapolationOutflowBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ExtrapolationOutflowBC"]): + if wp.static("ExtrapolationOutflowBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]): f_result = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("GradsApproximationBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): + if wp.static("GradsApproximationBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["GradsApproximationBC"]): f_result = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_result @@ -163,12 +170,12 @@ def apply_post_collision_bc( ): f_result = f_post - if wp.static("FullwayBounceBackBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["FullwayBounceBackBC"]): + if wp.static("FullwayBounceBackBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["FullwayBounceBackBC"]): f_result = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - if wp.static("ExtrapolationOutflowBC" in self.active_bcs): - if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["ExtrapolationOutflowBC"]): + if wp.static("ExtrapolationOutflowBC" in active_bcs): + if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]): f_result = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) return f_result @@ -288,7 +295,7 @@ def kernel3d( # Store the result in f_1 for l in range(self.velocity_set.q): - if wp.static("GradsApproximationBC" in self.active_bcs): + if wp.static("GradsApproximationBC" in active_bcs): if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): if _missing_mask[l] == wp.uint8(1): f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) From 01237489e7ff83a5e353b721cf0ad4419d9ae184 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Mon, 7 Oct 2024 21:25:53 -0400 Subject: [PATCH 135/144] Extremely simplified BC implementation! --- .../bc_grads_approximation.py | 2 +- .../boundary_condition/bc_regularized.py | 2 - xlb/operator/boundary_condition/bc_zouhe.py | 5 - .../boundary_condition/boundary_condition.py | 2 +- .../boundary_condition_registry.py | 1 + xlb/operator/stepper/nse_stepper.py | 96 ++++++------------- 6 files changed, 31 insertions(+), 77 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index f4af5a8..3d60879 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -47,7 +47,7 @@ def __init__( indices=None, mesh_vertices=None, ): - + # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. self.u = (0, 0, 0) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index f90ca60..065a0b0 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -62,8 +62,6 @@ def __init__( mesh_vertices, ) # Overwrite the boundary condition registry id with the bc_type in the name - self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type) - # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() @partial(jit, static_argnums=(0,), inline=True) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index a1b79c2..4be2cf2 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -35,8 +35,6 @@ class ZouHeBC(BoundaryCondition): Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. """ - - def __init__( self, bc_type, @@ -64,9 +62,6 @@ def __init__( mesh_vertices, ) - # Overwrite the boundary condition registry id with the bc_type in the name - self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type) - # Set the prescribed value for pressure or velocity dim = self.velocity_set.d if self.compute_backend == ComputeBackend.JAX: diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 17f8226..6d72fc0 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -35,7 +35,7 @@ def __init__( indices=None, mesh_vertices=None, ): - self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__) + self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + str(hash(self))) velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy compute_backend = compute_backend or DefaultConfig.default_backend diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 5b1e092..6238fc5 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -23,6 +23,7 @@ def register_boundary_condition(self, boundary_condition): self.next_id += 1 self.id_to_bc[_id] = boundary_condition self.bc_to_id[boundary_condition] = _id + print(f"registered bc {boundary_condition} with id {_id}") return _id diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index b52b721..99431eb 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -100,65 +100,16 @@ def _construct_warp(self): bc_to_id = boundary_condition_registry.bc_to_id id_to_bc = boundary_condition_registry.id_to_bc + # Gather IDs of ExtrapolationOutflowBC boundary conditions + extrapolation_outflow_bc_ids = [] + for bc_name, bc_id in bc_to_id.items(): + if bc_name.startswith("ExtrapolationOutflowBC"): + extrapolation_outflow_bc_ids.append(bc_id) + # Group active boundary conditions active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) - for bc in self.boundary_conditions: - bc_name = id_to_bc[bc.id] - setattr(self, bc_name, bc) - - @wp.func - def apply_post_streaming_bc( - index: Any, - timestep: Any, - _boundary_id: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - f_result = f_post - - if wp.static("EquilibriumBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["EquilibriumBC"]): - f_result = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("DoNothingBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["DoNothingBC"]): - f_result = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("HalfwayBounceBackBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["HalfwayBounceBackBC"]): - f_result = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ZouHeBC_pressure" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ZouHeBC_pressure"]): - f_result = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ZouHeBC_velocity" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ZouHeBC_velocity"]): - f_result = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("RegularizedBC_pressure" in active_bcs): - if _boundary_id == wp.static(bc_to_id["RegularizedBC_pressure"]): - f_result = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("RegularizedBC_velocity" in active_bcs): - if _boundary_id == wp.static(bc_to_id["RegularizedBC_velocity"]): - f_result = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ExtrapolationOutflowBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]): - f_result = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("GradsApproximationBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["GradsApproximationBC"]): - f_result = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - return f_result - @wp.func - def apply_post_collision_bc( + def apply_bc( index: Any, timestep: Any, _boundary_id: Any, @@ -167,17 +118,25 @@ def apply_post_collision_bc( f_1: Any, f_pre: Any, f_post: Any, + is_post_streaming: bool, ): f_result = f_post - if wp.static("FullwayBounceBackBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["FullwayBounceBackBC"]): - f_result = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ExtrapolationOutflowBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]): - f_result = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if is_post_streaming: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + else: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].prepare_bc_auxilary_data)( + index, timestep, missing_mask, f_0, f_1, f_pre, f_post + ) return f_result @wp.func @@ -244,7 +203,7 @@ def kernel2d( _f_post_collision = _f0_thread # Apply post-streaming boundary conditions - _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -256,7 +215,7 @@ def kernel2d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision boundary conditions - _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) + _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) # Store the result in f_1 for l in range(self.velocity_set.q): @@ -284,17 +243,18 @@ def kernel3d( _f_post_collision = _f0_thread # Apply post-streaming boundary conditions - _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision boundary conditions - _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) + _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) # Store the result in f_1 for l in range(self.velocity_set.q): + # TODO: Improve this later if wp.static("GradsApproximationBC" in active_bcs): if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): if _missing_mask[l] == wp.uint8(1): From e16ba2445007a5e860097f132d91fdfc00c5a1e6 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 17 Oct 2024 22:15:39 -0400 Subject: [PATCH 136/144] fixed race conditioning in indices_boundary_masker due to duplicate bc indices at corners and edges. --- examples/cfd/flow_past_sphere_3d.py | 24 ++++++++-- examples/cfd/lid_driven_cavity_2d.py | 2 +- examples/cfd/windtunnel_3d.py | 4 +- .../bc_grads_approximation.py | 1 - .../boundary_condition/boundary_condition.py | 1 + .../indices_boundary_masker.py | 46 ++++++++----------- xlb/utils/utils.py | 8 ++-- 7 files changed, 49 insertions(+), 37 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2e0df95..994d18b 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -80,10 +80,9 @@ def setup_boundary_conditions(self): bc_outlet = ExtrapolationOutflowBC(indices=outlet) bc_sphere = HalfwayBounceBackBC(indices=sphere) - self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls] - # Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because + self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere] + # Note: it is important to add bc_walls before bc_outlet/bc_inlet because # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. - # TODO: how to ensure about this behind in the src code? def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( @@ -105,6 +104,8 @@ def run(self, num_steps, post_process_interval=100): self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 + if i == 0: + self.check_boundary_mask() if i % post_process_interval == 0 or i == num_steps - 1: self.post_process(i) end_time = time.time() @@ -134,6 +135,23 @@ def post_process(self, i): # save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) + return + + def check_boundary_mask(self): + # Write the results. We'll use JAX backend for the post-processing + if not isinstance(self.f_0, jnp.ndarray): + bmask = wp.to_jax(self.bc_mask)[0] + else: + bmask = self.bc_mask[0] + + # save_fields_vtk(fields, timestep=i) + save_image(bmask[0, :, :], prefix="00_left") + save_image(bmask[self.grid_shape[0] - 1, :, :], prefix="00_right") + save_image(bmask[:, :, self.grid_shape[2] - 1], prefix="00_top") + save_image(bmask[:, :, 0], prefix="00_bottom") + save_image(bmask[:, 0, :], prefix="00_front") + save_image(bmask[:, self.grid_shape[1] - 1, :], prefix="00_back") + save_image(bmask[:, self.grid_shape[1] // 2, :], prefix="00_middle") if __name__ == "__main__": diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 20f3b7c..300614d 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -50,7 +50,7 @@ def setup_boundary_conditions(self): lid, walls = self.define_boundary_indices() bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid) bc_walls = HalfwayBounceBackBC(indices=walls) - self.boundary_conditions = [bc_top, bc_walls] + self.boundary_conditions = [bc_walls, bc_top] def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 140e756..b0ee5b9 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -104,7 +104,9 @@ def setup_boundary_conditions(self): # bc_car = HalfwayBounceBackBC(mesh_vertices=car) bc_car = GradsApproximationBC(mesh_vertices=car) # bc_car = FullwayBounceBackBC(mesh_vertices=car) - self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car] + self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car] + # Note: it is important to add bc_walls before bc_outlet/bc_inlet because + # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. def setup_boundary_masker(self): indices_boundary_masker = IndicesBoundaryMasker( diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 3d60879..bc29851 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -47,7 +47,6 @@ def __init__( indices=None, mesh_vertices=None, ): - # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. self.u = (0, 0, 0) diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 6d72fc0..008cc9a 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -15,6 +15,7 @@ from xlb import DefaultConfig from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + # Enum for implementation step class ImplementationStep(Enum): COLLISION = auto() diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index fdc4331..10d72c4 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -66,7 +66,7 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): start_index = (0,) * dim domain_shape = bc_mask[0].shape - for bc in bclist: + for bc in reversed(bclist): assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id @@ -103,6 +103,11 @@ def _construct_warp(self): _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) + @wp.func + def check_index_bounds(index: wp.vec3i, shape: wp.vec3i): + is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2] + return is_in_bounds + # Construct the warp 2D kernel @wp.kernel def kernel2d( @@ -173,14 +178,8 @@ def kernel3d( index[2] = indices[2, ii] - start_index[2] # Check if index is in bounds - if ( - index[0] >= 0 - and index[0] < missing_mask.shape[1] - and index[1] >= 0 - and index[1] < missing_mask.shape[2] - and index[2] >= 0 - and index[2] < missing_mask.shape[3] - ): + shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3]) + if check_index_bounds(index, shape): # Stream indices for l in range(_q): # Get the index of the streaming direction @@ -195,27 +194,12 @@ def kernel3d( # check if pull index is out of bound # These directions will have missing information after streaming - if ( - pull_index[0] < 0 - or pull_index[0] >= missing_mask.shape[1] - or pull_index[1] < 0 - or pull_index[1] >= missing_mask.shape[2] - or pull_index[2] < 0 - or pull_index[2] >= missing_mask.shape[3] - ): + if not check_index_bounds(pull_index, shape): # Set the missing mask missing_mask[l, index[0], index[1], index[2]] = True # handling geometries in the interior of the computational domain - elif ( - is_interior[ii] - and push_index[0] >= 0 - and push_index[0] < missing_mask.shape[1] - and push_index[1] >= 0 - and push_index[1] < missing_mask.shape[2] - and push_index[2] >= 0 - and push_index[2] < missing_mask.shape[3] - ): + elif check_index_bounds(pull_index, shape) and is_interior[ii]: # Set the missing mask missing_mask[l, push_index[0], push_index[1], push_index[2]] = True bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] @@ -241,8 +225,14 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - indices = wp.array2d(index_list, dtype=wp.int32) - id_number = wp.array1d(id_list, dtype=wp.uint8) + # Remove duplicates indices to avoid race conditioning + index_arr, unique_loc = np.unique(index_list, axis=-1, return_index=True) + id_arr = np.array(id_list)[unique_loc] + is_interior = np.array(is_interior)[unique_loc] + + # convert to warp arrays + indices = wp.array2d(index_arr, dtype=wp.int32) + id_number = wp.array1d(id_arr, dtype=wp.uint8) is_interior = wp.array1d(is_interior, dtype=wp.bool) if start_index is None: diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index 074177e..0a9858a 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -44,7 +44,7 @@ def downsample_field(field, factor, method="bicubic"): return jnp.stack(downsampled_components, axis=-1) -def save_image(fld, timestep, prefix=None): +def save_image(fld, timestep=None, prefix=None, **kwargs): """ Save an image of a field at a given timestep. @@ -74,7 +74,8 @@ def save_image(fld, timestep, prefix=None): else: fname = prefix - fname = fname + "_" + str(timestep).zfill(4) + if timestep is not None: + fname = fname + "_" + str(timestep).zfill(4) if len(fld.shape) > 3: raise ValueError("The input field should be 2D!") @@ -82,7 +83,8 @@ def save_image(fld, timestep, prefix=None): fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2) plt.clf() - plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower") + kwargs.pop("cmap", None) + plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower", **kwargs) def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): From bd62072fb03cc1d7ad8edd8aa525aefa07ffd4ea Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 18 Oct 2024 09:36:18 -0400 Subject: [PATCH 137/144] fixed a bug in 2d kbc warp --- xlb/operator/collision/kbc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index cc2fb04..bbc5d15 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -178,7 +178,7 @@ def decompose_shear_d2q9_jax(self, fneq): def _construct_warp(self): # Raise error if velocity set is not supported - if not isinstance(self.velocity_set, D3Q27): + if not (isinstance(self.velocity_set, D3Q27) or isinstance(self.velocity_set, D2Q9)): raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set))) # Set local constants TODO: This is a hack and should be fixed with warp update @@ -192,7 +192,7 @@ def _construct_warp(self): def decompose_shear_d2q9(fneq: Any): pi = self.momentum_flux.warp_functional(fneq) N = pi[0] - pi[1] - s = wp.vec9(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + s = _f_vec() s[3] = N s[6] = N s[2] = -N From ec68bfc095091802f8774f27fe244363f85479b7 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 18 Oct 2024 16:27:39 -0400 Subject: [PATCH 138/144] moved np.unique to examples and a helper function --- examples/cfd/flow_past_sphere_3d.py | 23 ++++++++----------- examples/cfd/lid_driven_cavity_2d.py | 17 ++++++++------ examples/cfd/turbulent_channel_3d.py | 4 ++-- examples/cfd/windtunnel_3d.py | 22 ++++++++---------- examples/performance/mlups_3d.py | 13 ++++------- xlb/grid/grid.py | 1 - xlb/helper/__init__.py | 1 + xlb/helper/check_boundary_overlaps.py | 22 ++++++++++++++++++ .../indices_boundary_masker.py | 9 ++------ 9 files changed, 61 insertions(+), 51 deletions(-) create mode 100644 xlb/helper/check_boundary_overlaps.py diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 994d18b..8b958fc 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -1,7 +1,7 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq +from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, @@ -48,15 +48,12 @@ def _setup(self, omega): self.setup_stepper(omega) def define_boundary_indices(self): - inlet = self.grid.boundingBoxIndices["left"] - outlet = self.grid.boundingBoxIndices["right"] - walls = [ - self.grid.boundingBoxIndices["bottom"][i] - + self.grid.boundingBoxIndices["top"][i] - + self.grid.boundingBoxIndices["front"][i] - + self.grid.boundingBoxIndices["back"][i] - for i in range(self.velocity_set.d) - ] + box = self.grid.bounding_box_indices() + box_noedge = self.grid.bounding_box_indices(remove_edges=True) + inlet = box_noedge["left"] + outlet = box_noedge["right"] + walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] + walls = np.unique(np.array(walls), axis=-1).tolist() sphere_radius = self.grid_shape[1] // 12 x = np.arange(self.grid_shape[0]) @@ -79,12 +76,12 @@ def setup_boundary_conditions(self): # bc_outlet = DoNothingBC(indices=outlet) bc_outlet = ExtrapolationOutflowBC(indices=outlet) bc_sphere = HalfwayBounceBackBC(indices=sphere) - self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere] - # Note: it is important to add bc_walls before bc_outlet/bc_inlet because - # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. def setup_boundary_masker(self): + # check boundary condition list for duplicate indices before creating bc mask + check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend) + indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 300614d..a77cc62 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -1,15 +1,16 @@ import xlb from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq +from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps from xlb.operator.boundary_masker import IndicesBoundaryMasker from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image +import xlb.velocity_set import warp as wp import jax.numpy as jnp -import xlb.velocity_set +import numpy as np class LidDrivenCavity2D: @@ -39,11 +40,11 @@ def _setup(self, omega): self.setup_stepper(omega) def define_boundary_indices(self): - lid = self.grid.boundingBoxIndices["top"] - walls = [ - self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["left"][i] + self.grid.boundingBoxIndices["right"][i] - for i in range(self.velocity_set.d) - ] + box = self.grid.bounding_box_indices() + box_noedge = self.grid.bounding_box_indices(remove_edges=True) + lid = box_noedge["top"] + walls = [box["bottom"][i] + box["left"][i] + box["right"][i] for i in range(self.velocity_set.d)] + walls = np.unique(np.array(walls), axis=-1).tolist() return lid, walls def setup_boundary_conditions(self): @@ -53,6 +54,8 @@ def setup_boundary_conditions(self): self.boundary_conditions = [bc_walls, bc_top] def setup_boundary_masker(self): + # check boundary condition list for duplicate indices before creating bc mask + check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend) indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index 2ec5560..eb73fdc 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -77,8 +77,8 @@ def _setup(self): def define_boundary_indices(self): # top and bottom sides of the channel are no-slip and the other directions are periodic - boundingBoxIndices = self.grid.bounding_box_indices(remove_edges=True) - walls = [boundingBoxIndices["bottom"][i] + boundingBoxIndices["top"][i] for i in range(self.velocity_set.d)] + box = self.grid.bounding_box_indices(remove_edges=True) + walls = [box["bottom"][i] + box["top"][i] for i in range(self.velocity_set.d)] return walls def setup_boundary_conditions(self): diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index b0ee5b9..a7567d0 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -3,7 +3,7 @@ import time from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy -from xlb.helper import create_nse_fields, initialize_eq +from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import ( FullwayBounceBackBC, @@ -67,15 +67,12 @@ def voxelize_stl(self, stl_filename, length_lbm_unit): return mesh_matrix, pitch def define_boundary_indices(self): - inlet = self.grid.boundingBoxIndices["left"] - outlet = self.grid.boundingBoxIndices["right"] - walls = [ - self.grid.boundingBoxIndices["bottom"][i] - + self.grid.boundingBoxIndices["top"][i] - + self.grid.boundingBoxIndices["front"][i] - + self.grid.boundingBoxIndices["back"][i] - for i in range(self.velocity_set.d) - ] + box = self.grid.bounding_box_indices() + box_noedge = self.grid.bounding_box_indices(remove_edges=True) + inlet = box_noedge["left"] + outlet = box_noedge["right"] + walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] + walls = np.unique(np.array(walls), axis=-1).tolist() # Load the mesh stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" @@ -105,10 +102,11 @@ def setup_boundary_conditions(self): bc_car = GradsApproximationBC(mesh_vertices=car) # bc_car = FullwayBounceBackBC(mesh_vertices=car) self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car] - # Note: it is important to add bc_walls before bc_outlet/bc_inlet because - # of the corner nodes. This way the corners are treated as wall and not inlet/outlet. def setup_boundary_masker(self): + # check boundary condition list for duplicate indices before creating bc mask + check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend) + indices_boundary_masker = IndicesBoundaryMasker( velocity_set=self.velocity_set, precision_policy=self.precision_policy, diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 1812d95..4c337e8 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -48,15 +48,10 @@ def create_grid_and_fields(cube_edge): def define_boundary_indices(grid): - lid = grid.boundingBoxIndices["top"] - walls = [ - grid.boundingBoxIndices["bottom"][i] - + grid.boundingBoxIndices["left"][i] - + grid.boundingBoxIndices["right"][i] - + grid.boundingBoxIndices["front"][i] - + grid.boundingBoxIndices["back"][i] - for i in range(len(grid.shape)) - ] + box = grid.bounding_box_indices() + box_noedge = grid.bounding_box_indices(remove_edges=True) + lid = box_noedge["top"] + walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))] return lid, walls diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 7494c3e..53139fc 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -25,7 +25,6 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): self.shape = shape self.dim = len(shape) self.compute_backend = compute_backend - self.boundingBoxIndices = self.bounding_box_indices() self._initialize_backend() @abstractmethod diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py index 92d3583..4c63aa6 100644 --- a/xlb/helper/__init__.py +++ b/xlb/helper/__init__.py @@ -1,2 +1,3 @@ from xlb.helper.nse_solver import create_nse_fields as create_nse_fields from xlb.helper.initializers import initialize_eq as initialize_eq +from xlb.helper.check_boundary_overlaps import check_bc_overlaps as check_bc_overlaps diff --git a/xlb/helper/check_boundary_overlaps.py b/xlb/helper/check_boundary_overlaps.py new file mode 100644 index 0000000..18e17a1 --- /dev/null +++ b/xlb/helper/check_boundary_overlaps.py @@ -0,0 +1,22 @@ +import numpy as np +from xlb.compute_backend import ComputeBackend + + +def check_bc_overlaps(bclist, dim, backend): + index_list = [[] for _ in range(dim)] + for bc in bclist: + if bc.indices is None: + continue + # Detect duplicates within bc.indices + index_arr = np.unique(bc.indices, axis=-1) + if index_arr.shape[-1] != len(bc.indices[0]): + if backend == ComputeBackend.WARP: + raise ValueError(f"Boundary condition {bc.__class__.__name__} has duplicate indices!") + for d in range(dim): + index_list[d] += bc.indices[d] + + # Detect duplicates within bclist + index_arr = np.unique(index_list, axis=-1) + if index_arr.shape[-1] != len(index_list[0]): + if backend == ComputeBackend.WARP: + raise ValueError("Boundary condition list containes duplicate indices!") diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 10d72c4..848d74e 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -225,14 +225,9 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - # Remove duplicates indices to avoid race conditioning - index_arr, unique_loc = np.unique(index_list, axis=-1, return_index=True) - id_arr = np.array(id_list)[unique_loc] - is_interior = np.array(is_interior)[unique_loc] - # convert to warp arrays - indices = wp.array2d(index_arr, dtype=wp.int32) - id_number = wp.array1d(id_arr, dtype=wp.uint8) + indices = wp.array2d(index_list, dtype=wp.int32) + id_number = wp.array1d(id_list, dtype=wp.uint8) is_interior = wp.array1d(is_interior, dtype=wp.bool) if start_index is None: From 8f869d23d0d7d0a19c5e995eb673395a401451bd Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 18 Oct 2024 17:02:15 -0400 Subject: [PATCH 139/144] minor --- xlb/operator/boundary_masker/indices_boundary_masker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 848d74e..e7e50d7 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -66,7 +66,7 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): start_index = (0,) * dim domain_shape = bc_mask[0].shape - for bc in reversed(bclist): + for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id From cee77b9691be618bec7203831545be31740d36e4 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Fri, 18 Oct 2024 17:31:38 -0400 Subject: [PATCH 140/144] addressing PR reviews --- examples/cfd/flow_past_sphere_3d.py | 25 +++------------------ examples/cfd/lid_driven_cavity_2d.py | 4 ++-- examples/cfd/windtunnel_3d.py | 6 ++--- examples/performance/mlups_3d.py | 4 ++-- xlb/distribute/__init__.py | 2 +- xlb/experimental/ooc/__init__.py | 4 ++-- xlb/helper/__init__.py | 6 ++--- xlb/helper/check_boundary_overlaps.py | 2 ++ xlb/operator/__init__.py | 4 ++-- xlb/operator/boundary_condition/__init__.py | 22 +++++++++--------- xlb/operator/boundary_masker/__init__.py | 12 +++------- xlb/operator/collision/__init__.py | 8 +++---- xlb/operator/equilibrium/__init__.py | 5 +---- xlb/operator/force/__init__.py | 4 ++-- xlb/operator/precision_caster/__init__.py | 2 +- xlb/operator/stepper/__init__.py | 4 ++-- xlb/operator/stream/__init__.py | 2 +- xlb/utils/__init__.py | 14 ++++++------ xlb/velocity_set/__init__.py | 8 +++---- 19 files changed, 55 insertions(+), 83 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 8b958fc..1b5905e 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -49,9 +49,9 @@ def _setup(self, omega): def define_boundary_indices(self): box = self.grid.bounding_box_indices() - box_noedge = self.grid.bounding_box_indices(remove_edges=True) - inlet = box_noedge["left"] - outlet = box_noedge["right"] + box_no_edge = self.grid.bounding_box_indices(remove_edges=True) + inlet = box_no_edge["left"] + outlet = box_no_edge["right"] walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] walls = np.unique(np.array(walls), axis=-1).tolist() @@ -101,8 +101,6 @@ def run(self, num_steps, post_process_interval=100): self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 - if i == 0: - self.check_boundary_mask() if i % post_process_interval == 0 or i == num_steps - 1: self.post_process(i) end_time = time.time() @@ -132,23 +130,6 @@ def post_process(self, i): # save_fields_vtk(fields, timestep=i) save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) - return - - def check_boundary_mask(self): - # Write the results. We'll use JAX backend for the post-processing - if not isinstance(self.f_0, jnp.ndarray): - bmask = wp.to_jax(self.bc_mask)[0] - else: - bmask = self.bc_mask[0] - - # save_fields_vtk(fields, timestep=i) - save_image(bmask[0, :, :], prefix="00_left") - save_image(bmask[self.grid_shape[0] - 1, :, :], prefix="00_right") - save_image(bmask[:, :, self.grid_shape[2] - 1], prefix="00_top") - save_image(bmask[:, :, 0], prefix="00_bottom") - save_image(bmask[:, 0, :], prefix="00_front") - save_image(bmask[:, self.grid_shape[1] - 1, :], prefix="00_back") - save_image(bmask[:, self.grid_shape[1] // 2, :], prefix="00_middle") if __name__ == "__main__": diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index a77cc62..dfb1092 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -41,8 +41,8 @@ def _setup(self, omega): def define_boundary_indices(self): box = self.grid.bounding_box_indices() - box_noedge = self.grid.bounding_box_indices(remove_edges=True) - lid = box_noedge["top"] + box_no_edge = self.grid.bounding_box_indices(remove_edges=True) + lid = box_no_edge["top"] walls = [box["bottom"][i] + box["left"][i] + box["right"][i] for i in range(self.velocity_set.d)] walls = np.unique(np.array(walls), axis=-1).tolist() return lid, walls diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index a7567d0..c83e2c9 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -68,9 +68,9 @@ def voxelize_stl(self, stl_filename, length_lbm_unit): def define_boundary_indices(self): box = self.grid.bounding_box_indices() - box_noedge = self.grid.bounding_box_indices(remove_edges=True) - inlet = box_noedge["left"] - outlet = box_noedge["right"] + box_no_edge = self.grid.bounding_box_indices(remove_edges=True) + inlet = box_no_edge["left"] + outlet = box_no_edge["right"] walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] walls = np.unique(np.array(walls), axis=-1).tolist() diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 4c337e8..2001fb2 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -49,8 +49,8 @@ def create_grid_and_fields(cube_edge): def define_boundary_indices(grid): box = grid.bounding_box_indices() - box_noedge = grid.bounding_box_indices(remove_edges=True) - lid = box_noedge["top"] + box_no_edge = grid.bounding_box_indices(remove_edges=True) + lid = box_no_edge["top"] walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))] return lid, walls diff --git a/xlb/distribute/__init__.py b/xlb/distribute/__init__.py index 25fa0af..dd9f33d 100644 --- a/xlb/distribute/__init__.py +++ b/xlb/distribute/__init__.py @@ -1 +1 @@ -from .distribute import distribute as distribute +from .distribute import distribute diff --git a/xlb/experimental/ooc/__init__.py b/xlb/experimental/ooc/__init__.py index 801683d..5206cc1 100644 --- a/xlb/experimental/ooc/__init__.py +++ b/xlb/experimental/ooc/__init__.py @@ -1,2 +1,2 @@ -from xlb.experimental.ooc.out_of_core import OOCmap as OOCmap -from xlb.experimental.ooc.ooc_array import OOCArray as OOCArray +from xlb.experimental.ooc.out_of_core import OOCmap +from xlb.experimental.ooc.ooc_array import OOCArray diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py index 4c63aa6..d52f206 100644 --- a/xlb/helper/__init__.py +++ b/xlb/helper/__init__.py @@ -1,3 +1,3 @@ -from xlb.helper.nse_solver import create_nse_fields as create_nse_fields -from xlb.helper.initializers import initialize_eq as initialize_eq -from xlb.helper.check_boundary_overlaps import check_bc_overlaps as check_bc_overlaps +from xlb.helper.nse_solver import create_nse_fields +from xlb.helper.initializers import initialize_eq +from xlb.helper.check_boundary_overlaps import check_bc_overlaps diff --git a/xlb/helper/check_boundary_overlaps.py b/xlb/helper/check_boundary_overlaps.py index 18e17a1..1adceb2 100644 --- a/xlb/helper/check_boundary_overlaps.py +++ b/xlb/helper/check_boundary_overlaps.py @@ -12,6 +12,7 @@ def check_bc_overlaps(bclist, dim, backend): if index_arr.shape[-1] != len(bc.indices[0]): if backend == ComputeBackend.WARP: raise ValueError(f"Boundary condition {bc.__class__.__name__} has duplicate indices!") + print(f"WARNING: there are duplicate indices in {bc.__class__.__name__} and hence the order in bc list matters!") for d in range(dim): index_list[d] += bc.indices[d] @@ -20,3 +21,4 @@ def check_bc_overlaps(bclist, dim, backend): if index_arr.shape[-1] != len(index_list[0]): if backend == ComputeBackend.WARP: raise ValueError("Boundary condition list containes duplicate indices!") + print("WARNING: there are duplicate indices in the boundary condition list and hence the order in this list matters!") diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index c88ef83..02b8a59 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.operator import Operator as Operator -from xlb.operator.parallel_operator import ParallelOperator as ParallelOperator +from xlb.operator.operator import Operator +from xlb.operator.parallel_operator import ParallelOperator diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 925dfdc..4782ea0 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,12 +1,10 @@ -from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition as BoundaryCondition -from xlb.operator.boundary_condition.boundary_condition_registry import ( - BoundaryConditionRegistry as BoundaryConditionRegistry, -) -from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC as EquilibriumBC -from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC as DoNothingBC -from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC -from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC -from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC -from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC -from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC -from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC as GradsApproximationBC +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry +from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC +from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC +from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC +from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC +from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC +from xlb.operator.boundary_condition.bc_regularized import RegularizedBC +from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC +from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index fbe851d..d03636a 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,9 +1,3 @@ -from xlb.operator.boundary_masker.indices_boundary_masker import ( - IndicesBoundaryMasker as IndicesBoundaryMasker, -) -from xlb.operator.boundary_masker.mesh_boundary_masker import ( - MeshBoundaryMasker as MeshBoundaryMasker, -) -from xlb.operator.boundary_masker.mesh_distance_boundary_masker import ( - MeshDistanceBoundaryMasker as MeshDistanceBoundaryMasker, -) +from xlb.operator.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker +from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker +from xlb.operator.boundary_masker.mesh_distance_boundary_masker import MeshDistanceBoundaryMasker diff --git a/xlb/operator/collision/__init__.py b/xlb/operator/collision/__init__.py index 0526c8a..2f92bb4 100644 --- a/xlb/operator/collision/__init__.py +++ b/xlb/operator/collision/__init__.py @@ -1,4 +1,4 @@ -from xlb.operator.collision.collision import Collision as Collision -from xlb.operator.collision.bgk import BGK as BGK -from xlb.operator.collision.kbc import KBC as KBC -from xlb.operator.collision.forced_collision import ForcedCollision as ForcedCollision +from xlb.operator.collision.collision import Collision +from xlb.operator.collision.bgk import BGK +from xlb.operator.collision.kbc import KBC +from xlb.operator.collision.forced_collision import ForcedCollision diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py index b9f9f08..987aa74 100644 --- a/xlb/operator/equilibrium/__init__.py +++ b/xlb/operator/equilibrium/__init__.py @@ -1,4 +1 @@ -from xlb.operator.equilibrium.quadratic_equilibrium import ( - Equilibrium as Equilibrium, - QuadraticEquilibrium as QuadraticEquilibrium, -) +from xlb.operator.equilibrium.quadratic_equilibrium import Equilibrium, QuadraticEquilibrium diff --git a/xlb/operator/force/__init__.py b/xlb/operator/force/__init__.py index 2f3e3da..ba8a13c 100644 --- a/xlb/operator/force/__init__.py +++ b/xlb/operator/force/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.force.momentum_transfer import MomentumTransfer as MomentumTransfer -from xlb.operator.force.exact_difference_force import ExactDifference as ExactDifference +from xlb.operator.force.momentum_transfer import MomentumTransfer +from xlb.operator.force.exact_difference_force import ExactDifference diff --git a/xlb/operator/precision_caster/__init__.py b/xlb/operator/precision_caster/__init__.py index c333ab7..a027c52 100644 --- a/xlb/operator/precision_caster/__init__.py +++ b/xlb/operator/precision_caster/__init__.py @@ -1 +1 @@ -from xlb.operator.precision_caster.precision_caster import PrecisionCaster as PrecisionCaster +from xlb.operator.precision_caster.precision_caster import PrecisionCaster diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py index 528375d..e5d159c 100644 --- a/xlb/operator/stepper/__init__.py +++ b/xlb/operator/stepper/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.stepper.stepper import Stepper as Stepper -from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper as IncompressibleNavierStokesStepper +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper diff --git a/xlb/operator/stream/__init__.py b/xlb/operator/stream/__init__.py index 2f5b2f3..9093da7 100644 --- a/xlb/operator/stream/__init__.py +++ b/xlb/operator/stream/__init__.py @@ -1 +1 @@ -from xlb.operator.stream.stream import Stream as Stream +from xlb.operator.stream.stream import Stream diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py index 6f1f61a..3c8032e 100644 --- a/xlb/utils/__init__.py +++ b/xlb/utils/__init__.py @@ -1,9 +1,9 @@ from .utils import ( - downsample_field as downsample_field, - save_image as save_image, - save_fields_vtk as save_fields_vtk, - save_BCs_vtk as save_BCs_vtk, - rotate_geometry as rotate_geometry, - voxelize_stl as voxelize_stl, - axangle2mat as axangle2mat, + downsample_field, + save_image, + save_fields_vtk, + save_BCs_vtk, + rotate_geometry, + voxelize_stl, + axangle2mat, ) diff --git a/xlb/velocity_set/__init__.py b/xlb/velocity_set/__init__.py index c1338db..5b7b737 100644 --- a/xlb/velocity_set/__init__.py +++ b/xlb/velocity_set/__init__.py @@ -1,4 +1,4 @@ -from xlb.velocity_set.velocity_set import VelocitySet as VelocitySet -from xlb.velocity_set.d2q9 import D2Q9 as D2Q9 -from xlb.velocity_set.d3q19 import D3Q19 as D3Q19 -from xlb.velocity_set.d3q27 import D3Q27 as D3Q27 +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set.d2q9 import D2Q9 +from xlb.velocity_set.d3q19 import D3Q19 +from xlb.velocity_set.d3q27 import D3Q27 From de67c099e5c224186e828495ca7d8de4a8e31e0b Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 18 Oct 2024 17:53:22 -0400 Subject: [PATCH 141/144] merged 2D and 3D Warp kernels --- examples/cfd/lid_driven_cavity_2d.py | 3 +- requirements.txt | 25 ++- tests/grids/test_grid_warp.py | 7 +- tests/kernels/stream/test_stream_warp.py | 7 +- xlb/grid/warp_grid.py | 4 +- .../boundary_condition/bc_do_nothing.py | 65 +------- .../boundary_condition/bc_equilibrium.py | 66 +------- .../bc_extrapolation_outflow.py | 121 +------------- .../bc_fullway_bounce_back.py | 64 +------ .../bc_grads_approximation.py | 41 +---- .../bc_halfway_bounce_back.py | 65 +------- .../boundary_condition/bc_regularized.py | 156 ++---------------- xlb/operator/boundary_condition/bc_zouhe.py | 150 ++--------------- .../boundary_condition/boundary_condition.py | 30 +--- .../indices_boundary_masker.py | 137 ++++++--------- xlb/operator/collision/bgk.py | 31 +--- xlb/operator/collision/forced_collision.py | 36 +--- xlb/operator/collision/kbc.py | 62 +------ .../equilibrium/quadratic_equilibrium.py | 25 +-- xlb/operator/force/exact_difference_force.py | 29 +--- xlb/operator/force/momentum_transfer.py | 60 +------ xlb/operator/macroscopic/first_moment.py | 22 +-- xlb/operator/macroscopic/macroscopic.py | 22 +-- xlb/operator/macroscopic/second_moment.py | 23 +-- xlb/operator/macroscopic/zero_moment.py | 19 +-- xlb/operator/stepper/nse_stepper.py | 69 +------- xlb/operator/stream/stream.py | 50 +----- 27 files changed, 127 insertions(+), 1262 deletions(-) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 20f3b7c..383a110 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -77,7 +77,8 @@ def run(self, num_steps, post_process_interval=100): def post_process(self, i): # Write the results. We'll use JAX backend for the post-processing if not isinstance(self.f_0, jnp.ndarray): - f_0 = wp.to_jax(self.f_0) + # If the backend is warp, we need to drop the last dimension added by warp for 2D simulations + f_0 = wp.to_jax(self.f_0)[..., 0] else: f_0 = self.f_0 diff --git a/requirements.txt b/requirements.txt index ee107af..4d0cd2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,10 @@ -jax==0.4.20 -jaxlib==0.4.20 -matplotlib==3.8.0 -numpy==1.26.1 -pyvista==0.43.4 -Rtree==1.0.1 -trimesh==4.4.1 -orbax-checkpoint==0.4.1 -termcolor==2.3.0 -PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git -tqdm==4.66.2 -warp-lang==1.0.2 -numpy-stl==3.1.1 -pydantic==2.7.0 -ruff==0.5.6 \ No newline at end of file +jax[cuda] +matplotlib +numpy +pyvista +Rtree +trimesh +warp-lang +numpy-stl +pydantic +ruff \ No newline at end of file diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 782434d..61b27d4 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -22,8 +22,10 @@ def test_warp_grid_create_field(grid_size): init_xlb_env(xlb.velocity_set.D3Q19) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) - - assert f.shape == (9,) + grid_shape, "Field shape is incorrect" + if len(grid_shape) == 2: + assert f.shape == (9,) + grid_shape + (1,), "Field shape is incorrect got {}".format(f.shape) + else: + assert f.shape == (9,) + grid_shape, "Field shape is incorrect got {}".format(f.shape) assert isinstance(f, wp.array), "Field should be a Warp ndarray" @@ -37,7 +39,6 @@ def test_warp_grid_create_field_fill_value(): assert isinstance(f, wp.array), "Field should be a Warp ndarray" f = f.numpy() - assert f.shape == (9,) + grid_shape, "Field shape is incorrect" assert np.allclose(f, fill_value), "Field not properly initialized with fill_value" diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index 0d100cf..95fcc05 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -61,7 +61,7 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape): expected = jnp.stack(expected, axis=0) if dim == 2: - f_initial_warp = wp.array(f_initial) + f_initial_warp = wp.array(f_initial[..., np.newaxis]) elif dim == 3: f_initial_warp = wp.array(f_initial) @@ -71,7 +71,10 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape): f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q) f_streamed = stream_op(f_initial_warp, f_streamed) - assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected" + if len(grid_shape) == 2: + assert jnp.allclose(f_streamed.numpy()[..., 0], np.array(expected)), "Streaming did not occur as expected" + else: + assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected" if __name__ == "__main__": diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index 5018962..c74fc2f 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -21,7 +21,9 @@ def create_field( fill_value=None, ): dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype - shape = (cardinality,) + (self.shape) + + # Check if shape is 2D, and if so, append a singleton dimension to the shape + shape = (cardinality,) + (self.shape if len(self.shape) != 2 else self.shape + (1,)) if fill_value is None: f = wp.zeros(shape, dtype=dtype) diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 4b7ac90..67b343d 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -64,67 +64,4 @@ def functional( ): return f_pre - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.uint8), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(DoNothingBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the result - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(DoNothingBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the result - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + return functional, None diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 4dd4b9e..b4b957a 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -88,68 +88,4 @@ def functional( _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f - # Construct the warp kernel - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(EquilibriumBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the result - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(EquilibriumBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the result - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + return functional, None diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 53645c6..4a96c73 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -140,15 +140,7 @@ def _construct_warp(self): _opp_indices = self.velocity_set.opp_indices @wp.func - def get_normal_vectors_2d( - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -wp.vec2i(_c[0, l], _c[1, l]) - - @wp.func - def get_normal_vectors_3d( + def get_normal_vectors( missing_mask: Any, ): for l in range(_q): @@ -175,7 +167,7 @@ def functional( return _f @wp.func - def prepare_bc_auxilary_data_2d( + def prepare_bc_auxilary_data( index: Any, timestep: Any, missing_mask: Any, @@ -188,34 +180,7 @@ def prepare_bc_auxilary_data_2d( # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. _f = f_post - nv = get_normal_vectors_2d(missing_mask) - for l in range(self.velocity_set.q): - if missing_mask[l] == wp.uint8(1): - # f_0 is the post-collision values of the current time-step - # Get pull index associated with the "neighbours" pull_index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - (_c[d, l] + nv[d]) - # The following is the post-streaming values of the neighbor cell - f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) - _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux - return _f - - @wp.func - def prepare_bc_auxilary_data_3d( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # Preparing the formulation for this BC using the neighbour's populations stored in f_aux and - # f_pre (post-streaming values of the current voxel). We use directions that leave the domain - # for storing this prepared data. - _f = f_post - nv = get_normal_vectors_3d(missing_mask) + nv = get_normal_vectors(missing_mask) for l in range(self.velocity_set.q): if missing_mask[l] == wp.uint8(1): # f_0 is the post-collision values of the current time-step @@ -228,82 +193,4 @@ def prepare_bc_auxilary_data_3d( _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux return _f - # Construct the warp kernel - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - timestep = 0 - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # special preparation of auxiliary data - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - _f_pre = prepare_bc_auxilary_data_2d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post) - - # Apply the boundary condition - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - # TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both - # collision and streaming? - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - timestep = 0 - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - _f_aux = _f_vec() - - # special preparation of auxiliary data - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - _f_pre = prepare_bc_auxilary_data_3d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post) - - # Apply the boundary condition - if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id): - # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both - # collision and streaming? - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - prepare_bc_auxilary_data = prepare_bc_auxilary_data_3d if self.velocity_set.d == 3 else prepare_bc_auxilary_data_2d - - return (functional, prepare_bc_auxilary_data), kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + return (functional, prepare_bc_auxilary_data), None \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 8569e84..afe05de 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -74,66 +74,4 @@ def functional( fliped_f[l] = f_pre[_opp_indices[l]] return fliped_f - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # Check if the boundary is active - if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the result to the output - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - - # Check if the boundary is active - if _boundary_id == wp.uint8(FullwayBounceBackBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the result to the output - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + return functional, None \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 3d60879..f5dc343 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -308,45 +308,6 @@ def functional_method2( f_post = grads_approximate_fpop(missing_mask, rho_target, u_target, f_post) return f_post - # Construct the warp kernel - @wp.kernel - def kernel( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - timestep = 0 - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - _f_aux = _f_vec() - - # Apply the boundary condition - if _boundary_id == wp.uint8(GradsApproximationBC.id): - # TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both - # collision and streaming? - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - functional = functional_method1 - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + return functional, None \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index e8df6b7..ee68b50 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -87,68 +87,5 @@ def functional( return _f - # Construct the warp kernel - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) + return functional, None - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 065a0b0..12622e2 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -159,15 +159,7 @@ def _get_fsum( return fsum_known + fsum_middle @wp.func - def get_normal_vectors_2d( - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l]) - - @wp.func - def get_normal_vectors_3d( + def get_normal_vectors( missing_mask: Any, ): for l in range(_q): @@ -211,7 +203,7 @@ def regularize_fpop( return fpop @wp.func - def functional3d_velocity( + def functional_velocity( index: Any, timestep: Any, missing_mask: Any, @@ -224,7 +216,7 @@ def functional3d_velocity( _f = f_post # Find normal vector - normals = get_normal_vectors_3d(missing_mask) + normals = get_normal_vectors(missing_mask) # calculate rho fsum = _get_fsum(_f, missing_mask) @@ -242,7 +234,7 @@ def functional3d_velocity( return _f @wp.func - def functional3d_pressure( + def functional_pressure( index: Any, timestep: Any, missing_mask: Any, @@ -255,7 +247,7 @@ def functional3d_pressure( _f = f_post # Find normal vector - normals = get_normal_vectors_3d(missing_mask) + normals = get_normal_vectors(missing_mask) # calculate velocity fsum = _get_fsum(_f, missing_mask) @@ -270,136 +262,8 @@ def functional3d_pressure( _f = regularize_fpop(_f, feq) return _f - @wp.func - def functional2d_velocity( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # Post-streaming values are only modified at missing direction - _f = f_post - - # Find normal vector - normals = get_normal_vectors_2d(missing_mask) - - # calculate rho - fsum = _get_fsum(_f, missing_mask) - unormal = self.compute_dtype(0.0) - for d in range(_d): - unormal += _u[d] * normals[d] - _rho = fsum / (self.compute_dtype(1.0) + unormal) - - # impose non-equilibrium bounceback - feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) - - # Regularize the boundary fpop - _f = regularize_fpop(_f, feq) - return _f - - @wp.func - def functional2d_pressure( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # Post-streaming values are only modified at missing direction - _f = f_post - - # Find normal vector - normals = get_normal_vectors_2d(missing_mask) - - # calculate velocity - fsum = _get_fsum(_f, missing_mask) - unormal = -self.compute_dtype(1.0) + fsum / _rho - _u = unormal * normals - - # impose non-equilibrium bounceback - feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) - - # Regularize the boundary fpop - _f = regularize_fpop(_f, feq) - return _f - - # Construct the warp kernel - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - if self.velocity_set.d == 3 and self.bc_type == "velocity": - functional = functional3d_velocity - elif self.velocity_set.d == 3 and self.bc_type == "pressure": - functional = functional3d_pressure - elif self.bc_type == "velocity": - functional = functional2d_velocity - else: - functional = functional2d_pressure - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + if self.bc_type == "velocity": + functional = functional_velocity + elif self.bc_type == "pressure": + functional = functional_pressure + return functional, None \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 4be2cf2..c5d9498 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -189,15 +189,6 @@ def _construct_warp(self): _c_float = self.velocity_set.c_float # TODO: this is way less than ideal. we should not be making new types - @wp.func - def get_normal_vectors_2d( - lattice_direction: Any, - ): - l = lattice_direction - if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - normals = -_u_vec(_c_float[0, l], _c_float[1, l]) - return normals - @wp.func def _get_fsum( fpop: Any, @@ -213,7 +204,7 @@ def _get_fsum( return fsum_known + fsum_middle @wp.func - def get_normal_vectors_3d( + def get_normal_vectors( missing_mask: Any, ): for l in range(_q): @@ -232,7 +223,7 @@ def bounceback_nonequilibrium( return fpop @wp.func - def functional3d_velocity( + def functional_velocity( index: Any, timestep: Any, missing_mask: Any, @@ -245,7 +236,7 @@ def functional3d_velocity( _f = f_post # Find normal vector - normals = get_normal_vectors_3d(missing_mask) + normals = get_normal_vectors(missing_mask) # calculate rho fsum = _get_fsum(_f, missing_mask) @@ -260,7 +251,7 @@ def functional3d_velocity( return _f @wp.func - def functional3d_pressure( + def functional_pressure( index: Any, timestep: Any, missing_mask: Any, @@ -273,7 +264,7 @@ def functional3d_pressure( _f = f_post # Find normal vector - normals = get_normal_vectors_3d(missing_mask) + normals = get_normal_vectors(missing_mask) # calculate velocity fsum = _get_fsum(_f, missing_mask) @@ -285,130 +276,11 @@ def functional3d_pressure( _f = bounceback_nonequilibrium(_f, feq, missing_mask) return _f - @wp.func - def functional2d_velocity( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # Post-streaming values are only modified at missing direction - _f = f_post - - # Find normal vector - normals = get_normal_vectors_2d(missing_mask) - - # calculate rho - fsum = _get_fsum(_f, missing_mask) - unormal = self.compute_dtype(0.0) - for d in range(_d): - unormal += _u[d] * normals[d] - _rho = fsum / (self.compute_dtype(1.0) + unormal) - - # impose non-equilibrium bounceback - feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) - return _f - - @wp.func - def functional2d_pressure( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # Post-streaming values are only modified at missing direction - _f = f_post - - # Find normal vector - normals = get_normal_vectors_2d(missing_mask) - - # calculate velocity - fsum = _get_fsum(_f, missing_mask) - unormal = -self.compute_dtype(1.0) + fsum / _rho - _u = unormal * normals - - # impose non-equilibrium bounceback - feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) - return _f - - # Construct the warp kernel - @wp.kernel - def kernel2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index) - - # Apply the boundary condition - if _boundary_id == wp.uint8(self.id): - timestep = 0 - _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) - else: - _f = _f_post - - # Write the distribution function - for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - if self.velocity_set.d == 3 and self.bc_type == "velocity": - functional = functional3d_velocity - elif self.velocity_set.d == 3 and self.bc_type == "pressure": - functional = functional3d_pressure + if self.bc_type == "velocity": + functional = functional_velocity + elif self.bc_type == "pressure": + functional = functional_pressure elif self.bc_type == "velocity": - functional = functional2d_velocity - else: - functional = functional2d_pressure + functional = functional_pressure - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post + return functional, None \ No newline at end of file diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 6d72fc0..e724b27 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -74,32 +74,7 @@ def prepare_bc_auxilary_data( return f_post @wp.func - def _get_thread_data_2d( - f_pre: wp.array3d(dtype=Any), - f_post: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - index: wp.vec2i, - ): - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = bc_mask[0, index[0], index[1]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1]]) - _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1]]) - - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_id, _missing_mask - - @wp.func - def _get_thread_data_3d( + def _get_thread_data( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), @@ -125,8 +100,7 @@ def _get_thread_data_3d( # Construct some helper warp functions for getting tid data if self.compute_backend == ComputeBackend.WARP: - self._get_thread_data_2d = _get_thread_data_2d - self._get_thread_data_3d = _get_thread_data_3d + self._get_thread_data = _get_thread_data self.prepare_bc_auxilary_data = prepare_bc_auxilary_data @partial(jit, static_argnums=(0,), inline=True) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index fdc4331..16650fe 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -31,18 +31,12 @@ def are_indices_in_interior(self, indices, shape): Check if each 2D or 3D index is inside the bounds of the domain with the given shape and not at its boundary. - :param indices: List of tuples, where each tuple contains indices for each dimension. + :param indices: Array of indices, where each column contains indices for each dimension. :param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D. - :return: List of boolean flags where each flag indicates whether the corresponding index is inside the bounds. + :return: Array of boolean flags where each flag indicates whether the corresponding index is inside the bounds. """ - # Ensure that the number of dimensions in indices matches the domain shape - dim = len(shape) - if len(indices) != dim: - raise ValueError(f"Indices tuple must have {dim} dimensions to match the domain shape.") - - # Check each index tuple and return a list of boolean flags - flags = [all(0 < idx[d] < shape[d] - 1 for d in range(dim)) for idx in np.array(indices).T] - return flags + shape_array = np.array(shape) + return np.all((indices > 0) & (indices < shape_array[:, np.newaxis] - 1), axis=0) @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! @@ -70,11 +64,12 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" id_number = bc.id - local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] + bc_indices = np.array(bc.indices) + local_indices = bc_indices - np.array(start_index)[:, np.newaxis] padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] bmap = bmap.at[tuple(padded_indices)].set(id_number) - if any(self.are_indices_in_interior(bc.indices, domain_shape)) and bc.needs_padding: - # checking if all indices associated with this BC are in the interior of the domain (not at the boundary). + if any(self.are_indices_in_interior(bc_indices, domain_shape)) and bc.needs_padding: + # checking if all indices associated with this BC are in the interior of the domain. # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain. if dim == 2: grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) @@ -103,59 +98,9 @@ def _construct_warp(self): _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) - # Construct the warp 2D kernel - @wp.kernel - def kernel2d( - indices: wp.array2d(dtype=wp.int32), - id_number: wp.array1d(dtype=wp.uint8), - is_interior: wp.array1d(dtype=wp.bool), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - start_index: wp.vec2i, - ): - # Get the index of indices - ii = wp.tid() - - # Get local indices - index = wp.vec2i() - index[0] = indices[0, ii] - start_index[0] - index[1] = indices[1, ii] - start_index[1] - - # Check if index is in bounds - if index[0] >= 0 and index[0] < missing_mask.shape[1] and index[1] >= 0 and index[1] < missing_mask.shape[2]: - # Stream indices - for l in range(_q): - # Get the index of the streaming direction - pull_index = wp.vec2i() - push_index = wp.vec2i() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - push_index[d] = index[d] + _c[d, l] - - # set bc_mask for all bc indices - bc_mask[0, index[0], index[1]] = id_number[ii] - - # check if pull index is out of bound - # These directions will have missing information after streaming - if pull_index[0] < 0 or pull_index[0] >= missing_mask.shape[1] or pull_index[1] < 0 or pull_index[1] >= missing_mask.shape[2]: - # Set the missing mask - missing_mask[l, index[0], index[1]] = True - - # handling geometries in the interior of the computational domain - elif ( - is_interior[ii] - and push_index[0] >= 0 - and push_index[0] < missing_mask.shape[1] - and push_index[1] >= 0 - and push_index[1] < missing_mask.shape[2] - ): - # Set the missing mask - missing_mask[l, push_index[0], push_index[1]] = True - bc_mask[0, push_index[0], push_index[1]] = id_number[ii] - # Construct the warp 3D kernel @wp.kernel - def kernel3d( + def kernel( indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), is_interior: wp.array1d(dtype=wp.bool), @@ -220,46 +165,72 @@ def kernel3d( missing_mask[l, push_index[0], push_index[1], push_index[2]] = True bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): - dim = self.velocity_set.d - index_list = [[] for _ in range(dim)] - id_list = [] - is_interior = [] + # Pre-allocate arrays with maximum possible size + max_size = sum(len(bc.indices[0]) if isinstance(bc.indices, list) else bc.indices.shape[1] for bc in bclist if bc.indices is not None) + indices = np.zeros((3, max_size), dtype=np.int32) + id_numbers = np.zeros(max_size, dtype=np.uint8) + is_interior = np.zeros(max_size, dtype=bool) + + current_index = 0 for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" - for d in range(dim): - index_list[d] += bc.indices[d] - id_list += [bc.id] * len(bc.indices[0]) - is_interior += self.are_indices_in_interior(bc.indices, bc_mask[0].shape) if bc.needs_padding else [False] * len(bc.indices[0]) - # We are done with bc.indices. Remove them from BC objects + bc_indices = np.asarray(bc.indices) + num_indices = bc_indices.shape[1] + + # Ensure indices are 3D + if bc_indices.shape[0] == 2: + bc_indices = np.vstack([bc_indices, np.zeros(num_indices, dtype=int)]) + + # Add indices to the pre-allocated array + indices[:, current_index : current_index + num_indices] = bc_indices + + # Set id numbers + id_numbers[current_index : current_index + num_indices] = bc.id + + # Set is_interior flags + if bc.needs_padding: + is_interior[current_index : current_index + num_indices] = self.are_indices_in_interior(bc_indices, bc_mask[0].shape) + else: + is_interior[current_index : current_index + num_indices] = False + + current_index += num_indices + + # Remove indices from BC objects bc.__dict__.pop("indices", None) - indices = wp.array2d(index_list, dtype=wp.int32) - id_number = wp.array1d(id_list, dtype=wp.uint8) - is_interior = wp.array1d(is_interior, dtype=wp.bool) + # Trim arrays to actual size + indices = indices[:, :current_index] + id_numbers = id_numbers[:current_index] + is_interior = is_interior[:current_index] + + # Convert to Warp arrays + wp_indices = wp.array(indices, dtype=wp.int32) + wp_id_numbers = wp.array(id_numbers, dtype=wp.uint8) + wp_is_interior = wp.array(is_interior, dtype=wp.bool) if start_index is None: - start_index = (0,) * dim + start_index = wp.vec3i(0, 0, 0) + else: + start_index = wp.vec3i(*start_index) # Launch the warp kernel wp.launch( self.warp_kernel, + dim=current_index, inputs=[ - indices, - id_number, - is_interior, + wp_indices, + wp_id_numbers, + wp_is_interior, bc_mask, missing_mask, start_index, ], - dim=indices.shape[1], ) return bc_mask, missing_mask diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 60f63ef..115ed9a 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -36,34 +36,7 @@ def functional(f: Any, feq: Any, rho: Any, u: Any): # Construct the warp kernel @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - feq: wp.array3d(dtype=Any), - fout: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Load needed values - _f = _f_vec() - _feq = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _feq[l] = feq[l, index[0], index[1]] - - # Compute the collision - _fout = functional(_f, _feq, rho, u) - - # Write the result - for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), @@ -88,8 +61,6 @@ def kernel3d( for l in range(self.velocity_set.q): fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/collision/forced_collision.py b/xlb/operator/collision/forced_collision.py index 31ef392..2036bab 100644 --- a/xlb/operator/collision/forced_collision.py +++ b/xlb/operator/collision/forced_collision.py @@ -52,39 +52,7 @@ def functional(f: Any, feq: Any, rho: Any, u: Any): # Construct the warp kernel @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - feq: wp.array3d(dtype=Any), - fout: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) # TODO: Warp needs to fix this - - # Load needed values - _f = _f_vec() - _feq = _f_vec() - _d = self.velocity_set.d - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _feq[l] = feq[l, index[0], index[1]] - _u = _u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1]] - _rho = rho[0, index[0], index[1]] - - # Compute the collision - _fout = functional(_f, _feq, _rho, _u) - - # Write the result - for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] - - # Construct the warp kernel - @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), @@ -114,8 +82,6 @@ def kernel3d( for l in range(self.velocity_set.q): fout[l, index[0], index[1], index[2]] = _fout[l] - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index cc2fb04..d94f5eb 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -263,30 +263,7 @@ def entropic_scalar_product( # Construct the functional @wp.func - def functional2d( - f: Any, - feq: Any, - rho: Any, - u: Any, - ): - # Compute shear and delta_s - fneq = f - feq - shear = decompose_shear_d2q9(fneq) - delta_s = shear * rho # TODO: Check this - - # Perform collision - delta_h = fneq - delta_s - two = self.compute_dtype(2.0) - gamma = _inv_beta - (two - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( - _epsilon + entropic_scalar_product(delta_h, delta_h, feq) - ) - fout = f - _beta * (two * delta_s + gamma * delta_h) - - return fout - - # Construct the functional - @wp.func - def functional3d( + def functional( f: Any, feq: Any, rho: Any, @@ -309,39 +286,7 @@ def functional3d( # Construct the warp kernel @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - feq: wp.array3d(dtype=Any), - fout: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) # TODO: Warp needs to fix this - - # Load needed values - _f = _f_vec() - _feq = _f_vec() - _d = self.velocity_set.d - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _feq[l] = feq[l, index[0], index[1]] - _u = _u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1]] - _rho = rho[0, index[0], index[1]] - - # Compute the collision - _fout = functional(_f, _feq, _rho, _u) - - # Write the result - for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), @@ -371,9 +316,6 @@ def kernel3d( for l in range(self.velocity_set.q): fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) - functional = functional3d if self.velocity_set.d == 3 else functional2d - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index ba337f0..62cc041 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -61,7 +61,7 @@ def functional( # Construct the warp kernel @wp.kernel - def kernel3d( + def kernel( rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), f: wp.array4d(dtype=Any), @@ -81,29 +81,6 @@ def kernel3d( for l in range(self.velocity_set.q): f[l, index[0], index[1], index[2]] = self.store_dtype(feq[l]) - @wp.kernel - def kernel2d( - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - f: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Get the equilibrium - _u = _u_vec() - for d in range(self.velocity_set.d): - _u[d] = u[d, index[0], index[1]] - _rho = rho[0, index[0], index[1]] - feq = functional(_rho, _u) - - # Set the output - for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = self.store_dtype(feq[l]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/force/exact_difference_force.py b/xlb/operator/force/exact_difference_force.py index b4da602..ec1ef5b 100644 --- a/xlb/operator/force/exact_difference_force.py +++ b/xlb/operator/force/exact_difference_force.py @@ -86,33 +86,7 @@ def functional(f_postcollision: Any, feq: Any, rho: Any, u: Any): # Construct the warp kernel @wp.kernel - def kernel2d( - f_postcollision: Any, - feq: Any, - fout: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Load needed values - _u = _u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1]] - _rho = rho[0, index[0], index[1]] - - # Compute the collision - _fout = functional(f_postcollision, feq, _rho, _u) - - # Write the result - for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) - - # Construct the warp kernel - @wp.kernel - def kernel3d( + def kernel( f_postcollision: Any, feq: Any, fout: wp.array4d(dtype=Any), @@ -136,7 +110,6 @@ def kernel3d( for l in range(self.velocity_set.q): fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 8b0aacf..dbd5307 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -102,62 +102,7 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel - def kernel2d( - f_0: wp.array3d(dtype=Any), - f_1: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=wp.uint8), - missing_mask: wp.array3d(dtype=wp.bool), - force: wp.array(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Get the boundary id - _boundary_id = bc_mask[0, index[0], index[1]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # TODO fix vec bool - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) - - # Determin if boundary is an edge by checking if center is missing - is_edge = wp.bool(False) - if _boundary_id == wp.uint8(_no_slip_id): - if _missing_mask[_zero_index] == wp.uint8(0): - is_edge = wp.bool(True) - - # If the boundary is an edge then add the momentum transfer - m = _u_vec() - if is_edge: - # Get the distribution function - f_post_collision = _f_vec() - for l in range(self.velocity_set.q): - f_post_collision[l] = f_0[l, index[0], index[1]] - - # Apply streaming (pull method) - timestep = 0 - f_post_stream = self.stream.warp_functional(f_0, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) - - # Compute the momentum transfer - for d in range(self.velocity_set.d): - m[d] = self.compute_dtype(0.0) - for l in range(self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] - if _c[d, _opp_indices[l]] == 1: - m[d] += phi - elif _c[d, _opp_indices[l]] == -1: - m[d] -= phi - - wp.atomic_add(force, 0, m) - - # Construct the warp kernel - @wp.kernel - def kernel3d( + def kernel( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), @@ -210,9 +155,6 @@ def kernel3d( wp.atomic_add(force, 0, m) - # Return the correct kernel - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py index 329a71f..cb99a9f 100644 --- a/xlb/operator/macroscopic/first_moment.py +++ b/xlb/operator/macroscopic/first_moment.py @@ -38,7 +38,7 @@ def functional( return u @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), @@ -55,26 +55,6 @@ def kernel3d( for d in range(self.velocity_set.d): u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) - @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - i, j = wp.tid() - index = wp.vec2i(i, j) - - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _rho = rho[0, index[0], index[1]] - _u = functional(_f, _rho) - - for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = self.store_dtype(_u[d]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index b574436..ab1193b 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -37,7 +37,7 @@ def functional(f: _f_vec): return rho, u @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), @@ -54,26 +54,6 @@ def kernel3d( for d in range(self.velocity_set.d): u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) - @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - u: wp.array3d(dtype=Any), - ): - i, j = wp.tid() - index = wp.vec2i(i, j) - - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _rho, _u = functional(_f) - - rho[0, index[0], index[1]] = self.store_dtype(_rho) - for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = self.store_dtype(_u[d]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index 687b38a..6c7e70e 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -79,7 +79,7 @@ def functional( # Construct the kernel @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), pi: wp.array4d(dtype=Any), ): @@ -97,27 +97,6 @@ def kernel3d( for d in range(_pi_dim): pi[d, index[0], index[1], index[2]] = self.store_dtype(_pi[d]) - @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - pi: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Get the equilibrium - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _pi = functional(_f) - - # Set the output - for d in range(_pi_dim): - pi[d, index[0], index[1]] = self.store_dtype(_pi[d]) - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/macroscopic/zero_moment.py b/xlb/operator/macroscopic/zero_moment.py index d0fbf51..8abb4de 100644 --- a/xlb/operator/macroscopic/zero_moment.py +++ b/xlb/operator/macroscopic/zero_moment.py @@ -27,7 +27,7 @@ def functional(f: _f_vec): return rho @wp.kernel - def kernel3d( + def kernel( f: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), ): @@ -41,23 +41,6 @@ def kernel3d( rho[0, index[0], index[1], index[2]] = _rho - @wp.kernel - def kernel2d( - f: wp.array3d(dtype=Any), - rho: wp.array3d(dtype=Any), - ): - i, j = wp.tid() - index = wp.vec2i(i, j) - - _f = _f_vec() - for l in range(self.velocity_set.q): - _f[l] = f[l, index[0], index[1]] - _rho = functional(_f) - - rho[0, index[0], index[1]] = _rho - - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 99431eb..e08e95c 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -140,27 +140,7 @@ def apply_bc( return f_result @wp.func - def get_thread_data_2d( - f0_buffer: wp.array3d(dtype=Any), - f1_buffer: wp.array3d(dtype=Any), - missing_mask: wp.array3d(dtype=Any), - index: Any, - ): - # Read thread data for populations and missing mask - _f0_thread = _f_vec() - _f1_thread = _f_vec() - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1]]) - _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1]]) - if missing_mask[l, index[0], index[1]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) - return _f0_thread, _f1_thread, _missing_mask - - @wp.func - def get_thread_data_3d( + def get_thread_data( f0_buffer: wp.array4d(dtype=Any), f1_buffer: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), @@ -182,47 +162,7 @@ def get_thread_data_3d( return _f0_thread, _f1_thread, _missing_mask @wp.kernel - def kernel2d( - f_0: wp.array3d(dtype=Any), - f_1: wp.array3d(dtype=Any), - bc_mask: wp.array3d(dtype=Any), - missing_mask: wp.array3d(dtype=Any), - timestep: int, - ): - i, j = wp.tid() - index = wp.vec2i(i, j) - - _boundary_id = bc_mask[0, index[0], index[1]] - if _boundary_id == wp.uint8(255): - return - - # Apply streaming - _f_post_stream = self.stream.warp_functional(f_0, index) - - _f0_thread, _f1_thread, _missing_mask = get_thread_data_2d(f_0, f_1, missing_mask, index) - _f_post_collision = _f0_thread - - # Apply post-streaming boundary conditions - _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) - - # Compute rho and u - _rho, _u = self.macroscopic.warp_functional(_f_post_stream) - - # Compute equilibrium - _feq = self.equilibrium.warp_functional(_rho, _u) - - # Apply collision - _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) - - # Apply post-collision boundary conditions - _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) - - # Store the result in f_1 - for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) - - @wp.kernel - def kernel3d( + def kernel( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=Any), @@ -239,7 +179,7 @@ def kernel3d( # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - _f0_thread, _f1_thread, _missing_mask = get_thread_data_3d(f_0, f_1, missing_mask, index) + _f0_thread, _f1_thread, _missing_mask = get_thread_data(f_0, f_1, missing_mask, index) _f_post_collision = _f0_thread # Apply post-streaming boundary conditions @@ -261,9 +201,6 @@ def kernel3d( f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) - # Return the correct kernel - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel @Operator.register_backend(ComputeBackend.WARP) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index dc2417a..247fa5a 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -55,50 +55,9 @@ def _construct_warp(self): _c = self.velocity_set.c _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - # Construct the warp functional - @wp.func - def functional2d( - f: wp.array3d(dtype=Any), - index: Any, - ): - # Pull the distribution function - _f = _f_vec() - for l in range(self.velocity_set.q): - # Get pull index - pull_index = type(index)() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - - # impose periodicity for out of bound values - if pull_index[d] < 0: - pull_index[d] = f.shape[d + 1] - 1 - elif pull_index[d] >= f.shape[d + 1]: - pull_index[d] = 0 - - # Read the distribution function - _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1]]) - - return _f - - @wp.kernel - def kernel2d( - f_0: wp.array3d(dtype=Any), - f_1: wp.array3d(dtype=Any), - ): - # Get the global index - i, j = wp.tid() - index = wp.vec2i(i, j) - - # Set the output - _f = functional2d(f_0, index) - - # Write the output - for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = self.store_dtype(_f[l]) - # Construct the funcional to get streamed indices @wp.func - def functional3d( + def functional( f: wp.array4d(dtype=Any), index: Any, ): @@ -124,7 +83,7 @@ def functional3d( # Construct the warp kernel @wp.kernel - def kernel3d( + def kernel( f_0: wp.array4d(dtype=Any), f_1: wp.array4d(dtype=Any), ): @@ -133,15 +92,12 @@ def kernel3d( index = wp.vec3i(i, j, k) # Set the output - _f = functional3d(f_0, index) + _f = functional(f_0, index) # Write the output for l in range(self.velocity_set.q): f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) - functional = functional3d if self.velocity_set.d == 3 else functional2d - kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return functional, kernel @Operator.register_backend(ComputeBackend.WARP) From c0ea2a5be3ee92586d9293b7d9d243082caf4493 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 18 Oct 2024 18:00:43 -0400 Subject: [PATCH 142/144] Changed the exmaple mesh address --- examples/cfd/windtunnel_3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index c83e2c9..96c79f7 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -74,8 +74,8 @@ def define_boundary_indices(self): walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] walls = np.unique(np.array(walls), axis=-1).tolist() - # Load the mesh - stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl" + # Load the mesh (replace with your own mesh) + stl_filename = "../stl-files/DrivAer-Notchback.stl" mesh = trimesh.load_mesh(stl_filename, process=False) mesh_vertices = mesh.vertices From 27c0205af5782b107ccde5a1ac133489b07873ec Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 18 Oct 2024 18:02:26 -0400 Subject: [PATCH 143/144] Fixed ruff issue --- xlb/operator/boundary_masker/indices_boundary_masker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 36f9a36..0c1f7e1 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -93,7 +93,6 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) return bc_mask, missing_mask - def _construct_warp(self): # Make constants for warp _c = self.velocity_set.c From ed5f6435182718a76b40e14c0a2a94dddc3b802d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Mon, 21 Oct 2024 17:06:23 -0400 Subject: [PATCH 144/144] Added back BC kernels in an abstract manner --- .../test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_warp.py | 7 ++-- .../boundary_condition/bc_do_nothing.py | 14 +++++++- .../boundary_condition/bc_equilibrium.py | 15 +++++++- .../bc_extrapolation_outflow.py | 14 +++++++- .../bc_fullway_bounce_back.py | 14 +++++++- .../bc_grads_approximation.py | 14 +++++++- .../bc_halfway_bounce_back.py | 13 ++++++- .../boundary_condition/bc_regularized.py | 14 +++++++- xlb/operator/boundary_condition/bc_zouhe.py | 14 +++++++- .../boundary_condition/boundary_condition.py | 35 +++++++++++++++++++ 11 files changed, 145 insertions(+), 11 deletions(-) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 5eb0c10..6bd9311 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -72,7 +72,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): f = f.numpy() f_post = f_post.numpy() - assert f.shape == (velocity_set.q,) + grid_shape + assert f.shape == (velocity_set.q,) + grid_shape if dim == 3 else (velocity_set.q, grid_shape[0], grid_shape[1], 1) # Assert that the values are correct in the indices of the sphere weights = velocity_set.w diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 10b9244..59c6c9d 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -58,7 +58,10 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): bc_mask, missing_mask = indices_boundary_masker([fullway_bc], bc_mask, missing_mask, start_index=None) # Generate a random field with the same shape - random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) + if dim == 2: + random_field = np.random.rand(velocity_set.q, grid_shape[0], grid_shape[1], 1).astype(np.float32) + else: + random_field = np.random.rand(velocity_set.q, grid_shape[0], grid_shape[1], grid_shape[2]).astype(np.float32) # Add the random field to f_pre f_pre = wp.array(random_field) @@ -71,7 +74,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): f = f_pre.numpy() f_post = f_post.numpy() - assert f.shape == (velocity_set.q,) + grid_shape + assert f.shape == (velocity_set.q,) + grid_shape if dim == 3 else (velocity_set.q, grid_shape[0], grid_shape[1], 1) for i in range(velocity_set.q): np.allclose( diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 67b343d..56a332f 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -64,4 +64,16 @@ def functional( ): return f_pre - return functional, None + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index b4b957a..77f408f 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -88,4 +88,17 @@ def functional( _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f - return functional, None + # Use the parent class's kernel and pass the functional + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 4a96c73..38657e5 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -193,4 +193,16 @@ def prepare_bc_auxilary_data( _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux return _f - return (functional, prepare_bc_auxilary_data), None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return (functional, prepare_bc_auxilary_data), kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index afe05de..19a3013 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -74,4 +74,16 @@ def functional( fliped_f[l] = f_pre[_opp_indices[l]] return fliped_f - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 5806375..94ddba3 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -309,4 +309,16 @@ def functional_method2( functional = functional_method1 - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index ee68b50..bf04af0 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -87,5 +87,16 @@ def functional( return _f - return functional, None + kernel = self._construct_kernel(functional) + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 12622e2..af4c783 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -266,4 +266,16 @@ def functional_pressure( functional = functional_velocity elif self.bc_type == "pressure": functional = functional_pressure - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index c5d9498..a92d909 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -283,4 +283,16 @@ def functional_pressure( elif self.bc_type == "velocity": functional = functional_pressure - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 2cd2a11..bf1eef2 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -111,3 +111,38 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): currently being called after collision only. """ return f_post + + def _construct_kernel(self, functional): + """ + Constructs the warp kernel for the boundary condition. + The functional is specific to each boundary condition and should be passed as an argument. + """ + _id = wp.uint8(self.id) + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data(f_pre, f_post, bc_mask, missing_mask, index) + + # Apply the boundary condition + if _boundary_id == _id: + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) + else: + _f = _f_post + + # Write the result + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) + + return kernel