diff --git a/README.md b/README.md index 8d662b5..6bb7727 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![GitHub star chart](https://img.shields.io/github/stars/Autodesk/XLB?style=social)](https://star-history.com/#Autodesk/XLB)

- +

-# XLB: Distributed Multi-GPU Lattice Boltzmann Simulation Framework for Differentiable Scientific Machine Learning +# XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It's built on top of the [JAX](https://github.com/google/jax) library and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning. @@ -18,7 +18,7 @@ If you use XLB in your research, please cite the following paper: ``` @article{ataei2023xlb, - title={{XLB}: Distributed Multi-GPU Lattice Boltzmann Simulation Framework for Differentiable Scientific Machine Learning}, + title={{XLB}: A Differentiable Massively Parallel Lattice Boltzmann Library in Python}, author={Ataei, Mohammadmehdi and Salehipour, Hesam}, journal={arXiv preprint arXiv:2311.16080}, year={2023}, @@ -33,7 +33,7 @@ If you use XLB in your research, please cite the following paper: - **User-Friendly Interface:** Written entirely in Python, XLB emphasizes a highly accessible interface that allows users to extend the library with ease and quickly set up and run new simulations. - **Leverages JAX Array and Shardmap:** The library incorporates the new JAX array unified array type and JAX shardmap, providing users with a numpy-like interface. This allows users to focus solely on the semantics, leaving performance optimizations to the compiler. - **Platform Versatility:** The same XLB code can be executed on a variety of platforms including multi-core CPUs, single or multi-GPU systems, TPUs, and it also supports distributed runs on multi-GPU systems or TPU Pod slices. -- **Visualization:** XLB provides a variety of visualization options including in-situ rendering using [PhantomGaze](https://github.com/loliverhennigh/PhantomGaze). +- **Visualization:** XLB provides a variety of visualization options including in-situ on GPU rendering using [PhantomGaze](https://github.com/loliverhennigh/PhantomGaze). ## Showcase @@ -153,4 +153,4 @@ git clone https://github.com/Autodesk/XLB cd XLB export PYTHONPATH=. python3 examples/CFD/cavity2d.py -``` +``` \ No newline at end of file diff --git a/examples/CFD/cylinder2d.py b/examples/CFD/cylinder2d.py index e1d5f64..2c9887d 100644 --- a/examples/CFD/cylinder2d.py +++ b/examples/CFD/cylinder2d.py @@ -16,6 +16,8 @@ 5. Visualization: The simulation outputs data in VTK format for visualization. It also generates images of the velocity field. The data can be visualized using software like ParaView. +# To run type: +nohup python3 examples/CFD/cylinder2d.py > logfile.log & """ import os import json @@ -50,8 +52,9 @@ def set_boundary_conditions(self): # Outflow BC outlet = self.boundingBoxIndices['right'] - rho_outlet = np.ones(outlet.shape[0], dtype=self.precisionPolicy.compute_dtype) + rho_outlet = np.ones((outlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy)) + # self.BCs.append(ZouHe(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet)) # Inlet BC inlet = self.boundingBoxIndices['left'] @@ -78,7 +81,7 @@ def output_data(self, **kwargs): if timestep == 0: self.CL_max = 0.0 self.CD_max = 0.0 - if timestep > 0.8 * t_max: + if timestep > 0.5 * niter_max: # compute lift and drag over the cyliner cylinder = self.BCs[0] boundary_force = cylinder.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) @@ -94,7 +97,7 @@ def output_data(self, **kwargs): self.CL_max = max(self.CL_max, cl) self.CD_max = max(self.CD_max, cd) print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - save_image(timestep, u) + # save_image(timestep, u) # Helper function to specify a parabolic poiseuille profile poiseuille_profile = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2)) @@ -131,9 +134,11 @@ def output_data(self, **kwargs): 'print_info_rate': int(10000 / scale_factor), 'return_fpost': True # Need to retain fpost-collision for computation of lift and drag } + # characteristic time + tc = prescribed_vel/diam + niter_max = int(100//tc) sim = Cylinder(**kwargs) - t_max = int(1000000 / scale_factor) - sim.run(t_max) + sim.run(niter_max) CL_list.append(sim.CL_max) CD_list.append(sim.CD_max) diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py index 28d9208..9af75f3 100644 --- a/examples/CFD_refactor/windtunnel3d.py +++ b/examples/CFD_refactor/windtunnel3d.py @@ -14,6 +14,7 @@ from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC + class WindTunnel(IncompressibleNavierStokesSolver): """ This class extends the IncompressibleNavierStokesSolver class to define the boundary conditions for the wind tunnel simulation. diff --git a/examples/interfaces/boundary_conditions.py b/examples/interfaces/boundary_conditions.py new file mode 100644 index 0000000..70648bf --- /dev/null +++ b/examples/interfaces/boundary_conditions.py @@ -0,0 +1,156 @@ +# Simple script to run different boundary conditions with jax and warp backends +import time +from tqdm import tqdm +import os +import matplotlib.pyplot as plt +from typing import Any +import numpy as np +import jax.numpy as jnp +import warp as wp + +wp.init() + +import xlb + +def run_boundary_conditions(backend): + + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX + + # Set the precision policy + precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set + velocity_set = xlb.velocity_set.D3Q19() + + # Make grid + nr = 256 + shape = (nr, nr, nr) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make feilds + f_pre = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f_post = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) + + # Make needed operators + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + halfway_bounce_back_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + fullway_bounce_back_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Make indices for boundary conditions (sphere) + sphere_radius = 32 + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + indices = np.array(indices).T + if backend == "jax": + indices = jnp.array(indices) + elif backend == "warp": + indices = wp.from_numpy(indices, dtype=wp.int32) + + # Test equilibrium boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + equilibrium_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = equilibrium_bc(f_pre, f_post, boundary_id, missing_mask, f) + print(f"Equilibrium BC test passed for {backend}") + + # Test do nothing boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + do_nothing_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = do_nothing_bc(f_pre, f_post, boundary_id, missing_mask, f) + + # Test halfway bounce back boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + halfway_bounce_back_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = halfway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f) + print(f"Halfway bounce back BC test passed for {backend}") + + # Test the full boundary condition + boundary_id, missing_mask = indices_boundary_masker( + indices, + fullway_bounce_back_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + if backend == "jax": + f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask) + elif backend == "warp": + f = fullway_bounce_back_bc(f_pre, f_post, boundary_id, missing_mask, f) + print(f"Fullway bounce back BC test passed for {backend}") + + +if __name__ == "__main__": + + # Test the boundary conditions + backends = ["warp", "jax"] + for backend in backends: + run_boundary_conditions(backend) diff --git a/examples/interfaces/flow_past_sphere.py b/examples/interfaces/flow_past_sphere.py new file mode 100644 index 0000000..8bfe945 --- /dev/null +++ b/examples/interfaces/flow_past_sphere.py @@ -0,0 +1,220 @@ +# Simple flow past sphere example using the functional interface to xlb + +import time +from tqdm import tqdm +import os +import matplotlib.pyplot as plt +from typing import Any +import numpy as np + +import warp as wp + +wp.init() + +import xlb +from xlb.operator import Operator + +class UniformInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + vel: float, + ): + # Get the global index + i, j, k = wp.tid() + + # Set the velocity + u[0, i, j, k] = vel + u[1, i, j, k] = 0.0 + u[2, i, j, k] = 0.0 + + # Set the density + rho[0, i, j, k] = 1.0 + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u, vel): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + vel, + ], + dim=rho.shape[1:], + ) + return rho, u + + +if __name__ == "__main__": + # Set parameters + compute_backend = xlb.ComputeBackend.WARP + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q19() + + # Make feilds + nr = 256 + vel = 0.05 + shape = (nr, nr, nr) + grid = xlb.grid.WarpGrid(shape=shape) + rho = grid.create_field(cardinality=1, dtype=wp.float32) + u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) + f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) + boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) + missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + + # Make operators + initializer = UniformInitializer( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + collision = xlb.operator.collision.BGK( + omega=1.95, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(vel, 0.0, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + equilibrium_bc=equilibrium_bc, + do_nothing_bc=do_nothing_bc, + half_way_bc=half_way_bc, + ) + indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Make indices for boundary conditions (sphere) + sphere_radius = 32 + x = np.arange(nr) + y = np.arange(nr) + z = np.arange(nr) + X, Y, Z = np.meshgrid(x, y, z) + indices = np.where( + (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 + < sphere_radius**2 + ) + indices = np.array(indices).T + indices = wp.from_numpy(indices, dtype=wp.int32) + + # Set boundary conditions on the indices + boundary_id, missing_mask = indices_boundary_masker( + indices, + half_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set inlet bc + lower_bound = (0, 0, 0) + upper_bound = (0, nr, nr) + direction = (1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + equilibrium_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set outlet bc + lower_bound = (nr-1, 0, 0) + upper_bound = (nr-1, nr, nr) + direction = (-1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + do_nothing_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set initial conditions + rho, u = initializer(rho, u, vel) + f0 = equilibrium(rho, u, f0) + + # Time stepping + plot_freq = 512 + save_dir = "flow_past_sphere" + os.makedirs(save_dir, exist_ok=True) + #compute_mlup = False # Plotting results + compute_mlup = True + num_steps = 1024 * 8 + start = time.time() + for _ in tqdm(range(num_steps)): + f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1, f0 = f0, f1 + if (_ % plot_freq == 0) and (not compute_mlup): + rho, u = macroscopic(f0, rho, u) + + # Plot the velocity field and boundary id side by side + plt.subplot(1, 2, 1) + plt.imshow(u[0, :, nr // 2, :].numpy()) + plt.colorbar() + plt.subplot(1, 2, 2) + plt.imshow(boundary_id[0, :, nr // 2, :].numpy()) + plt.colorbar() + plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") + plt.close() + + wp.synchronize() + end = time.time() + + # Print MLUPS + print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") diff --git a/examples/interfaces/ldc.py b/examples/interfaces/ldc.py new file mode 100644 index 0000000..e5ca559 --- /dev/null +++ b/examples/interfaces/ldc.py @@ -0,0 +1,290 @@ +# Simple flow past sphere example using the functional interface to xlb + +import time +from tqdm import tqdm +import os +import matplotlib.pyplot as plt +from typing import Any +import numpy as np + +import warp as wp + +wp.init() + +import xlb +from xlb.operator import Operator + +class UniformInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + + # Set the velocity + u[0, i, j, k] = 0.0 + u[1, i, j, k] = 0.0 + u[2, i, j, k] = 0.0 + + # Set the density + rho[0, i, j, k] = 1.0 + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + ], + dim=rho.shape[1:], + ) + return rho, u + + +def run_ldc(backend, compute_mlup=True): + + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX + + # Set the precision policy + precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set + velocity_set = xlb.velocity_set.D3Q19() + + # Make grid + nr = 128 + shape = (nr, nr, nr) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make feilds + rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) + u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) + f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) + + # Make operators + initializer = UniformInitializer( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + collision = xlb.operator.collision.BGK( + omega=1.9, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=1.0, + u=(0, 0.10, 0.0), + equilibrium_operator=equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + #boundary_conditions=[equilibrium_bc, half_way_bc, full_way_bc], + boundary_conditions=[half_way_bc, full_way_bc, equilibrium_bc], + ) + planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + # Set inlet bc (bottom x face) + lower_bound = (0, 1, 1) + upper_bound = (0, nr-1, nr-1) + direction = (1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + equilibrium_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set outlet bc (top x face) + lower_bound = (nr-1, 0, 0) + upper_bound = (nr-1, nr, nr) + direction = (-1, 0, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + half_way_bc.id, + #full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (bottom y face) + lower_bound = (0, 0, 0) + upper_bound = (nr, 0, nr) + direction = (0, 1, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + half_way_bc.id, + #full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (top y face) + lower_bound = (0, nr-1, 0) + upper_bound = (nr, nr-1, nr) + direction = (0, -1, 0) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + half_way_bc.id, + #full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (bottom z face) + lower_bound = (0, 0, 0) + upper_bound = (nr, nr, 0) + direction = (0, 0, 1) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + half_way_bc.id, + #full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set half way bc (top z face) + lower_bound = (0, 0, nr-1) + upper_bound = (nr, nr, nr-1) + direction = (0, 0, -1) + boundary_id, missing_mask = planar_boundary_masker( + lower_bound, + upper_bound, + direction, + half_way_bc.id, + #full_way_bc.id, + boundary_id, + missing_mask, + (0, 0, 0) + ) + + # Set initial conditions + if backend == "warp": + rho, u = initializer(rho, u) + f0 = equilibrium(rho, u, f0) + elif backend == "jax": + rho = rho + 1.0 + f0 = equilibrium(rho, u) + + # Time stepping + plot_freq = 128 + save_dir = "ldc" + os.makedirs(save_dir, exist_ok=True) + num_steps = nr * 16 + start = time.time() + + for _ in tqdm(range(num_steps)): + # Time step + if backend == "warp": + f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1, f0 = f0, f1 + elif backend == "jax": + f0 = stepper(f0, boundary_id, missing_mask, _) + + # Plot if necessary + if (_ % plot_freq == 0) and (not compute_mlup): + if backend == "warp": + rho, u = macroscopic(f0, rho, u) + local_rho = rho.numpy() + local_u = u.numpy() + local_boundary_id = boundary_id.numpy() + elif backend == "jax": + local_rho, local_u = macroscopic(f0) + local_boundary_id = boundary_id + + # Plot the velocity field, rho and boundary id side by side + plt.subplot(1, 3, 1) + plt.imshow(np.linalg.norm(local_u[:, :, nr // 2, :], axis=0)) + plt.colorbar() + plt.subplot(1, 3, 2) + plt.imshow(local_rho[0, :, nr // 2, :]) + plt.colorbar() + plt.subplot(1, 3, 3) + plt.imshow(local_boundary_id[0, :, nr // 2, :]) + plt.colorbar() + plt.savefig(f"{save_dir}/{backend}_{str(_).zfill(6)}.png") + plt.close() + + wp.synchronize() + end = time.time() + + # Print MLUPS + print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + +if __name__ == "__main__": + + # Run the LDC example + backends = ["warp", "jax"] + compute_mlup = False + for backend in backends: + run_ldc(backend, compute_mlup=compute_mlup) diff --git a/examples/interfaces/functional_interface.py b/examples/interfaces/taylor_green.py similarity index 50% rename from examples/interfaces/functional_interface.py rename to examples/interfaces/taylor_green.py index e1419a1..f842107 100644 --- a/examples/interfaces/functional_interface.py +++ b/examples/interfaces/taylor_green.py @@ -4,22 +4,63 @@ from tqdm import tqdm import os import matplotlib.pyplot as plt - +from functools import partial +from typing import Any +import jax.numpy as jnp +from jax import jit import warp as wp + wp.init() import xlb from xlb.operator import Operator class TaylorGreenInitializer(Operator): + """ + Initialize the Taylor-Green vortex. + """ + + @Operator.register_backend(xlb.ComputeBackend.JAX) + #@partial(jit, static_argnums=(0)) + def jax_implementation(self, vel, nr): + # Make meshgrid + x = jnp.linspace(0, 2 * jnp.pi, nr) + y = jnp.linspace(0, 2 * jnp.pi, nr) + z = jnp.linspace(0, 2 * jnp.pi, nr) + X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij") + + # Compute u + u = jnp.stack( + [ + vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), + - vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), + jnp.zeros_like(X), + ], + axis=0, + ) + + # Compute rho + rho = ( + 3.0 + * vel + * vel + * (1.0 / 16.0) + * ( + jnp.cos(2.0 * X) + + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0)) + ) + + 1.0 + ) + rho = jnp.expand_dims(rho, axis=0) + + return rho, u def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel( - f0: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), vel: float, nr: int, ): @@ -53,12 +94,11 @@ def kernel( return None, kernel @Operator.register_backend(xlb.ComputeBackend.WARP) - def warp_implementation(self, f0, rho, u, vel, nr): + def warp_implementation(self, rho, u, vel, nr): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ - f0, rho, u, vel, @@ -68,23 +108,35 @@ def warp_implementation(self, f0, rho, u, vel, nr): ) return rho, u -if __name__ == "__main__": +def run_taylor_green(backend, compute_mlup=True): + + # Set the compute backend + if backend == "warp": + compute_backend = xlb.ComputeBackend.WARP + elif backend == "jax": + compute_backend = xlb.ComputeBackend.JAX - # Set parameters - compute_backend = xlb.ComputeBackend.WARP + # Set the precision policy precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set velocity_set = xlb.velocity_set.D3Q19() - # Make feilds - nr = 256 + # Make grid + nr = 128 shape = (nr, nr, nr) - grid = xlb.grid.WarpGrid(shape=shape) - rho = grid.create_field(cardinality=1, dtype=wp.float32) - u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32) - f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32) - boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8) - mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool) + if backend == "jax": + grid = xlb.grid.JaxGrid(shape=shape) + elif backend == "warp": + grid = xlb.grid.WarpGrid(shape=shape) + + # Make feilds + rho = grid.create_field(cardinality=1, precision=xlb.Precision.FP32) + u = grid.create_field(cardinality=velocity_set.d, precision=xlb.Precision.FP32) + f0 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + f1 = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.FP32) + boundary_id = grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators initializer = TaylorGreenInitializer( @@ -112,35 +164,57 @@ def warp_implementation(self, f0, rho, u, vel, nr): collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, - stream=stream, - boundary_conditions=[]) + stream=stream) - # Parrallelize the stepper + # Parrallelize the stepper TODO: Add this functionality #stepper = grid.parallelize_operator(stepper) # Set initial conditions - rho, u = initializer(f0, rho, u, 0.1, nr) - f0 = equilibrium(rho, u, f0) + if backend == "warp": + rho, u = initializer(rho, u, 0.1, nr) + f0 = equilibrium(rho, u, f0) + elif backend == "jax": + rho, u = initializer(0.1, nr) + f0 = equilibrium(rho, u) # Time stepping plot_freq = 32 save_dir = "taylor_green" os.makedirs(save_dir, exist_ok=True) - #compute_mlup = False # Plotting results - compute_mlup = True - num_steps = 1024 + num_steps = 8192 start = time.time() + for _ in tqdm(range(num_steps)): - f1 = stepper(f0, f1, boundary_id, mask, _) - f1, f0 = f0, f1 + # Time step + if backend == "warp": + f1 = stepper(f0, f1, boundary_id, missing_mask, _) + f1, f0 = f0, f1 + elif backend == "jax": + f0 = stepper(f0, boundary_id, missing_mask, _) + + # Plot if needed if (_ % plot_freq == 0) and (not compute_mlup): - rho, u = macroscopic(f0, rho, u) - plt.imshow(u[0, :, nr//2, :].numpy()) + if backend == "warp": + rho, u = macroscopic(f0, rho, u) + local_u = u.numpy() + elif backend == "jax": + rho, local_u = macroscopic(f0) + + + plt.imshow(local_u[0, :, nr//2, :]) plt.colorbar() - plt.savefig(f"{save_dir}/{str(_).zfill(4)}.png") + plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() wp.synchronize() end = time.time() # Print MLUPS print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + +if __name__ == "__main__": + + # Run Taylor-Green vortex on different backends + backends = ["warp", "jax"] + #backends = ["jax"] + for backend in backends: + run_taylor_green(backend, compute_mlup=True) diff --git a/xlb/__init__.py b/xlb/__init__.py index 88dcff2..84d38c5 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -1,6 +1,6 @@ # Enum classes from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy +from xlb.precision_policy import PrecisionPolicy, Precision from xlb.physics_type import PhysicsType # Config diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 1698c34..d2b579a 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -6,6 +6,7 @@ from xlb.grid import Grid from xlb.compute_backend import ComputeBackend from xlb.operator import Operator +from xlb.precision_policy import Precision class JaxGrid(Grid): def __init__(self, shape): @@ -31,8 +32,8 @@ def _initialize_jax_backend(self): else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) ) self.grid_shape_per_gpu = ( - self.grid_shape[0] // self.nDevices, - ) + self.grid_shape[1:] + self.shape[0] // self.nDevices, + ) + self.shape[1:] def parallelize_operator(self, operator: Operator): @@ -73,15 +74,15 @@ def _parallel_operator(f): return f - def create_field(self, name: str, cardinality: int, callback=None): + def create_field(self, cardinality: int, precision: Precision, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape) # Create field if callback is None: - f = jax.numpy.full(shape, 0.0, dtype=self.precision_policy) - if self.sharding is not None: - f = jax.make_sharded_array(self.sharding, f) + f = jax.numpy.full(shape, 0.0, dtype=precision.jax_dtype) + #if self.sharding is not None: + # f = jax.make_sharded_array(self.sharding, f) else: f = jax.make_array_from_callback(shape, self.sharding, callback) diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index e4d160e..97b337b 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -2,6 +2,7 @@ from xlb.grid import Grid from xlb.operator import Operator +from xlb.precision_policy import Precision class WarpGrid(Grid): def __init__(self, shape): @@ -11,12 +12,12 @@ def parallelize_operator(self, operator: Operator): # TODO: Implement parallelization of the operator raise NotImplementedError("Parallelization of the operator is not implemented yet for the WarpGrid") - def create_field(self, cardinality: int, dtype, callback=None): + def create_field(self, cardinality: int, precision: Precision, callback=None): # Get shape of the field shape = (cardinality,) + (self.shape) # Create the field - f = wp.zeros(shape, dtype=dtype) + f = wp.zeros(shape, dtype=precision.wp_dtype) # Raise error on callback if callback is not None: diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 501a7af..c1232a3 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,3 +1,4 @@ from xlb.operator.operator import Operator from xlb.operator.parallel_operator import ParallelOperator -import xlb.operator.stepper # +import xlb.operator.stepper +import xlb.operator.boundary_masker diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 3d10b59..27e0472 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,8 +1,8 @@ -from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, - ImplementationStep, +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition_registry import ( + BoundaryConditionRegistry, ) -from xlb.operator.boundary_condition.full_bounce_back import FullBounceBack -from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBack -from xlb.operator.boundary_condition.do_nothing import DoNothing -from xlb.operator.boundary_condition.equilibrium_boundary import EquilibriumBoundary +from xlb.operator.boundary_condition.equilibrium import EquilibriumBC +from xlb.operator.boundary_condition.do_nothing import DoNothingBC +from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC +from xlb.operator.boundary_condition.fullway_bounce_back import FullwayBounceBackBC diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 95b1265..92a2c1f 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -6,22 +6,18 @@ from jax import jit, device_count from functools import partial import numpy as np -from enum import Enum +from enum import Enum, auto from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) # Enum for implementation step class ImplementationStep(Enum): - COLLISION = 1 - STREAMING = 2 + COLLISION = auto() + STREAMING = auto() class BoundaryCondition(Operator): @@ -32,83 +28,11 @@ class BoundaryCondition(Operator): def __init__( self, implementation_step: ImplementationStep, - boundary_masker: BoundaryMasker, velocity_set: VelocitySet, precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, + compute_backend: ComputeBackend, ): super().__init__(velocity_set, precision_policy, compute_backend) - # Set implementation step + # Set the implementation step self.implementation_step = implementation_step - - # Set boundary masker - self.boundary_masker = boundary_masker - - @classmethod - def from_function( - cls, - implementation_step: ImplementationStep, - boundary_function, - velocity_set, - precision_policy, - compute_backend, - ): - """ - Create a boundary condition from a function. - """ - # Create boundary mask - boundary_mask = BoundaryMasker.from_function( - boundary_function, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - implementation_step, - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_indices( - cls, - implementation_step: ImplementationStep, - indices: np.ndarray, - stream_indices: bool, - velocity_set, - precision_policy, - compute_backend, - ): - """ - Create a boundary condition from indices and boundary id. - """ - # Create boundary mask - boundary_mask = IndicesBoundaryMasker( - indices, stream_indices, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - implementation_step, - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_stl( - cls, - implementation_step: ImplementationStep, - stl_file: str, - stream_indices: bool, - velocity_set, - precision_policy, - compute_backend, - ): - """ - Create a boundary condition from an STL file. - """ - raise NotImplementedError diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py new file mode 100644 index 0000000..0a3b2c7 --- /dev/null +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -0,0 +1,29 @@ +""" +Registry for boundary conditions in a LBM simulation. +""" + + +class BoundaryConditionRegistry: + """ + Registry for boundary conditions in a LBM simulation. + """ + + def __init__( + self, + ): + self.id_to_bc = {} # Maps id number to boundary condition + self.bc_to_id = {} # Maps boundary condition to id number + self.next_id = 1 # 0 is reserved for no boundary condition + + def register_boundary_condition(self, boundary_condition): + """ + Register a boundary condition. + """ + id = self.next_id + self.next_id += 1 + self.id_to_bc[id] = boundary_condition + self.bc_to_id[boundary_condition] = id + return id + + +boundary_condition_registry = BoundaryConditionRegistry() diff --git a/xlb/operator/boundary_condition/boundary_masker/__init__.py b/xlb/operator/boundary_condition/boundary_masker/__init__.py deleted file mode 100644 index e33e509..0000000 --- a/xlb/operator/boundary_condition/boundary_masker/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from xlb.operator.boundary_condition.boundary_masker.boundary_masker import BoundaryMasker -from xlb.operator.boundary_condition.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker diff --git a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py deleted file mode 100644 index 20bf580..0000000 --- a/xlb/operator/boundary_condition/boundary_masker/boundary_masker.py +++ /dev/null @@ -1,34 +0,0 @@ -# Base class for all equilibriums - -import jax.numpy as jnp -from jax import jit -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator - - -class BoundaryMasker(Operator): - """ - Operator for creating a boundary mask - """ - - @classmethod - def from_jax_func( - cls, jax_func, precision_policy: PrecisionPolicy, velocity_set: VelocitySet - ): - """ - Create a boundary masker from a jax function - """ - raise NotImplementedError - - @classmethod - def from_warp_func( - cls, warp_func, precision_policy: PrecisionPolicy, velocity_set: VelocitySet - ): - """ - Create a boundary masker from a warp function - """ - raise NotImplementedError diff --git a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py deleted file mode 100644 index fdf8ced..0000000 --- a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py +++ /dev/null @@ -1,118 +0,0 @@ -# Base class for all equilibriums - -from functools import partial -import numpy as np -import jax.numpy as jnp -from jax import jit -import warp as wp -from typing import Tuple - -from xlb.global_config import GlobalConfig -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.stream.stream import Stream - - -class IndicesBoundaryMasker(Operator): - """ - Operator for creating a boundary mask - """ - - def __init__( - self, - indices: np.ndarray, - stream_indices: bool, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set indices - # TODO: handle multi-gpu case (this will usually implicitly work) - self.indices = indices - self.stream_indices = stream_indices - - # Make stream operator - self.stream = Stream(velocity_set, precision_policy, compute_backend) - - @staticmethod - def _indices_to_tuple(indices): - """ - Converts a tensor of indices to a tuple for indexing - TODO: Might be better to index - """ - return tuple([indices[:, i] for i in range(indices.shape[1])]) - - @Operator.register_backend(ComputeBackend.JAX) - #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this - def jax_implementation(self, start_index, boundary_id, mask, id_number): - # Get local indices from the meshgrid and the indices - local_indices = self.indices - np.array(start_index)[np.newaxis, :] - - # Remove any indices that are out of bounds - local_indices = local_indices[ - (local_indices[:, 0] >= 0) - & (local_indices[:, 0] < mask.shape[0]) - & (local_indices[:, 1] >= 0) - & (local_indices[:, 1] < mask.shape[1]) - & (local_indices[:, 2] >= 0) - & (local_indices[:, 2] < mask.shape[2]) - ] - - # Set the boundary id - boundary_id = boundary_id.at[self._indices_to_tuple(local_indices)].set( - id_number - ) - - # Stream mask if necessary - if self.stream_indices: - # Make mask then stream to get the edge points - pre_stream_mask = jnp.zeros_like(mask) - pre_stream_mask = pre_stream_mask.at[ - self._indices_to_tuple(local_indices) - ].set(True) - post_stream_mask = self.stream(pre_stream_mask) - - # Set false for points inside the boundary - post_stream_mask = post_stream_mask.at[ - post_stream_mask[..., 0] == True - ].set(False) - - # Get indices on edges - edge_indices = jnp.argwhere(post_stream_mask) - - # Set the mask - mask = mask.at[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ].set( - post_stream_mask[ - edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : - ] - ) - - else: - # Set the mask - mask = mask.at[self._indices_to_tuple(local_indices)].set(True) - - return boundary_id, mask - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, start_index, boundary_id, mask, id_number): - # Reuse the jax implementation, TODO: implement a warp version - # Convert to jax - boundary_id = wp.jax.to_jax(boundary_id) - mask = wp.jax.to_jax(mask) - - # Call jax implementation - boundary_id, mask = self.jax_implementation( - start_index, boundary_id, mask, id_number - ) - - # Convert back to warp - boundary_id = wp.jax.to_warp(boundary_id) - mask = wp.jax.to_warp(mask) - - return boundary_id, mask diff --git a/xlb/operator/boundary_condition/do_nothing.py b/xlb/operator/boundary_condition/do_nothing.py index 6251660..46a6fdd 100644 --- a/xlb/operator/boundary_condition/do_nothing.py +++ b/xlb/operator/boundary_condition/do_nothing.py @@ -1,54 +1,121 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + import jax.numpy as jnp from jax import jit, device_count import jax.lax as lax from functools import partial import numpy as np +import warp as wp +from typing import Tuple, Any from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, ) -class DoNothing(BoundaryCondition): +class DoNothingBC(BoundaryCondition): """ - A boundary condition that skips the streaming step. + Do nothing boundary condition. This boundary condition skips the streaming step for the + boundary nodes. """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + def __init__( self, - set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, ): super().__init__( - set_boundary=set_boundary, - implementation_step=ImplementationStep.STREAMING, - velocity_set=velocity_set, - compute_backend=compute_backend, + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, ) - @classmethod - def from_indices( - cls, - indices, - velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, - ): - """ - Creates a boundary condition from a list of indices. - """ - - return cls( - set_boundary=cls._set_boundary_from_indices(indices), - velocity_set=velocity_set, - compute_backend=compute_backend, - ) + @Operator.register_backend(ComputeBackend.JAX) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + # TODO: This is unoptimized + boundary = boundary_id == self.id + flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) + skipped_f = lax.select(flip, f_pre, f_post) + return skipped_f - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - do_nothing = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) - f = lax.select(do_nothing, f_pre, f_post) + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: wp.array4d(dtype=Any), + missing_mask: Any, + index: Any, + ): + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + return _f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(DoNothingBC.id): + _f = functional(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1], index[2]] + + # Write the result + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], + ) return f diff --git a/xlb/operator/boundary_condition/equilibrium.py b/xlb/operator/boundary_condition/equilibrium.py new file mode 100644 index 0000000..6de68ec --- /dev/null +++ b/xlb/operator/boundary_condition/equilibrium.py @@ -0,0 +1,135 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp +from typing import Tuple, Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + + +class EquilibriumBC(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + rho: float, + u: Tuple[float, float, float], + equilibrium_operator: Operator, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + # Store the equilibrium information + self.rho = rho + self.u = u + self.equilibrium_operator = equilibrium_operator + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + # TODO: This is unoptimized + feq = self.equilibrium_operator(jnp.array([self.rho]), jnp.array(self.u)) + feq = jnp.reshape(feq, (self.velocity_set.q, 1, 1, 1)) + feq = jnp.repeat(feq, f_pre.shape[1], axis=1) + feq = jnp.repeat(feq, f_pre.shape[2], axis=2) + feq = jnp.repeat(feq, f_pre.shape[3], axis=3) + boundary = boundary_id == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + skipped_f = lax.select(boundary, feq, f_post) + return skipped_f + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _rho = wp.float32(self.rho) + _u = _u_vec(self.u[0], self.u[1], self.u[2]) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: wp.array4d(dtype=Any), + missing_mask: Any, + index: Any, + ): + _f = self.equilibrium_operator.warp_functional(_rho, _u) + return _f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(EquilibriumBC.id): + _f = functional(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1], index[2]] + + # Write the result + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], + ) + return f diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py deleted file mode 100644 index 4b47980..0000000 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ /dev/null @@ -1,78 +0,0 @@ -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) - - - -class EquilibriumBoundary(BoundaryCondition): - """ - Equilibrium boundary condition for a lattice Boltzmann method simulation. - """ - - def __init__( - self, - set_boundary, - rho: float, - u: tuple[float, float], - equilibrium: Equilibrium, - boundary_masker: BoundaryMasker, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - super().__init__( - ImplementationStep.COLLISION, - implementation_step=ImplementationStep.STREAMING, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - self.f = equilibrium(rho, u) - - @classmethod - def from_indices( - cls, - indices: np.ndarray, - rho: float, - u: tuple[float, float], - equilibrium: Equilibrium, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - """ - Creates a boundary condition from a list of indices. - """ - - return cls( - set_boundary=cls._set_boundary_from_indices(indices), - rho=rho, - u=u, - equilibrium=equilibrium, - velocity_set=velocity_set, - compute_backend=compute_backend, - ) - - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - equilibrium_mask = jnp.repeat(boundary[..., None], self.velocity_set.q, axis=-1) - equilibrium_f = jnp.repeat(self.f[None, ...], boundary.shape[0], axis=0) - equilibrium_f = jnp.repeat(equilibrium_f[:, None], boundary.shape[1], axis=1) - equilibrium_f = jnp.repeat(equilibrium_f[:, :, None], boundary.shape[2], axis=2) - f = lax.select(equilibrium_mask, equilibrium_f, f_post) - return f diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py deleted file mode 100644 index ed0ec5a..0000000 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax -from functools import partial -import numpy as np -import warp as wp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator import Operator -from xlb.operator.boundary_condition import ( - BoundaryCondition, - ImplementationStep, -) -from xlb.operator.boundary_condition.boundary_masker import ( - BoundaryMasker, - IndicesBoundaryMasker, -) - - -class FullBounceBack(BoundaryCondition): - """ - Full Bounce-back boundary condition for a lattice Boltzmann method simulation. - """ - - def __init__( - self, - boundary_masker: BoundaryMasker, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend, - ): - super().__init__( - ImplementationStep.COLLISION, - boundary_masker, - velocity_set, - precision_policy, - compute_backend, - ) - - @classmethod - def from_indices( - cls, indices: np.ndarray, velocity_set, precision_policy, compute_backend - ): - """ - Create a full bounce-back boundary condition from indices. - """ - # Create boundary mask - boundary_mask = IndicesBoundaryMasker( - indices, False, velocity_set, precision_policy, compute_backend - ) - - # Create boundary condition - return cls( - boundary_mask, - velocity_set, - precision_policy, - compute_backend, - ) - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) - flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) - return flipped_f - - def _construct_warp(self): - # Make constants for warp - _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - - # Construct the funcional to get streamed indices - @wp.func - def functional( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, - ): - fliped_f = self._warp_lattice_vec() - for l in range(_q): - fliped_f[l] = f_pre[_opp_indices[l]] - return fliped_f - - # Construct the warp kernel - @wp.kernel - def kernel( - f_pre: self._warp_array_type, - f_post: self._warp_array_type, - f: self._warp_array_type, - boundary: self._warp_bool_array_type, - mask: self._warp_bool_array_type, - ): - # Get the global index - i, j, k = wp.tid() - - # Make vectors for the lattice - _f_pre = self._warp_lattice_vec() - _f_post = self._warp_lattice_vec() - _mask = self._warp_bool_lattice_vec() - for l in range(_q): - _f_pre[l] = f_pre[l, i, j, k] - _f_post[l] = f_post[l, i, j, k] - - # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) - else: - _mask[l] = wp.uint8(0) - - # Check if the boundary is active - if boundary[i, j, k]: - _f = functional(_f_pre, _f_post, _mask) - else: - _f = _f_post - - # Write the result to the output - for l in range(_q): - f[l, i, j, k] = _f[l] - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, f, boundary, mask): - # Launch the warp kernel - wp.launch( - self._kernel, inputs=[f_pre, f_post, f, boundary, mask], dim=f_pre.shape[1:] - ) - return f diff --git a/xlb/operator/boundary_condition/fullway_bounce_back.py b/xlb/operator/boundary_condition/fullway_bounce_back.py new file mode 100644 index 0000000..547cde1 --- /dev/null +++ b/xlb/operator/boundary_condition/fullway_bounce_back.py @@ -0,0 +1,126 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + +import jax.numpy as jnp +from jax import jit, device_count +import jax.lax as lax +from functools import partial +import numpy as np +import warp as wp +from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator import Operator +from xlb.operator.boundary_condition.boundary_condition import ( + BoundaryCondition, + ImplementationStep, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, +) + + +class FullwayBounceBackBC(BoundaryCondition): + """ + Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """ + + id = boundary_condition_registry.register_boundary_condition(__qualname__) + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, + ): + super().__init__( + ImplementationStep.COLLISION, + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + boundary = boundary_id == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + return lax.select(boundary, f_pre[self.velocity_set.opp_indices], f_post) + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _opp_indices = self.velocity_set.wp_opp_indices + _q = wp.constant(self.velocity_set.q) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f_pre: Any, + f_post: Any, + missing_mask: Any, + ): + fliped_f = _f_vec() + for l in range(_q): + fliped_f[l] = f_pre[_opp_indices[l]] + return fliped_f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + + # Make vectors for the lattice + _f_pre = _f_vec() + _f_post = _f_vec() + _mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + _f_pre[l] = f_pre[l, index[0], index[1], index[2]] + _f_post[l] = f_post[l, index[0], index[1], index[2]] + + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) + + # Check if the boundary is active + if _boundary_id == wp.uint8(FullwayBounceBackBC.id): + _f = functional(_f_pre, _f_post, _mask) + else: + _f = _f_post + + # Write the result to the output + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], + ) + return f diff --git a/xlb/operator/boundary_condition/halfway_bounce_back.py b/xlb/operator/boundary_condition/halfway_bounce_back.py index 0937f1a..e47cc26 100644 --- a/xlb/operator/boundary_condition/halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/halfway_bounce_back.py @@ -1,40 +1,141 @@ +""" +Base class for boundary conditions in a LBM simulation. +""" + import jax.numpy as jnp from jax import jit, device_count import jax.lax as lax from functools import partial import numpy as np +import warp as wp +from typing import Tuple, Any from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend -from xlb.operator.stream.stream import Stream +from xlb.operator.operator import Operator from xlb.operator.boundary_condition.boundary_condition import ( - BoundaryCondition, ImplementationStep, + BoundaryCondition, +) +from xlb.operator.boundary_condition.boundary_condition_registry import ( + boundary_condition_registry, ) -class HalfwayBounceBack(BoundaryCondition): +class HalfwayBounceBackBC(BoundaryCondition): """ Halfway Bounce-back boundary condition for a lattice Boltzmann method simulation. + + TODO: Implement moving boundary conditions for this """ + id = boundary_condition_registry.register_boundary_condition(__qualname__) + def __init__( self, - set_boundary, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, ): + # Call the parent constructor super().__init__( - set_boundary=set_boundary, - implementation_step=ImplementationStep.STREAMING, - velocity_set=velocity_set, - compute_backend=compute_backend, + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, ) - @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) - def apply_jax(self, f_pre, f_post, boundary, mask): - flip_mask = boundary[..., jnp.newaxis] & mask - flipped_f = lax.select( - flip_mask, f_pre[..., self.velocity_set.opp_indices], f_post + @Operator.register_backend(ComputeBackend.JAX) + #@partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0)) + def apply_jax(self, f_pre, f_post, boundary_id, missing_mask): + boundary = boundary_id == self.id + boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + return lax.select(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], f_post) + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: wp.array4d(dtype=Any), + missing_mask: Any, + index: Any, + ): + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + # Get pull index + pull_index = type(index)() + + # If the mask is missing then take the opposite index + if missing_mask[l] == wp.uint8(1): + use_l = _opp_indices[l] + for d in range(self.velocity_set.d): + pull_index[d] = index[d] + + # Pull the distribution function + else: + use_l = l + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, l] + + # Get the distribution function + _f[l] = f[use_l, pull_index[0], pull_index[1], pull_index[2]] + + return _f + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + f: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Apply the boundary condition + if _boundary_id == wp.uint8(HalfwayBounceBackBC.id): + _f = functional(f_pre, _missing_mask, index) + else: + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f_post[l, index[0], index[1], index[2]] + + # Write the distribution function + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = _f[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, boundary_id, missing_mask, f): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, boundary_id, missing_mask, f], + dim=f_pre.shape[1:], ) - return flipped_f + return f diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py new file mode 100644 index 0000000..a9069c6 --- /dev/null +++ b/xlb/operator/boundary_masker/__init__.py @@ -0,0 +1,9 @@ +from xlb.operator.boundary_masker.boundary_masker import ( + BoundaryMasker, +) +from xlb.operator.boundary_masker.indices_boundary_masker import ( + IndicesBoundaryMasker, +) +from xlb.operator.boundary_masker.planar_boundary_masker import ( + PlanarBoundaryMasker, +) diff --git a/xlb/operator/boundary_masker/boundary_masker.py b/xlb/operator/boundary_masker/boundary_masker.py new file mode 100644 index 0000000..6fe487f --- /dev/null +++ b/xlb/operator/boundary_masker/boundary_masker.py @@ -0,0 +1,9 @@ +# Base class for all boundary masker operators + +from xlb.operator.operator import Operator + + +class BoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py new file mode 100644 index 0000000..b9e9f5b --- /dev/null +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -0,0 +1,158 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class IndicesBoundaryMasker(Operator): + """ + Operator for creating a boundary mask + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + @staticmethod + def _indices_to_tuple(indices): + """ + Converts a tensor of indices to a tuple for indexing + TODO: Might be better to index + """ + return tuple([indices[:, i] for i in range(indices.shape[1])]) + + @Operator.register_backend(ComputeBackend.JAX) + # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this + def jax_implementation( + self, indices, id_number, boundary_id, mask, start_index=(0, 0, 0) + ): + # TODO: This is somewhat untested and unoptimized + + # Get local indices from the meshgrid and the indices + local_indices = indices - np.array(start_index)[np.newaxis, :] + + # Remove any indices that are out of bounds + local_indices = local_indices[ + (local_indices[:, 0] >= 0) + & (local_indices[:, 0] < mask.shape[0]) + & (local_indices[:, 1] >= 0) + & (local_indices[:, 1] < mask.shape[1]) + & (local_indices[:, 2] >= 0) + & (local_indices[:, 2] < mask.shape[2]) + ] + + # Set the boundary id + boundary_id = boundary_id.at[0, self._indices_to_tuple(local_indices)].set( + id_number + ) + + # Make mask then stream to get the edge points + pre_stream_mask = jnp.zeros_like(mask) + pre_stream_mask = pre_stream_mask.at[self._indices_to_tuple(local_indices)].set( + True + ) + post_stream_mask = self.stream(pre_stream_mask) + + # Set false for points inside the boundary (NOTE: removing this to be more consistent with the other boundary maskers, maybe add back in later) + # post_stream_mask = post_stream_mask.at[ + # post_stream_mask[0, ...] == True + # ].set(False) + + # Get indices on edges + edge_indices = jnp.argwhere(post_stream_mask) + + # Set the mask + mask = mask.at[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ].set( + post_stream_mask[ + edge_indices[:, 0], edge_indices[:, 1], edge_indices[:, 2], : + ] + ) + + return boundary_id, mask + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.wp_c + _q = wp.constant(self.velocity_set.q) + + # Construct the warp kernel + @wp.kernel + def kernel( + indices: wp.array2d(dtype=wp.int32), + id_number: wp.int32, + boundary_id: wp.array4d(dtype=wp.uint8), + mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, + ): + # Get the index of indices + ii = wp.tid() + + # Get local indices + index = wp.vec3i() + index[0] = indices[ii, 0] - start_index[0] + index[1] = indices[ii, 1] - start_index[1] + index[2] = indices[ii, 2] - start_index[2] + + # Check if in bounds + if ( + index[0] >= 0 + and index[0] < mask.shape[1] + and index[1] >= 0 + and index[1] < mask.shape[2] + and index[2] >= 0 + and index[2] < mask.shape[3] + ): + # Stream indices + for l in range(_q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and mask + boundary_id[ + 0, push_index[0], push_index[1], push_index[2] + ] = wp.uint8(id_number) + mask[l, push_index[0], push_index[1], push_index[2]] = True + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, indices, id_number, boundary_id, missing_mask, start_index=(0, 0, 0) + ): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + indices, + id_number, + boundary_id, + missing_mask, + start_index, + ], + dim=indices.shape[0], + ) + + return boundary_id, missing_mask diff --git a/xlb/operator/boundary_masker/planar_boundary_masker.py b/xlb/operator/boundary_masker/planar_boundary_masker.py new file mode 100644 index 0000000..572f345 --- /dev/null +++ b/xlb/operator/boundary_masker/planar_boundary_masker.py @@ -0,0 +1,197 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class PlanarBoundaryMasker(Operator): + """ + Operator for creating a boundary mask on a plane of the domain + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + @Operator.register_backend(ComputeBackend.JAX) + # @partial(jit, static_argnums=(0), inline=True) TODO: Fix this + def jax_implementation( + self, + lower_bound, + upper_bound, + direction, + id_number, + boundary_id, + mask, + start_index=(0, 0, 0), + ): + # TODO: Optimize this + + # x plane + if direction[0] != 0: + + # Set boundary id + boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) + + # Set mask + for l in range(self.velocity_set.q): + d_dot_c = ( + direction[0] * self.velocity_set.c[0, l] + + direction[1] * self.velocity_set.c[1, l] + + direction[2] * self.velocity_set.c[2, l] + ) + if d_dot_c >= 0: + mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(True) + + # y plane + elif direction[1] != 0: + + # Set boundary id + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number) + + # Set mask + for l in range(self.velocity_set.q): + d_dot_c = ( + direction[0] * self.velocity_set.c[0, l] + + direction[1] * self.velocity_set.c[1, l] + + direction[2] * self.velocity_set.c[2, l] + ) + if d_dot_c >= 0: + mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(True) + + # z plane + elif direction[2] != 0: + + # Set boundary id + boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number) + + # Set mask + for l in range(self.velocity_set.q): + d_dot_c = ( + direction[0] * self.velocity_set.c[0, l] + + direction[1] * self.velocity_set.c[1, l] + + direction[2] * self.velocity_set.c[2, l] + ) + if d_dot_c >= 0: + mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(True) + + return boundary_id, mask + + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.wp_c + _q = wp.constant(self.velocity_set.q) + + # Construct the warp kernel + @wp.kernel + def kernel( + lower_bound: wp.vec3i, + upper_bound: wp.vec3i, + direction: wp.vec3i, + id_number: wp.int32, + boundary_id: wp.array4d(dtype=wp.uint8), + mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, + ): + # Get the indices of the plane to mask + plane_i, plane_j = wp.tid() + + # Get local indices + if direction[0] != 0: + i = lower_bound[0] - start_index[0] + j = plane_i + lower_bound[1] - start_index[1] + k = plane_j + lower_bound[2] - start_index[2] + elif direction[1] != 0: + i = plane_i + lower_bound[0] - start_index[0] + j = lower_bound[1] - start_index[1] + k = plane_j + lower_bound[2] - start_index[2] + elif direction[2] != 0: + i = plane_i + lower_bound[0] - start_index[0] + j = plane_j + lower_bound[1] - start_index[1] + k = lower_bound[2] - start_index[2] + + # Check if in bounds + if ( + i >= 0 + and i < mask.shape[1] + and j >= 0 + and j < mask.shape[2] + and k >= 0 + and k < mask.shape[3] + ): + # Set the boundary id + boundary_id[0, i, j, k] = wp.uint8(id_number) + + # Set mask for just directions coming from the boundary + for l in range(_q): + d_dot_c = ( + direction[0] * _c[0, l] + + direction[1] * _c[1, l] + + direction[2] * _c[2, l] + ) + if d_dot_c >= 0: + mask[l, i, j, k] = wp.bool(True) + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + lower_bound, + upper_bound, + direction, + id_number, + boundary_id, + mask, + start_index=(0, 0, 0), + ): + # Get plane dimensions + if direction[0] != 0: + dim = ( + upper_bound[1] - lower_bound[1], + upper_bound[2] - lower_bound[2], + ) + elif direction[1] != 0: + dim = ( + upper_bound[0] - lower_bound[0], + upper_bound[2] - lower_bound[2], + ) + elif direction[2] != 0: + dim = ( + upper_bound[0] - lower_bound[0], + upper_bound[1] - lower_bound[1], + ) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + lower_bound, + upper_bound, + direction, + id_number, + boundary_id, + mask, + start_index, + ], + dim=dim, + ) + + return boundary_id, mask diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py new file mode 100644 index 0000000..cda8c00 --- /dev/null +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -0,0 +1,114 @@ +# Base class for all equilibriums + +from functools import partial +import numpy as np +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Tuple + +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.stream.stream import Stream + + +class STLBoundaryMasker(Operator): + """ + Operator for creating a boundary mask from an STL file + """ + + def __init__( + self, + velocity_set: VelocitySet, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend.JAX, + ): + super().__init__(velocity_set, precision_policy, compute_backend) + + # TODO: Implement this + raise NotImplementedError + + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation( + self, mesh, id_number, boundary_id, mask, start_index=(0, 0, 0) + ): + # TODO: Implement this + raise NotImplementedError + + def _construct_warp(self): + # Make constants for warp + _opp_indices = wp.constant( + self._warp_int_lattice_vec(self.velocity_set.opp_indices) + ) + _q = wp.constant(self.velocity_set.q) + _d = wp.constant(self.velocity_set.d) + _id = wp.constant(self.id) + + # Construct the warp kernel + @wp.kernel + def _voxelize_mesh( + voxels: wp.array3d(dtype=wp.uint8), + mesh: wp.uint64, + spacing: wp.vec3, + origin: wp.vec3, + shape: wp.vec(3, wp.uint32), + max_length: float, + material_id: int, + ): + # get index of voxel + i, j, k = wp.tid() + + # position of voxel + ijk = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) + ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center + pos = wp.cw_mul(ijk, spacing) + origin + + # Only evaluate voxel if not set yet + if voxels[i, j, k] != wp.uint8(0): + return + + # evaluate distance of point + face_index = int(0) + face_u = float(0.0) + face_v = float(0.0) + sign = float(0.0) + if wp.mesh_query_point( + mesh, pos, max_length, sign, face_index, face_u, face_v + ): + p = wp.mesh_eval_position(mesh, face_index, face_u, face_v) + delta = pos - p + norm = wp.sqrt(wp.dot(delta, delta)) + + # set point to be solid + if norm < wp.min(spacing): + voxels[i, j, k] = wp.uint8(255) + elif sign < 0: # TODO: fix this + voxels[i, j, k] = wp.uint8(material_id) + else: + pass + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, start_index, boundary_id, mask, id_number): + # Reuse the jax implementation, TODO: implement a warp version + # Convert to jax + boundary_id = wp.jax.to_jax(boundary_id) + mask = wp.jax.to_jax(mask) + + # Call jax implementation + boundary_id, mask = self.jax_implementation( + start_index, boundary_id, mask, id_number + ) + + # Convert back to warp + boundary_id = wp.jax.to_warp(boundary_id) + mask = wp.jax.to_warp(mask) + + return boundary_id, mask diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 4071345..69718aa 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,6 +1,7 @@ import jax.numpy as jnp from jax import jit import warp as wp +from typing import Any from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -32,21 +33,19 @@ def pallas_implementation( return fout def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) - _d = wp.constant(self.velocity_set.d) + # Set local constants TODO: This is a hack and should be fixed with warp update + _w = self.velocity_set.wp_w _omega = wp.constant(self.compute_dtype(self.omega)) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the functional @wp.func def functional( - f: self._warp_lattice_vec, - feq: self._warp_lattice_vec, - rho: self.compute_dtype, - u: self._warp_u_vec, - ) -> self._warp_lattice_vec: + f: Any, + feq: Any, + rho: Any, + u: Any, + ): fneq = f - feq fout = f - _omega * fneq return fout @@ -54,30 +53,33 @@ def functional( # Construct the warp kernel @wp.kernel def kernel( - f: self._warp_array_type, - feq: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, - fout: self._warp_array_type, + f: wp.array4d(dtype=Any), + feq: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this - # Get the equilibrium - _f = self._warp_lattice_vec() - _feq = self._warp_lattice_vec() - for l in range(_q): - _f[l] = f[l, i, j, k] - _feq[l] = feq[l, i, j, k] + # Load needed values + _f = _f_vec() + _feq = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _feq[l] = feq[l, index[0], index[1], index[2]] _u = self._warp_u_vec() for l in range(_d): - _u[l] = u[l, i, j, k] - _rho = rho[0, i, j, k] + _u[l] = u[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + + # Compute the collision _fout = functional(_f, _feq, _rho, _u) # Write the result - for l in range(_q): - fout[l, i, j, k] = _fout[l] + for l in range(self.velocity_set.q): + fout[l, index[0], index[1], index[2]] = _fout[l] return functional, kernel @@ -85,7 +87,7 @@ def kernel( def warp_implementation(self, f, feq, rho, u, fout): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ f, feq, diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index a1245f1..d5bb20a 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -2,6 +2,7 @@ import jax.numpy as jnp from jax import jit import warp as wp +from typing import Any from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -14,8 +15,6 @@ class QuadraticEquilibrium(Equilibrium): """ Quadratic equilibrium of Boltzmann equation using hermite polynomials. Standard equilibrium model for LBM. - - TODO: move this to a separate file and lower and higher order equilibriums """ @Operator.register_backend(ComputeBackend.JAX) @@ -56,22 +55,26 @@ def pallas_implementation(self, rho, u): return jnp.array(eq) def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) - _d = wp.constant(self.velocity_set.d) + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _w = self.velocity_set.wp_w + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) # Construct the equilibrium functional @wp.func def functional( - rho: self.compute_dtype, u: self._warp_u_vec - ) -> self._warp_lattice_vec: - feq = self._warp_lattice_vec() # empty lattice vector - for l in range(_q): - ## Compute cu + rho: Any, + u: Any, + ): + # Allocate the equilibrium + feq = _f_vec() + + # Compute the equilibrium + for l in range(self.velocity_set.q): + # Compute cu cu = self.compute_dtype(0.0) - for d in range(_d): + for d in range(self.velocity_set.d): if _c[d, l] == 1: cu += u[d] elif _c[d, l] == -1: @@ -89,23 +92,24 @@ def functional( # Construct the warp kernel @wp.kernel def kernel( - rho: self._warp_array_type, - u: self._warp_array_type, - f: self._warp_array_type, + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + f: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # Get the equilibrium - _u = self._warp_u_vec() - for d in range(_d): - _u[d] = u[d, i, j, k] - _rho = rho[0, i, j, k] + _u = _u_vec() + for d in range(self.velocity_set.d): + _u[d] = u[d, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] feq = functional(_rho, _u) # Set the output - for l in range(_q): - f[l, i, j, k] = feq[l] + for l in range(self.velocity_set.q): + f[l, index[0], index[1], index[2]] = feq[l] return functional, kernel diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 97bd10a..161705e 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jax import jit import warp as wp -from typing import Tuple +from typing import Tuple, Any from xlb.global_config import GlobalConfig from xlb.velocity_set.velocity_set import VelocitySet @@ -75,19 +75,19 @@ def pallas_implementation(self, f): def _construct_warp(self): # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) + _c = self.velocity_set.wp_c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) # Construct the functional @wp.func - def functional(f: self._warp_lattice_vec): + def functional(f: _f_vec): # Compute rho and u rho = self.compute_dtype(0.0) - u = self._warp_u_vec() - for l in range(_q): + u = _u_vec() + for l in range(self.velocity_set.q): rho += f[l] - for d in range(_d): + for d in range(self.velocity_set.d): if _c[d, l] == 1: u[d] += f[l] elif _c[d, l] == -1: @@ -95,28 +95,28 @@ def functional(f: self._warp_lattice_vec): u /= rho return rho, u - # return u, rho # Construct the kernel @wp.kernel def kernel( - f: self._warp_array_type, - rho: self._warp_array_type, - u: self._warp_array_type, + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # Get the equilibrium - _f = self._warp_lattice_vec() - for l in range(_q): - _f[l] = f[l, i, j, k] + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] (_rho, _u) = functional(_f) # Set the output - rho[0, i, j, k] = _rho - for d in range(_d): - u[d, i, j, k] = _u[d] + rho[0, index[0], index[1], index[2]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1], index[2]] = _u[d] return functional, kernel diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index f3ea901..1724ffc 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,5 +1,6 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) import warp as wp +from typing import Any from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy, Precision @@ -102,111 +103,26 @@ def compute_dtype(self): """ Returns the compute dtype """ - return self._precision_to_dtype(self.precision_policy.compute_precision) + if self.compute_backend == ComputeBackend.JAX: + return self.precision_policy.compute_precision.jax_dtype + elif self.compute_backend == ComputeBackend.WARP: + return self.precision_policy.compute_precision.wp_dtype @property def store_dtype(self): """ Returns the store dtype """ - return self._precision_to_dtype(self.precision_policy.store_precision) - - def _precision_to_dtype(self, precision): - """ - Convert the precision to the corresponding dtype - TODO: Maybe move this to precision policy? - """ - if precision == Precision.FP64: - return self.backend.float64 - elif precision == Precision.FP32: - return self.backend.float32 - elif precision == Precision.FP16: - return self.backend.float16 - - ### WARP specific types ### - # These are used to define the types for the warp backend - # TODO: There might be a better place to put these - @property - def _warp_u_vec(self): - """ - Returns the warp type for velocity - """ - return wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - - @property - def _warp_lattice_vec(self): - """ - Returns the warp type for the lattice - """ - return wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - - @property - def _warp_int_lattice_vec(self): - """ - Returns the warp type for the streaming matrix (c) - """ - return wp.vec(self.velocity_set.q, dtype=wp.int32) - - @property - def _warp_bool_lattice_vec(self): - """ - Returns the warp type for the streaming matrix (c) - """ - #return wp.vec(self.velocity_set.q, dtype=wp.bool) - return wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO bool breaks - - @property - def _warp_stream_mat(self): - """ - Returns the warp type for the streaming matrix (c) - """ - return wp.mat( - (self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype - ) - - @property - def _warp_int_stream_mat(self): - """ - Returns the warp type for the streaming matrix (c) - """ - return wp.mat( - (self.velocity_set.d, self.velocity_set.q), dtype=wp.int32 - ) - - @property - def _warp_array_type(self): - """ - Returns the warp type for arrays - """ - if self.velocity_set.d == 2: - return wp.array3d(dtype=self.store_dtype) - elif self.velocity_set.d == 3: - return wp.array4d(dtype=self.store_dtype) - - @property - def _warp_uint8_array_type(self): - """ - Returns the warp type for arrays - """ - if self.velocity_set.d == 2: - return wp.array3d(dtype=wp.uint8) - elif self.velocity_set.d == 3: - return wp.array4d(dtype=wp.uint8) - - @property - def _warp_bool_array_type(self): - """ - Returns the warp type for arrays - """ - if self.velocity_set.d == 2: - return wp.array3d(dtype=wp.bool) - elif self.velocity_set.d == 3: - return wp.array4d(dtype=wp.bool) + if self.compute_backend == ComputeBackend.JAX: + return self.precision_policy.store_precision.jax_dtype + elif self.compute_backend == ComputeBackend.WARP: + return self.precision_policy.store_precision.wp_dtype def _construct_warp(self): """ Construct the warp functional and kernel of the operator TODO: Maybe a better way to do this? Maybe add this to the backend decorator? + Leave it for now, as it is not clear how the warp backend will evolve """ return None, None diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index 7b0acfe..9c6d56e 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -4,13 +4,13 @@ from functools import partial from jax import jit import warp as wp +from typing import Any from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.stepper import Stepper -from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.collision import BGK +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep class IncompressibleNavierStokesStepper(Stepper): @@ -19,17 +19,17 @@ class IncompressibleNavierStokesStepper(Stepper): """ @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0, 5)) - def apply_jax(self, f, boundary_id, mask, timestep): + @partial(jit, static_argnums=(0, 4), donate_argnums=(1)) + def apply_jax(self, f, boundary_id, missing_mask, timestep): """ Perform a single step of the lattice boltzmann method """ - # Cast to compute precision - f_pre_collision = self.precision_policy.cast_to_compute_jax(f) + # Cast to compute precision TODO add this back in + #f_pre_collision = self.precision_policy.cast_to_compute_jax(f) # Compute the macroscopic variables - rho, u = self.macroscopic(f_pre_collision) + rho, u = self.macroscopic(f) # Compute equilibrium feq = self.equilibrium(rho, u) @@ -43,43 +43,42 @@ def apply_jax(self, f, boundary_id, mask, timestep): ) # Apply collision type boundary conditions - for id_number, bc in self.collision_boundary_conditions.items(): - f_post_collision = bc( - f_pre_collision, - f_post_collision, - boundary_id == id_number, - mask, - ) - f_pre_streaming = f_post_collision + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.COLLISION: + f_post_collision = bc( + f, + f_post_collision, + boundary_id, + missing_mask, + ) ## Apply forcing # if self.forcing_op is not None: # f = self.forcing_op.apply_jax(f, timestep) # Apply streaming - f_post_streaming = self.stream(f_pre_streaming) + f_post_streaming = self.stream(f_post_collision) # Apply boundary conditions - for id_number, bc in self.stream_boundary_conditions.items(): - f_post_streaming = bc( - f_pre_streaming, - f_post_streaming, - boundary_id == id_number, - mask, - ) + for bc in self.boundary_conditions: + if bc.implementation_step == ImplementationStep.STREAMING: + f_post_streaming = bc( + f_post_collision, + f_post_streaming, + boundary_id, + missing_mask, + ) # Copy back to store precision - f = self.precision_policy.cast_to_store_jax(f_post_streaming) + #f = self.precision_policy.cast_to_store_jax(f_post_streaming) - return f + return f_post_streaming @Operator.register_backend(ComputeBackend.PALLAS) @partial(jit, static_argnums=(0,)) - def apply_pallas(self, fin, boundary_id, mask, timestep): + def apply_pallas(self, fin, boundary_id, missing_mask, timestep): # Raise warning that the boundary conditions are not implemented - ################################################################ warning("Boundary conditions are not implemented for PALLAS backend currently") - ################################################################ from xlb.operator.parallel_operator import ParallelOperator @@ -131,90 +130,92 @@ def _pallas_collide_and_stream(f): return fout def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - _nr_boundary_conditions = wp.constant(len(self.boundary_conditions)) + # Set local constants TODO: This is a hack and should be fixed with warp update + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Get the boundary condition ids + _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) + _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) + _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) + _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) # Construct the kernel @wp.kernel def kernel( - f_0: self._warp_array_type, - f_1: self._warp_array_type, - boundary_id: self._warp_uint8_array_type, - mask: self._warp_bool_array_type, + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=Any), + missing_mask: wp.array4d(dtype=Any), timestep: int, - max_i: int, - max_j: int, - max_k: int, ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO warp should fix this - # Get the f, boundary id and mask - _f = self._warp_lattice_vec() - _boundary_id = boundary_id[0, i, j, k] - _mask = self._warp_bool_lattice_vec() - for l in range(_q): - _f[l] = f_0[l, i, j, k] - + # Get the boundary id and missing mask + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): # TODO fix vec bool - if mask[l, i, j, k]: - _mask[l] = wp.uint8(1) + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) else: - _mask[l] = wp.uint8(0) - + _missing_mask[l] = wp.uint8(0) + + # Apply streaming boundary conditions + if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: + # Regular streaming + f_post_stream = self.stream.warp_functional(f_0, index) + elif _boundary_id == _equilibrium_bc: + # Equilibrium boundary condition + f_post_stream = self.equilibrium_bc.warp_functional( + f_0, _missing_mask, index + ) + elif _boundary_id == _do_nothing_bc: + # Do nothing boundary condition + f_post_stream = self.do_nothing_bc.warp_functional( + f_0, _missing_mask, index + ) + elif _boundary_id == _halfway_bounce_back_bc: + # Half way boundary condition + f_post_stream = self.halfway_bounce_back_bc.warp_functional( + f_0, _missing_mask, index + ) + # Compute rho and u - rho, u = self.macroscopic.warp_functional(_f) + rho, u = self.macroscopic.warp_functional(f_post_stream) # Compute equilibrium feq = self.equilibrium.warp_functional(rho, u) # Apply collision f_post_collision = self.collision.warp_functional( - _f, + f_post_stream, feq, rho, u, ) - ## Apply collision type boundary conditions - #if _boundary_id != wp.uint8(0): - # f_post_collision = self.collision_boundary_conditions[ - # _boundary_id - # ].warp_functional( - # _f, - # f_post_collision, - # _mask, - # ) - f_pre_streaming = f_post_collision # store pre streaming vector - - # Apply forcing - # if self.forcing_op is not None: - # f = self.forcing.warp_functional(f, timestep) - - # Apply streaming - for l in range(_q): - # Get the streamed indices - streamed_i, streamed_j, streamed_k = self.stream.warp_functional( - l, i, j, k, max_i, max_j, max_k + # Apply collision type boundary conditions + if _boundary_id == _fullway_bounce_back_bc: + # Full way boundary condition + f_post_collision = self.fullway_bounce_back_bc.warp_functional( + f_post_stream, + f_post_collision, + _missing_mask, ) - streamed_l = l - - ## Modify the streamed indices based on streaming boundary condition - # if _boundary_id != 0: - # streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].warp_functional( - # streamed_l, streamed_i, streamed_j, streamed_k, self._warp_max_i, self._warp_max_j, self._warp_max_k - # ) - # Set the output - f_1[streamed_l, streamed_i, streamed_j, streamed_k] = f_pre_streaming[l] + # Set the output + for l in range(self.velocity_set.q): + f_1[l, index[0], index[1], index[2]] = f_post_collision[l] return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, boundary_id, mask, timestep): + def warp_implementation(self, f_0, f_1, boundary_id, missing_mask, timestep): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -222,11 +223,8 @@ def warp_implementation(self, f_0, f_1, boundary_id, mask, timestep): f_0, f_1, boundary_id, - mask, + missing_mask, timestep, - f_0.shape[1], - f_0.shape[2], - f_0.shape[3], ], dim=f_0.shape[1:], ) diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 005e54d..e1eed44 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -8,7 +8,6 @@ from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.boundary_condition import ImplementationStep from xlb.operator.precision_caster import PrecisionCaster @@ -24,9 +23,9 @@ def __init__( equilibrium, macroscopic, boundary_conditions=[], - forcing=None, + forcing=None, # TODO: Add forcing later ): - # Set parameters + # Add operators self.collision = collision self.stream = stream self.equilibrium = equilibrium @@ -40,8 +39,10 @@ def __init__( stream, equilibrium, macroscopic, - *boundary_conditions, + *self.boundary_conditions, ] + if forcing is not None: + self.operators.append(forcing) # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators]) @@ -54,160 +55,57 @@ def __init__( assert len(compute_backends) == 1, "All compute backends must be the same" compute_backend = compute_backends.pop() - # Get collision and stream boundary conditions - self.collision_boundary_conditions = {} - self.stream_boundary_conditions = {} - for id_number, bc in enumerate(self.boundary_conditions): - bc_id = id_number + 1 - if bc.implementation_step == ImplementationStep.COLLISION: - self.collision_boundary_conditions[bc_id] = bc - elif bc.implementation_step == ImplementationStep.STREAMING: - self.stream_boundary_conditions[bc_id] = bc - else: - raise ValueError("Boundary condition step not recognized") - - # Make operators for converting the precisions - #self.cast_to_compute = PrecisionCaster( - - # Make operator for setting boundary condition arrays - self.set_boundary = SetBoundary( - self.collision_boundary_conditions, - self.stream_boundary_conditions, - velocity_set, - precision_policy, - compute_backend, - ) - self.operators.append(self.set_boundary) + # Add boundary conditions + # Warp cannot handle lists of functions currently + # Because of this we manually unpack the boundary conditions + ############################################ + # TODO: Fix this later + ############################################ + from xlb.operator.boundary_condition.equilibrium import EquilibriumBC + from xlb.operator.boundary_condition.do_nothing import DoNothingBC + from xlb.operator.boundary_condition.halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.fullway_bounce_back import FullwayBounceBackBC + self.equilibrium_bc = None + self.do_nothing_bc = None + self.halfway_bounce_back_bc = None + self.fullway_bounce_back_bc = None + for bc in boundary_conditions: + if isinstance(bc, EquilibriumBC): + self.equilibrium_bc = bc + elif isinstance(bc, DoNothingBC): + self.do_nothing_bc = bc + elif isinstance(bc, HalfwayBounceBackBC): + self.halfway_bounce_back_bc = bc + elif isinstance(bc, FullwayBounceBackBC): + self.fullway_bounce_back_bc = bc + if self.equilibrium_bc is None: + self.equilibrium_bc = EquilibriumBC( + rho=1.0, + u=(0.0, 0.0, 0.0), + equilibrium_operator=self.equilibrium, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + if self.do_nothing_bc is None: + self.do_nothing_bc = DoNothingBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + if self.halfway_bounce_back_bc is None: + self.halfway_bounce_back_bc = HalfwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + if self.fullway_bounce_back_bc is None: + self.fullway_bounce_back_bc = FullwayBounceBackBC( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend + ) + ############################################ # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) - - ###################################################### - # TODO: This is a hacky way to do this. Need to refactor - ###################################################### - """ - def _construct_warp_bc_functional(self): - # identity collision boundary condition - @wp.func - def identity( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, - ): - return f_post - def get_bc_functional(id_number, self.collision_boundary_conditions): - if id_number in self.collision_boundary_conditions.keys(): - return self.collision_boundary_conditions[id_number].warp_functional - else: - return identity - - # Manually set the boundary conditions TODO: Extremely hacky - collision_bc_functional_0 = get_bc_functional(0, self.collision_boundary_conditions) - collision_bc_functional_1 = get_bc_functional(1, self.collision_boundary_conditions) - collision_bc_functional_2 = get_bc_functional(2, self.collision_boundary_conditions) - collision_bc_functional_3 = get_bc_functional(3, self.collision_boundary_conditions) - collision_bc_functional_4 = get_bc_functional(4, self.collision_boundary_conditions) - collision_bc_functional_5 = get_bc_functional(5, self.collision_boundary_conditions) - collision_bc_functional_6 = get_bc_functional(6, self.collision_boundary_conditions) - collision_bc_functional_7 = get_bc_functional(7, self.collision_boundary_conditions) - collision_bc_functional_8 = get_bc_functional(8, self.collision_boundary_conditions) - - # Make the warp boundary condition functional - @wp.func - def warp_bc( - f_pre: self._warp_lattice_vec, - f_post: self._warp_lattice_vec, - mask: self._warp_bool_lattice_vec, - boundary_id: wp.uint8, - ): - if boundary_id == 0: - f_post = collision_bc_functional_0(f_pre, f_post, mask) - elif boundary_id == 1: - f_post = collision_bc_functional_1(f_pre, f_post, mask) - elif boundary_id == 2: - f_post = collision_bc_functional_2(f_pre, f_post, mask) - elif boundary_id == 3: - f_post = collision_bc_functional_3(f_pre, f_post, mask) - elif boundary_id == 4: - f_post = collision_bc_functional_4(f_pre, f_post, mask) - elif boundary_id == 5: - f_post = collision_bc_functional_5(f_pre, f_post, mask) - elif boundary_id == 6: - f_post = collision_bc_functional_6(f_pre, f_post, mask) - elif boundary_id == 7: - f_post = collision_bc_functional_7(f_pre, f_post, mask) - elif boundary_id == 8: - f_post = collision_bc_functional_8(f_pre, f_post, mask) - - return f_post - - - - - ###################################################### - """ - - - - - - - -class SetBoundary(Operator): - """ - Class that handles the construction of lattice boltzmann boundary condition operator - This will probably never be used directly and it might be better to refactor it - """ - - def __init__( - self, - collision_boundary_conditions, - stream_boundary_conditions, - velocity_set, - precision_policy, - compute_backend, - ): - super().__init__(velocity_set, precision_policy, compute_backend) - - # Set parameters - self.collision_boundary_conditions = collision_boundary_conditions - self.stream_boundary_conditions = stream_boundary_conditions - - def _apply_all_bc(self, ijk, boundary_id, mask, bc): - """ - Apply all boundary conditions - """ - for id_number, bc in self.collision_boundary_conditions.items(): - boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, id_number) - for id_number, bc in self.stream_boundary_conditions.items(): - boundary_id, mask = bc.boundary_masker(ijk, boundary_id, mask, id_number) - return boundary_id, mask - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) - mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) - return self._apply_all_bc(ijk, boundary_id, mask, bc) - - @Operator.register_backend(ComputeBackend.PALLAS) - def pallas_implementation(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - raise NotImplementedError("Pallas implementation not available") - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - boundary_id = wp.zeros(ijk.shape[:-1], dtype=wp.uint8) - mask = wp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=wp.bool) - return self._apply_all_bc(ijk, boundary_id, mask, bc) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 5033cac..8bb2568 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit, vmap import warp as wp +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -12,7 +13,7 @@ class Stream(Operator): """ - Base class for all streaming operators. + Base class for all streaming operators. This is used for pulling the distribution """ @Operator.register_backend(ComputeBackend.JAX) @@ -21,6 +22,8 @@ def jax_implementation(self, f): """ JAX implementation of the streaming step. + TODO: Make sure this works with pull scheme. + Parameters ---------- f: jax.numpy.ndarray @@ -42,7 +45,9 @@ def _streaming_jax_i(f, c): The updated distribution function after streaming. """ if self.velocity_set.d == 2: - return jnp.roll(f, (c[0], c[1]), axis=(0, 1)) + return jnp.roll( + f, (c[0], c[1]), axis=(0, 1) + ) elif self.velocity_set.d == 3: return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) @@ -51,57 +56,50 @@ def _streaming_jax_i(f, c): ) def _construct_warp(self): - # Make constants for warp - _c = wp.constant(self._warp_int_stream_mat(self.velocity_set.c)) - _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the funcional to get streamed indices @wp.func def functional( - l: int, - i: int, - j: int, - k: int, - max_i: int, - max_j: int, - max_k: int, + f: wp.array4d(dtype=Any), + index: Any, ): - streamed_i = i + _c[0, l] - streamed_j = j + _c[1, l] - streamed_k = k + _c[2, l] - if streamed_i < 0: - streamed_i = max_i - 1 - elif streamed_i >= max_i: - streamed_i = 0 - if streamed_j < 0: - streamed_j = max_j - 1 - elif streamed_j >= max_j: - streamed_j = 0 - if streamed_k < 0: - streamed_k = max_k - 1 - elif streamed_k >= max_k: - streamed_k = 0 - return streamed_i, streamed_j, streamed_k + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + # Get pull index + pull_index = type(index)() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, l] + + if pull_index[d] < 0: + pull_index[d] = f.shape[d + 1] - 1 + elif pull_index[d] >= f.shape[d + 1]: + pull_index[d] = 0 + + # Read the distribution function + _f[l] = f[l, pull_index[0], pull_index[1], pull_index[2]] + + return _f # Construct the warp kernel @wp.kernel def kernel( - f_0: self._warp_array_type, - f_1: self._warp_array_type, - max_i: int, - max_j: int, - max_k: int, + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), ): # Get the global index i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # Set the output - for l in range(_q): - streamed_i, streamed_j, streamed_k = functional( - l, i, j, k, max_i, max_j, max_k - ) - f_1[l, streamed_i, streamed_j, streamed_k] = f_0[l, i, j, k] + _f = functional(f_0, index) + + # Write the output + for l in range(self.velocity_set.q): + f_1[l, index[0], index[1], index[2]] = _f[l] return functional, kernel @@ -109,13 +107,10 @@ def kernel( def warp_implementation(self, f_0, f_1): # Launch the warp kernel wp.launch( - self._kernel, + self.warp_kernel, inputs=[ f_0, f_1, - f_0.shape[1], - f_0.shape[2], - f_0.shape[3], ], dim=f_0.shape[1:], ) diff --git a/xlb/operator/test/test.py b/xlb/operator/test/test.py deleted file mode 100644 index 7d4290a..0000000 --- a/xlb/operator/test/test.py +++ /dev/null @@ -1 +0,0 @@ -x = 1 diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 0ba6c1c..db8a422 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -2,12 +2,45 @@ from enum import Enum, auto +import jax.numpy as jnp +import warp as wp class Precision(Enum): FP64 = auto() FP32 = auto() FP16 = auto() + UINT8 = auto() + BOOL = auto() + @property + def wp_dtype(self): + if self == Precision.FP64: + return wp.float64 + elif self == Precision.FP32: + return wp.float32 + elif self == Precision.FP16: + return wp.float16 + elif self == Precision.UINT8: + return wp.uint8 + elif self == Precision.BOOL: + return wp.bool + else: + raise ValueError("Invalid precision") + + @property + def jax_dtype(self): + if self == Precision.FP64: + return jnp.float64 + elif self == Precision.FP32: + return jnp.float32 + elif self == Precision.FP16: + return jnp.float16 + elif self == Precision.UINT8: + return jnp.uint8 + elif self == Precision.BOOL: + return jnp.bool_ + else: + raise ValueError("Invalid precision") class PrecisionPolicy(Enum): FP64FP64 = auto() diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 6c4c0c2..1b0ab43 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -6,9 +6,9 @@ from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend -from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.collision import BGK, KBC +from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium +from xlb.operator.collision.bgk import BGK +from xlb.operator.collision.kbc import KBC from xlb.operator.stream import Stream from xlb.operator.macroscopic import Macroscopic from xlb.solver.solver import Solver diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py index 7d3db77..335fd72 100644 --- a/xlb/solver/solver.py +++ b/xlb/solver/solver.py @@ -1,9 +1,8 @@ # Base class for all stepper operators from xlb.compute_backend import ComputeBackend -from xlb.operator.boundary_condition import ImplementationStep from xlb.global_config import GlobalConfig -from xlb.operator import Operator +from xlb.operator.operator import Operator class Solver(Operator): diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 0564fde..a137b87 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -43,6 +43,13 @@ def __init__(self, d, q, c, w): self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() + # Make warp constants for these vectors + # TODO: Following warp updates these may not be necessary + self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) + self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow + self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) + + def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype)