diff --git a/examples/refactor/example_basic.py b/examples/refactor/example_basic.py new file mode 100644 index 0000000..9c9b033 --- /dev/null +++ b/examples/refactor/example_basic.py @@ -0,0 +1,62 @@ +import xlb +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 + +from xlb.solver import IncompressibleNavierStokes +from xlb.grid import Grid +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.utils import save_fields_vtk, save_image + +xlb.init( + precision_policy=Fp32Fp32, + compute_backend=ComputeBackends.JAX, + velocity_set=xlb.velocity_set.D2Q9, +) + +grid_shape = (1000, 1000) +grid = Grid.create(grid_shape) + + +def initializer(): + rho = grid.create_field(cardinality=1) + 1.0 + u = grid.create_field(cardinality=2) + + circle_center = (grid_shape[0] // 2, grid_shape[1] // 2) + circle_radius = 10 + + for x in range(grid_shape[0]): + for y in range(grid_shape[1]): + if (x - circle_center[0]) ** 2 + ( + y - circle_center[1] + ) ** 2 <= circle_radius**2: + rho = rho.at[0, x, y].add(0.001) + + func_eq = QuadraticEquilibrium() + f_eq = func_eq(rho, u) + + return f_eq + + +f = initializer() + +compute_macro = Macroscopic() + +solver = IncompressibleNavierStokes(grid, omega=1.0) + + +def perform_io(f, step): + rho, u = compute_macro(f) + fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1]} + save_fields_vtk(fields, step) + save_image(rho[0], step) + print(f"Step {step + 1} complete") + + +num_steps = 1000 +io_rate = 100 +for step in range(num_steps): + f = solver.step(f, timestep=step) + + if step % io_rate == 0: + perform_io(f, step) diff --git a/examples/refactor/example_mehdi.py b/examples/refactor/example_mehdi.py deleted file mode 100644 index b758d3e..0000000 --- a/examples/refactor/example_mehdi.py +++ /dev/null @@ -1,37 +0,0 @@ -import xlb -from xlb.compute_backends import ComputeBackends -from xlb.precision_policy import Fp32Fp32 - -from xlb.solver import IncompressibleNavierStokes -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.stream import Stream -from xlb.global_config import GlobalConfig -from xlb.grid import Grid -from xlb.operator.initializer import EquilibriumInitializer, ConstInitializer - -import numpy as np -import jax.numpy as jnp - -xlb.init(precision_policy=Fp32Fp32, compute_backend=ComputeBackends.JAX, velocity_set=xlb.velocity_set.D2Q9) - -grid_shape = (100, 100) -grid = Grid.create(grid_shape) - -f_init = grid.create_field(cardinality=9, callback=EquilibriumInitializer(grid)) - -u_init = grid.create_field(cardinality=2, callback=ConstInitializer(grid, cardinality=2, const_value=0.0)) -rho_init = grid.create_field(cardinality=1, callback=ConstInitializer(grid, cardinality=1, const_value=1.0)) - - -st = Stream(grid) - -f_init = st(f_init) -print("here") -solver = IncompressibleNavierStokes(grid) - -num_steps = 100 -f = f_init -for step in range(num_steps): - f = solver.step(f, timestep=step) - print(f"Step {step+1}/{num_steps} complete") - diff --git a/examples/refactor/mlups3d.py b/examples/refactor/mlups3d.py new file mode 100644 index 0000000..f1207ff --- /dev/null +++ b/examples/refactor/mlups3d.py @@ -0,0 +1,53 @@ +import xlb +import time +import jax +import argparse +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 +from xlb.operator.initializer import EquilibriumInitializer + +from xlb.solver import IncompressibleNavierStokes +from xlb.grid import Grid + +parser = argparse.ArgumentParser( + description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" +) +parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") +parser.add_argument("num_steps", type=int, help="Timestep for the simulation") + +args = parser.parse_args() + +cube_edge = args.cube_edge +num_steps = args.num_steps + + +xlb.init( + precision_policy=Fp32Fp32, + compute_backend=ComputeBackends.JAX, + velocity_set=xlb.velocity_set.D3Q19, +) + +grid_shape = (cube_edge, cube_edge, cube_edge) +grid = Grid.create(grid_shape) + +f = grid.create_field(cardinality=19, callback=EquilibriumInitializer(grid)) + +solver = IncompressibleNavierStokes(grid, omega=1.0) + +# Ahead-of-Time Compilation to remove JIT overhead + + +if xlb.current_backend() == ComputeBackends.JAX: + lowered = jax.jit(solver.step).lower(f, timestep=0) + solver_step_compiled = lowered.compile() + +start_time = time.time() + +for step in range(num_steps): + f = solver_step_compiled(f, timestep=step) + +end_time = time.time() +total_lattice_updates = cube_edge**3 * num_steps +total_time_seconds = end_time - start_time +mlups = (total_lattice_updates / total_time_seconds) / 1e6 +print(f"MLUPS: {mlups}") diff --git a/xlb/__init__.py b/xlb/__init__.py index 7f13a8e..7845bb2 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -4,7 +4,7 @@ # Config -from .global_config import init +from .global_config import init, current_backend # Precision policy @@ -25,4 +25,7 @@ import xlb.grid # Solvers -import xlb.solver \ No newline at end of file +import xlb.solver + +# Utils +import xlb.utils \ No newline at end of file diff --git a/xlb/global_config.py b/xlb/global_config.py index dd3e705..c0047c9 100644 --- a/xlb/global_config.py +++ b/xlb/global_config.py @@ -8,3 +8,7 @@ def init(velocity_set, compute_backend, precision_policy): GlobalConfig.velocity_set = velocity_set() GlobalConfig.compute_backend = compute_backend GlobalConfig.precision_policy = precision_policy() + + +def current_backend(): + return GlobalConfig.compute_backend \ No newline at end of file diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 57270c5..af6d239 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -23,7 +23,7 @@ def initialize_jax_backend(self): self.global_mesh = ( Mesh(device_mesh, axis_names=("cardinality", "x", "y")) if self.dim == 2 - else Mesh(self.devices, axis_names=("cardinality", "x", "y", "z")) + else Mesh(device_mesh, axis_names=("cardinality", "x", "y", "z")) ) self.sharding = ( NamedSharding(self.global_mesh, P("cardinality", "x", "y")) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 1fd7458..3dc4993 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -31,7 +31,7 @@ def __init__( def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) - w = self.velocity_set.w.reshape(-1, 1, 1) + w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1)) feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py index 9d9dc56..bad7c85 100644 --- a/xlb/operator/initializer/equilibrium_init.py +++ b/xlb/operator/initializer/equilibrium_init.py @@ -17,6 +17,7 @@ def __init__( velocity_set = velocity_set or GlobalConfig.velocity_set compute_backend = compute_backend or GlobalConfig.compute_backend local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) + self.init_values = np.zeros( grid.global_to_local_shape(grid.pop_shape) ) + velocity_set.w.reshape(local_shape) diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index fc04db2..733f8f7 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -1,13 +1,14 @@ # Base class for all equilibriums +from xlb.global_config import GlobalConfig +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator + from functools import partial import jax.numpy as jnp from jax import jit -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator - class Macroscopic(Operator): """ @@ -20,9 +21,12 @@ class Macroscopic(Operator): def __init__( self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): + self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.compute_backend = compute_backend or GlobalConfig.compute_backend + super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) @@ -30,6 +34,20 @@ def __init__( def jax_implementation(self, f): """ Apply the macroscopic operator to the lattice distribution function + TODO: Check if the following implementation is more efficient ( + as the compiler may be able to remove operations resulting in zero) + c_x = tuple(self.velocity_set.c[0]) + c_y = tuple(self.velocity_set.c[1]) + + u_x = 0.0 + u_y = 0.0 + + rho = jnp.sum(f, axis=0, keepdims=True) + + for i in range(self.velocity_set.q): + u_x += c_x[i] * f[i, ...] + u_y += c_y[i] * f[i, ...] + return rho, jnp.stack((u_x, u_y)) """ rho = jnp.sum(f, axis=0, keepdims=True) u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 277e9a1..e942961 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -21,7 +21,7 @@ def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None) super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) - # @partial(jit, static_argnums=(0)) + @partial(jit, static_argnums=(0)) def jax_implementation(self, f): """ JAX implementation of the streaming step. @@ -38,7 +38,6 @@ def jax_implementation(self, f): mesh=self.grid.global_mesh, in_specs=in_specs, out_specs=out_specs, - check_rep=False, )(f) def _streaming_jax_p(self, f): diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py index 96a529c..251fe37 100644 --- a/xlb/solver/nse.py +++ b/xlb/solver/nse.py @@ -18,6 +18,7 @@ class IncompressibleNavierStokes(Solver): def __init__( self, grid, + omega, velocity_set: VelocitySet = None, compute_backend=None, precision_policy=None, @@ -25,6 +26,7 @@ def __init__( collision_kernel="BGK", ): self.grid = grid + self.omega = omega self.collision_kernel = collision_kernel super().__init__(velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, boundary_conditions=boundary_conditions) self.create_operators() @@ -39,13 +41,13 @@ def create_operators(self): ) self.collision = ( KBC( - omega=1.0, + omega=self.omega, velocity_set=self.velocity_set, compute_backend=self.compute_backend, ) if self.collision_kernel == "KBC" else BGK( - omega=1.0, + omega=self.omega, velocity_set=self.velocity_set, compute_backend=self.compute_backend, ) diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py new file mode 100644 index 0000000..2107fc8 --- /dev/null +++ b/xlb/utils/__init__.py @@ -0,0 +1 @@ +from .utils import downsample_field, save_image, save_fields_vtk, save_BCs_vtk, rotate_geometry, voxelize_stl, axangle2mat diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index d01c500..6d9c627 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -47,7 +47,7 @@ def downsample_field(field, factor, method="bicubic"): return jnp.stack(downsampled_components, axis=-1) -def save_image(timestep, fld, prefix=None): +def save_image(fld, timestep, prefix=None): """ Save an image of a field at a given timestep. @@ -78,13 +78,13 @@ def save_image(timestep, fld, prefix=None): if len(fld.shape) > 3: raise ValueError("The input field should be 2D!") elif len(fld.shape) == 3: - fld = np.sqrt(fld[..., 0] ** 2 + fld[..., 1] ** 2) + fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2) plt.clf() plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower") -def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): +def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): """ Save VTK fields to the specified directory. @@ -111,7 +111,7 @@ def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): will be saved as 'fields_0000010.vtk'in the specified directory. """ - # Assert that all fields have the same dimensions except for the last dimension assuming fields is a dictionary + # Assert that all fields have the same dimensions for key, value in fields.items(): if key == list(fields.keys())[0]: dimensions = value.shape @@ -140,53 +140,6 @@ def save_fields_vtk(timestep, fields, output_dir=".", prefix="fields"): grid.save(output_filename, binary=True) print(f"Saved {output_filename} in {time() - start:.6f} seconds.") - -def live_volume_randering(timestep, field): - # WORK IN PROGRESS - """ - Live rendering of a 3D volume using pyvista. - - Parameters - ---------- - field (np.ndarray): A 3D array containing the field to be rendered. - - Returns - ------- - None - - Notes - ----- - This function uses pyvista to render a 3D volume. The volume is rendered with a colormap based on the field values. - The colormap is updated every 0.1 seconds to reflect changes to the field. - - """ - # Create a uniform grid (Note that the field must be 3D) otherwise raise error - if field.ndim != 3: - raise ValueError("The input field must be 3D!") - dimensions = field.shape - grid = pv.UniformGrid(dimensions=dimensions) - - # Add the field to the grid - grid["field"] = field.flatten(order="F") - - # Create the rendering scene - if timestep == 0: - plt.ion() - plt.figure(figsize=(10, 10)) - plt.axis("off") - plt.title("Live rendering of the field") - pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) - plt.imshow(pl.screenshot()) - - else: - pl = pv.Plotter(off_screen=True) - pl.add_volume(grid, cmap="nipy_spectral", opacity="sigmoid_10", shade=False) - # Update the rendering scene every 0.1 seconds - plt.imshow(pl.screenshot()) - plt.pause(0.1) - - def save_BCs_vtk(timestep, BCs, gridInfo, output_dir="."): """ Save boundary conditions as VTK format to the specified directory.