From a05441f173a557bcd22e3fe5cffedb4b600bcce5 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Fri, 26 Jan 2024 17:21:38 -0500 Subject: [PATCH] example_mehdi works with no error --- examples/refactor/example_mehdi.py | 37 ++++++ test.py | 40 +++++++ xlb/__init__.py | 12 +- xlb/global_config.py | 10 ++ xlb/grid/__init__.py | 1 + xlb/grid/grid.py | 34 ++++++ xlb/grid/jax_grid.py | 46 ++++++++ xlb/operator/collision/bgk.py | 6 +- xlb/operator/collision/collision.py | 5 +- xlb/operator/collision/kbc.py | 14 ++- xlb/operator/equilibrium/equilibrium.py | 5 +- .../equilibrium/quadratic_equilibrium.py | 16 +-- xlb/operator/initializer/__init__.py | 2 + xlb/operator/initializer/const_init.py | 28 +++++ xlb/operator/initializer/equilibrium_init.py | 28 +++++ xlb/operator/macroscopic/macroscopic.py | 17 ++- xlb/operator/operator.py | 33 +++--- xlb/operator/stepper/__init__.py | 2 - xlb/operator/stepper/nse.py | 92 --------------- xlb/operator/stepper/stepper.py | 84 -------------- xlb/operator/stream/stream.py | 95 ++++++++------- xlb/precision_policy/__init__.py | 2 +- xlb/precision_policy/base_precision_policy.py | 14 +++ .../jax_precision_policy.py | 12 +- .../jax_precision_policy/___init__.py | 1 - xlb/precision_policy/precision_policy.py | 81 +++++++++---- xlb/solver/__init__.py | 1 + xlb/solver/nse.py | 108 ++++++++++++++++++ xlb/solver/solver.py | 37 ++++++ 29 files changed, 573 insertions(+), 290 deletions(-) create mode 100644 examples/refactor/example_mehdi.py create mode 100644 test.py create mode 100644 xlb/global_config.py create mode 100644 xlb/grid/__init__.py create mode 100644 xlb/grid/grid.py create mode 100644 xlb/grid/jax_grid.py create mode 100644 xlb/operator/initializer/__init__.py create mode 100644 xlb/operator/initializer/const_init.py create mode 100644 xlb/operator/initializer/equilibrium_init.py delete mode 100644 xlb/operator/stepper/__init__.py delete mode 100644 xlb/operator/stepper/nse.py delete mode 100644 xlb/operator/stepper/stepper.py create mode 100644 xlb/precision_policy/base_precision_policy.py rename xlb/precision_policy/{jax_precision_policy => }/jax_precision_policy.py (85%) delete mode 100644 xlb/precision_policy/jax_precision_policy/___init__.py create mode 100644 xlb/solver/__init__.py create mode 100644 xlb/solver/nse.py create mode 100644 xlb/solver/solver.py diff --git a/examples/refactor/example_mehdi.py b/examples/refactor/example_mehdi.py new file mode 100644 index 0000000..b758d3e --- /dev/null +++ b/examples/refactor/example_mehdi.py @@ -0,0 +1,37 @@ +import xlb +from xlb.compute_backends import ComputeBackends +from xlb.precision_policy import Fp32Fp32 + +from xlb.solver import IncompressibleNavierStokes +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.stream import Stream +from xlb.global_config import GlobalConfig +from xlb.grid import Grid +from xlb.operator.initializer import EquilibriumInitializer, ConstInitializer + +import numpy as np +import jax.numpy as jnp + +xlb.init(precision_policy=Fp32Fp32, compute_backend=ComputeBackends.JAX, velocity_set=xlb.velocity_set.D2Q9) + +grid_shape = (100, 100) +grid = Grid.create(grid_shape) + +f_init = grid.create_field(cardinality=9, callback=EquilibriumInitializer(grid)) + +u_init = grid.create_field(cardinality=2, callback=ConstInitializer(grid, cardinality=2, const_value=0.0)) +rho_init = grid.create_field(cardinality=1, callback=ConstInitializer(grid, cardinality=1, const_value=1.0)) + + +st = Stream(grid) + +f_init = st(f_init) +print("here") +solver = IncompressibleNavierStokes(grid) + +num_steps = 100 +f = f_init +for step in range(num_steps): + f = solver.step(f, timestep=step) + print(f"Step {step+1}/{num_steps} complete") + diff --git a/test.py b/test.py new file mode 100644 index 0000000..5529ac3 --- /dev/null +++ b/test.py @@ -0,0 +1,40 @@ +import jax.numpy as jnp +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.stream import Stream +from xlb.velocity_set import D2Q9, D3Q27 +from xlb.operator.collision.kbc import KBC, BGK +from xlb.compute_backends import ComputeBackends +from xlb.grid import Grid +import xlb + + +xlb.init(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +collision = BGK(omega=0.6) + +# eq = QuadraticEquilibrium(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +# macro = Macroscopic(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +# s = Stream(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +Q = 19 +# create random jnp arrays +f = jnp.ones((Q, 10, 10)) +rho = jnp.ones((1, 10, 10)) +u = jnp.zeros((2, 10, 10)) +# feq = eq(rho, u) + +print(collision(f, f)) + +grid = Grid.create(grid_shape=(10, 10), velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX) + +def advection_result(index): + return 1.0 + + +f = grid.initialize_pop(advection_result) + +print(f) +print(f.sharding) \ No newline at end of file diff --git a/xlb/__init__.py b/xlb/__init__.py index e4edc5c..7f13a8e 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -2,6 +2,11 @@ from xlb.compute_backends import ComputeBackends from xlb.physics_type import PhysicsType + +# Config +from .global_config import init + + # Precision policy import xlb.precision_policy @@ -15,4 +20,9 @@ import xlb.operator.boundary_condition # import xlb.operator.force import xlb.operator.macroscopic -import xlb.operator.stepper + +# Grids +import xlb.grid + +# Solvers +import xlb.solver \ No newline at end of file diff --git a/xlb/global_config.py b/xlb/global_config.py new file mode 100644 index 0000000..dd3e705 --- /dev/null +++ b/xlb/global_config.py @@ -0,0 +1,10 @@ +class GlobalConfig: + precision_policy = None + velocity_set = None + compute_backend = None + + +def init(velocity_set, compute_backend, precision_policy): + GlobalConfig.velocity_set = velocity_set() + GlobalConfig.compute_backend = compute_backend + GlobalConfig.precision_policy = precision_policy() diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py new file mode 100644 index 0000000..583b72e --- /dev/null +++ b/xlb/grid/__init__.py @@ -0,0 +1 @@ +from xlb.grid.grid import Grid \ No newline at end of file diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py new file mode 100644 index 0000000..7a03950 --- /dev/null +++ b/xlb/grid/grid.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from xlb.compute_backends import ComputeBackends +from xlb.global_config import GlobalConfig +from xlb.velocity_set import VelocitySet + + +class Grid(ABC): + def __init__(self, grid_shape, velocity_set, compute_backend): + self.velocity_set: VelocitySet = velocity_set + self.compute_backend = compute_backend + self.grid_shape = grid_shape + self.pop_shape = (self.velocity_set.q, *grid_shape) + self.u_shape = (self.velocity_set.d, *grid_shape) + self.rho_shape = (1, *grid_shape) + self.dim = self.velocity_set.d + + @abstractmethod + def create_field(self, cardinality, callback=None): + pass + + @staticmethod + def create(grid_shape, velocity_set=None, compute_backend=None): + compute_backend = compute_backend or GlobalConfig.compute_backend + velocity_set = velocity_set or GlobalConfig.velocity_set + + if compute_backend == ComputeBackends.JAX: + from xlb.grid.jax_grid import JaxGrid # Avoids circular import + + return JaxGrid(grid_shape, velocity_set, compute_backend) + raise ValueError(f"Compute backend {compute_backend} is not supported") + + @abstractmethod + def global_to_local_shape(self, shape): + pass diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py new file mode 100644 index 0000000..57270c5 --- /dev/null +++ b/xlb/grid/jax_grid.py @@ -0,0 +1,46 @@ +from xlb.grid.grid import Grid +from xlb.compute_backends import ComputeBackends +from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, Mesh +from jax.experimental import mesh_utils +from xlb.operator.initializer import ConstInitializer +import jax + + +class JaxGrid(Grid): + def __init__(self, grid_shape, velocity_set, compute_backend): + super().__init__(grid_shape, velocity_set, compute_backend) + self.initialize_jax_backend() + + def initialize_jax_backend(self): + self.nDevices = jax.device_count() + self.backend = jax.default_backend() + device_mesh = ( + mesh_utils.create_device_mesh((1, self.nDevices, 1)) + if self.dim == 2 + else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) + ) + self.global_mesh = ( + Mesh(device_mesh, axis_names=("cardinality", "x", "y")) + if self.dim == 2 + else Mesh(self.devices, axis_names=("cardinality", "x", "y", "z")) + ) + self.sharding = ( + NamedSharding(self.global_mesh, P("cardinality", "x", "y")) + if self.dim == 2 + else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z")) + ) + + def global_to_local_shape(self, shape): + if len(shape) < 2: + raise ValueError("Shape must have at least two dimensions") + + new_second_index = shape[1] // self.nDevices + + return shape[:1] + (new_second_index,) + shape[2:] + + def create_field(self, cardinality, callback=None): + if callback is None: + callback = ConstInitializer(self, cardinality, const_value=0.0) + shape = (cardinality,) + (self.grid_shape) + return jax.make_array_from_callback(shape, self.sharding, callback) diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 90943b8..9dfdf33 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -15,8 +15,8 @@ class BGK(Collision): def __init__( self, omega: float, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__( omega=omega, velocity_set=velocity_set, compute_backend=compute_backend @@ -24,7 +24,7 @@ def __init__( @Operator.register_backend(ComputeBackends.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation_2(self, f: jnp.ndarray, feq: jnp.ndarray): + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray): fneq = f - feq fout = f - self.omega * fneq return fout diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 9c22895..3243e6f 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -1,7 +1,6 @@ """ Base class for Collision operators """ -from xlb.compute_backends import ComputeBackends from xlb.velocity_set import VelocitySet from xlb.operator import Operator @@ -23,8 +22,8 @@ class Collision(Operator): def __init__( self, omega: float, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__(velocity_set, compute_backend) self.omega = omega diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index e79b6bd..1b24a6b 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -22,8 +22,8 @@ class KBC(Collision): def __init__( self, omega, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__( omega=omega, velocity_set=velocity_set, compute_backend=compute_backend @@ -75,6 +75,16 @@ def jax_implementation( return fout + @Operator.register_backend(ComputeBackends.WARP) + @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) + def warp_implementation( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + ): + raise NotImplementedError("Warp implementation not yet implemented") + @partial(jit, static_argnums=(0,), inline=True) def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarray): """ diff --git a/xlb/operator/equilibrium/equilibrium.py b/xlb/operator/equilibrium/equilibrium.py index f60d9ee..4bed600 100644 --- a/xlb/operator/equilibrium/equilibrium.py +++ b/xlb/operator/equilibrium/equilibrium.py @@ -1,6 +1,5 @@ # Base class for all equilibriums from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator @@ -11,7 +10,7 @@ class Equilibrium(Operator): def __init__( self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): super().__init__(velocity_set, compute_backend) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index bc4282b..1fd7458 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -5,6 +5,7 @@ from xlb.operator.equilibrium.equilibrium import Equilibrium from functools import partial from xlb.operator import Operator +from xlb.global_config import GlobalConfig class QuadraticEquilibrium(Equilibrium): @@ -17,21 +18,20 @@ class QuadraticEquilibrium(Equilibrium): def __init__( self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, + velocity_set: VelocitySet = None, + compute_backend=None, ): + velocity_set = velocity_set or GlobalConfig.velocity_set + compute_backend = compute_backend or GlobalConfig.compute_backend + super().__init__(velocity_set, compute_backend) @Operator.register_backend(ComputeBackends.JAX) - # @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) + @partial(jit, static_argnums=(0), donate_argnums=(1, 2)) def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) w = self.velocity_set.w.reshape(-1, 1, 1) - feq = ( - rho - * w - * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) - ) + feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq diff --git a/xlb/operator/initializer/__init__.py b/xlb/operator/initializer/__init__.py new file mode 100644 index 0000000..4d3f07d --- /dev/null +++ b/xlb/operator/initializer/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer +from xlb.operator.initializer.const_init import ConstInitializer \ No newline at end of file diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py new file mode 100644 index 0000000..d3b13cb --- /dev/null +++ b/xlb/operator/initializer/const_init.py @@ -0,0 +1,28 @@ +from xlb.velocity_set import VelocitySet +from xlb.global_config import GlobalConfig +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator +from xlb.grid.grid import Grid +import numpy as np +import jax + + +class ConstInitializer(Operator): + def __init__( + self, + grid: Grid, + cardinality, + const_value=0.0, + velocity_set: VelocitySet = None, + compute_backend: ComputeBackends = None, + ): + velocity_set = velocity_set or GlobalConfig.velocity_set + compute_backend = compute_backend or GlobalConfig.compute_backend + shape = (cardinality,) + (grid.grid_shape) + self.init_values = np.zeros(grid.global_to_local_shape(shape)) + const_value + + super().__init__(velocity_set, compute_backend) + + @Operator.register_backend(ComputeBackends.JAX) + def jax_implementation(self, index): + return self.init_values diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py new file mode 100644 index 0000000..9d9dc56 --- /dev/null +++ b/xlb/operator/initializer/equilibrium_init.py @@ -0,0 +1,28 @@ +from xlb.velocity_set import VelocitySet +from xlb.global_config import GlobalConfig +from xlb.compute_backends import ComputeBackends +from xlb.operator.operator import Operator +from xlb.grid.grid import Grid +import numpy as np +import jax + + +class EquilibriumInitializer(Operator): + def __init__( + self, + grid: Grid, + velocity_set: VelocitySet = None, + compute_backend: ComputeBackends = None, + ): + velocity_set = velocity_set or GlobalConfig.velocity_set + compute_backend = compute_backend or GlobalConfig.compute_backend + local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) + self.init_values = np.zeros( + grid.global_to_local_shape(grid.pop_shape) + ) + velocity_set.w.reshape(local_shape) + + super().__init__(velocity_set, compute_backend) + + @Operator.register_backend(ComputeBackends.JAX) + def jax_implementation(self, index): + return self.init_values diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 55121e6..fc04db2 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -19,20 +19,19 @@ class Macroscopic(Operator): """ def __init__( - self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, - ): + self, + velocity_set: VelocitySet, + compute_backend=ComputeBackends.JAX, + ): super().__init__(velocity_set, compute_backend) + @Operator.register_backend(ComputeBackends.JAX) @partial(jit, static_argnums=(0), inline=True) - def apply_jax(self, f): + def jax_implementation(self, f): """ Apply the macroscopic operator to the lattice distribution function """ - - rho = jnp.sum(f, axis=-1, keepdims=True) - c = jnp.array(self.velocity_set.c, dtype=f.dtype).T - u = jnp.dot(f, c) / rho + rho = jnp.sum(f, axis=0, keepdims=True) + u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho return rho, u diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 592961d..700de7e 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,9 +1,7 @@ # Base class for all operators, (collision, streaming, equilibrium, etc.) -from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backends import ComputeBackends - - +from xlb.global_config import GlobalConfig class Operator: """ Base class for all operators, collision, streaming, equilibrium, etc. @@ -14,9 +12,11 @@ class Operator: _backends = {} def __init__(self, velocity_set, compute_backend): - self.velocity_set = velocity_set - self.compute_backend = compute_backend - if compute_backend not in ComputeBackends: + # Set the default values from the global config + self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.compute_backend = compute_backend or GlobalConfig.compute_backend + + if self.compute_backend not in ComputeBackends: raise ValueError(f"Compute backend {compute_backend} is not supported") @classmethod @@ -34,14 +34,25 @@ def decorator(func): return decorator - def __call__(self, *args, **kwargs): + def __call__(self, *args, callback=None, **kwargs): """ Calls the operator with the compute backend specified in the constructor. + If a callback is provided, it is called either with the result of the operation + or with the original arguments and keyword arguments if the backend modifies them by reference. """ key = (self.__class__.__name__, self.compute_backend) backend_method = self._backends.get(key) + if backend_method: - return backend_method(self, *args, **kwargs) + result = backend_method(self, *args, **kwargs) + + # Determine what to pass to the callback based on the backend behavior + callback_arg = result if result is not None else (args, kwargs) + + if callback and callable(callback): + callback(callback_arg) + + return result else: raise NotImplementedError(f"Backend {self.compute_backend} not implemented") @@ -63,9 +74,3 @@ def _is_method_overridden(self, method_name): def __repr__(self): return f"{self.__class__.__name__}()" - - def data_handler(self, *args, **kwargs): - """ - Handles data for the operator. - """ - raise NotImplementedError("Child class must implement data_handler") diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py deleted file mode 100644 index 0469f13..0000000 --- a/xlb/operator/stepper/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from xlb.operator.stepper.stepper import Stepper -from xlb.operator.stepper.nse import NSE diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py deleted file mode 100644 index a8863b2..0000000 --- a/xlb/operator/stepper/nse.py +++ /dev/null @@ -1,92 +0,0 @@ -# Base class for all stepper operators - -from functools import partial -from jax import jit - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends -from xlb.operator.stepper.stepper import Stepper -from xlb.operator.boundary_condition import ImplementationStep - - -class NSE(Stepper): - """ - Class that handles the construction of lattice boltzmann stepping operator for the Navier-Stokes equations - - TODO: Check that the given operators (collision, stream, equilibrium, macroscopic, ...) are compatible - with the Navier-Stokes equations - """ - - def __init__( - self, - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions=[], - forcing=None, - precision_policy=None, - ): - super().__init__( - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions, - forcing, - precision_policy, - ) - - @partial(jit, static_argnums=(0,)) - def apply_jax(self, f, boundary_id, mask, timestep): - """ - Perform a single step of the lattice boltzmann method - """ - - # Cast to compute precision - f_pre_collision = self.precision_policy.cast_to_compute_jax(f) - - # Compute the macroscopic variables - rho, u = self.macroscopic(f_pre_collision) - - # Compute equilibrium - feq = self.equilibrium(rho, u) - - # Apply collision - f_post_collision = self.collision( - f, - feq, - rho, - u, - ) - - # Apply collision type boundary conditions - for id_number, bc in self.collision_boundary_conditions.items(): - f_post_collision = bc( - f_pre_collision, - f_post_collision, - boundary_id == id_number, - mask, - ) - f_pre_streaming = f_post_collision - - ## Apply forcing - # if self.forcing_op is not None: - # f = self.forcing_op.apply_jax(f, timestep) - - # Apply streaming - f_post_streaming = self.stream(f_pre_streaming) - - # Apply boundary conditions - for id_number, bc in self.stream_boundary_conditions.items(): - f_post_streaming = bc( - f_pre_streaming, - f_post_streaming, - boundary_id == id_number, - mask, - ) - - # Copy back to store precision - f = self.precision_policy.cast_to_store_jax(f_post_streaming) - - return f \ No newline at end of file diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py deleted file mode 100644 index 8730478..0000000 --- a/xlb/operator/stepper/stepper.py +++ /dev/null @@ -1,84 +0,0 @@ -# Base class for all stepper operators - -import jax.numpy as jnp - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.operator.boundary_condition import ImplementationStep - - -class Stepper(Operator): - """ - Class that handles the construction of lattice boltzmann stepping operator - """ - - def __init__( - self, - collision, - stream, - equilibrium, - macroscopic, - boundary_conditions=[], - forcing=None, - precision_policy=None, - compute_backend=ComputeBackends.JAX, - ): - # Set parameters - self.collision = collision - self.stream = stream - self.equilibrium = equilibrium - self.macroscopic = macroscopic - self.boundary_conditions = boundary_conditions - self.forcing = forcing - self.precision_policy = precision_policy - self.compute_backend = compute_backend - - # Get collision and stream boundary conditions - self.collision_boundary_conditions = {} - self.stream_boundary_conditions = {} - for id_number, bc in enumerate(self.boundary_conditions): - bc_id = id_number + 1 - if bc.implementation_step == ImplementationStep.COLLISION: - self.collision_boundary_conditions[bc_id] = bc - elif bc.implementation_step == ImplementationStep.STREAMING: - self.stream_boundary_conditions[bc_id] = bc - else: - raise ValueError("Boundary condition step not recognized") - - # Get all operators for checking - self.operators = [ - collision, - stream, - equilibrium, - macroscopic, - *boundary_conditions, - ] - - # Get velocity set and backend - velocity_sets = set([op.velocity_set for op in self.operators]) - assert len(velocity_sets) == 1, "All velocity sets must be the same" - self.velocity_set = velocity_sets.pop() - compute_backends = set([op.compute_backend for op in self.operators]) - assert len(compute_backends) == 1, "All compute backends must be the same" - self.compute_backend = compute_backends.pop() - - # Initialize operator - super().__init__(self.velocity_set, self.compute_backend) - - def set_boundary(self, ijk): - """ - Set boundary condition arrays - These store the boundary condition information for each boundary - """ - # Empty boundary condition array - boundary_id = jnp.zeros(ijk.shape[:-1], dtype=jnp.uint8) - mask = jnp.zeros(ijk.shape[:-1] + (self.velocity_set.q,), dtype=jnp.bool_) - - # Set boundary condition arrays - for id_number, bc in self.collision_boundary_conditions.items(): - boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) - for id_number, bc in self.stream_boundary_conditions.items(): - boundary_id, mask = bc.set_boundary(ijk, boundary_id, mask, id_number) - - return boundary_id, mask diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 7a4a1d8..277e9a1 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -2,32 +2,27 @@ from functools import partial import jax.numpy as jnp -from jax import jit, vmap -import numba +from jax import jit, vmap, lax from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backends import ComputeBackends from xlb.operator.operator import Operator +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P class Stream(Operator): """ Base class for all streaming operators. - - TODO: Currently only this one streaming operator is implemented but - in the future we may have more streaming operators. For example, - one might want a multi-step streaming operator. """ - def __init__( - self, - velocity_set: VelocitySet, - compute_backend=ComputeBackends.JAX, - ): + def __init__(self, grid, velocity_set: VelocitySet = None, compute_backend=None): + self.grid = grid super().__init__(velocity_set, compute_backend) - @partial(jit, static_argnums=(0), donate_argnums=(1,)) - def apply_jax(self, f): + @Operator.register_backend(ComputeBackends.JAX) + # @partial(jit, static_argnums=(0)) + def jax_implementation(self, f): """ JAX implementation of the streaming step. @@ -36,8 +31,18 @@ def apply_jax(self, f): f: jax.numpy.ndarray The distribution function. """ - - def _streaming(f, c): + in_specs = P(*((None, "x") + (self.grid.dim - 1) * (None,))) + out_specs = in_specs + return shard_map( + self._streaming_jax_m, + mesh=self.grid.global_mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + )(f) + + def _streaming_jax_p(self, f): + def _streaming_jax_i(f, c): """ Perform individual streaming operation in a direction. @@ -56,33 +61,45 @@ def _streaming(f, c): elif self.velocity_set.d == 3: return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) - return vmap(_streaming, in_axes=(-1, 0), out_axes=-1)( + return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( f, jnp.array(self.velocity_set.c).T ) - def construct_numba(self, dtype=numba.float32): + def _streaming_jax_m(self, f): """ - Numba implementation of the streaming step. + This function performs the streaming step in the Lattice Boltzmann Method, which is + the propagation of the distribution functions in the lattice. + + To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the + distribution functions that need to be communicated to the neighboring processes. + + The function then sends the left boundary slice to the right neighboring process and the right + boundary slice to the left neighboring process. The received data is then set to the + corresponding indices in the receiving domain. + + Parameters + ---------- + f: jax.numpy.ndarray + The array holding the distribution functions for the simulation. + + Returns + ------- + jax.numpy.ndarray + The distribution functions after the streaming operation. """ + rightPerm = [(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)] + leftPerm = [((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)] + + f = self._streaming_jax_p(f) + left_comm, right_comm = ( + f[self.velocity_set.right_indices, :1, ...], + f[self.velocity_set.left_indices, -1:, ...], + ) + left_comm, right_comm = ( + lax.ppermute(left_comm, perm=rightPerm, axis_name='x'), + lax.ppermute(right_comm, perm=leftPerm, axis_name='x'), + ) + f = f.at[self.velocity_set.right_indices, :1, ...].set(left_comm) + f = f.at[self.velocity_set.left_indices, -1:, ...].set(right_comm) - # Get needed values for numba functions - d = velocity_set.d - q = velocity_set.q - c = velocity_set.c.T - - # Make numba functions - @cuda.jit(device=True) - def _streaming(f_array, f, ijk): - # Stream to the next node - for _ in range(q): - if d == 2: - i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] - j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] - f_array[i, j, _] = f[_] - else: - i = (ijk[0] + int32(c[_, 0])) % f_array.shape[0] - j = (ijk[1] + int32(c[_, 1])) % f_array.shape[1] - k = (ijk[2] + int32(c[_, 2])) % f_array.shape[2] - f_array[i, j, k, _] = f[_] - - return _streaming + return f diff --git a/xlb/precision_policy/__init__.py b/xlb/precision_policy/__init__.py index 996d8e9..b1555aa 100644 --- a/xlb/precision_policy/__init__.py +++ b/xlb/precision_policy/__init__.py @@ -1 +1 @@ -from xlb.precision_policy.precision_policy import PrecisionPolicy \ No newline at end of file +from xlb.precision_policy.precision_policy import Fp64Fp64, Fp64Fp32, Fp32Fp32, Fp64Fp16, Fp32Fp16 \ No newline at end of file diff --git a/xlb/precision_policy/base_precision_policy.py b/xlb/precision_policy/base_precision_policy.py new file mode 100644 index 0000000..397e48d --- /dev/null +++ b/xlb/precision_policy/base_precision_policy.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + +class PrecisionPolicy(ABC): + def __init__(self, compute_dtype, storage_dtype): + self.compute_dtype = compute_dtype + self.storage_dtype = storage_dtype + + @abstractmethod + def cast_to_compute(self, array): + pass + + @abstractmethod + def cast_to_store(self, array): + pass diff --git a/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py b/xlb/precision_policy/jax_precision_policy.py similarity index 85% rename from xlb/precision_policy/jax_precision_policy/jax_precision_policy.py rename to xlb/precision_policy/jax_precision_policy.py index c9063c2..bd63adf 100644 --- a/xlb/precision_policy/jax_precision_policy/jax_precision_policy.py +++ b/xlb/precision_policy/jax_precision_policy.py @@ -1,4 +1,4 @@ -from xlb.precision_policy.precision_policy import PrecisionPolicy +from xlb.precision_policy.base_precision_policy import PrecisionPolicy from jax import jit from functools import partial import jax.numpy as jnp @@ -18,7 +18,7 @@ def cast_to_store(self, array): return array.astype(self.storage_dtype) -class Fp32Fp32(JaxPrecisionPolicy): +class JaxFp32Fp32(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation and storage precision both set to float32. @@ -32,7 +32,7 @@ def __init__(self): super().__init__(jnp.float32, jnp.float32) -class Fp64Fp64(JaxPrecisionPolicy): +class JaxFp64Fp64(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation and storage precision both set to float64. @@ -42,7 +42,7 @@ def __init__(self): super().__init__(jnp.float64, jnp.float64) -class Fp64Fp32(JaxPrecisionPolicy): +class JaxFp64Fp32(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation precision set to float64 and storage precision set to float32. @@ -52,7 +52,7 @@ def __init__(self): super().__init__(jnp.float64, jnp.float32) -class Fp64Fp16(JaxPrecisionPolicy): +class JaxFp64Fp16(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation precision set to float64 and storage precision set to float16. @@ -62,7 +62,7 @@ def __init__(self): super().__init__(jnp.float64, jnp.float16) -class Fp32Fp16(JaxPrecisionPolicy): +class JaxFp32Fp16(JaxPrecisionPolicy): """ Precision policy for lattice Boltzmann method with computation precision set to float32 and storage precision set to float16. diff --git a/xlb/precision_policy/jax_precision_policy/___init__.py b/xlb/precision_policy/jax_precision_policy/___init__.py deleted file mode 100644 index 1955b79..0000000 --- a/xlb/precision_policy/jax_precision_policy/___init__.py +++ /dev/null @@ -1 +0,0 @@ -from xlb.precision_policy.jax_precision_policy import Fp32Fp16, Fp32Fp32, Fp64Fp16, Fp64Fp32, Fp64Fp64 \ No newline at end of file diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 9a4bdb0..75f4434 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,22 +1,59 @@ -class PrecisionPolicy(object): - """ - Base class for precision policy in lattice Boltzmann method. - Stores dtype information and provides an interface for casting operations. - """ - def __init__(self, compute_dtype, storage_dtype): - self.compute_dtype = compute_dtype - self.storage_dtype = storage_dtype - - def cast_to_compute(self, array): - """ - Cast the array to the computation precision. - To be implemented by subclass. - """ - raise NotImplementedError - - def cast_to_store(self, array): - """ - Cast the array to the storage precision. - To be implemented by subclass. - """ - raise NotImplementedError \ No newline at end of file +from abc import ABC, abstractmethod +from xlb.compute_backends import ComputeBackends +from xlb.global_config import GlobalConfig + +from xlb.precision_policy.jax_precision_policy import ( + JaxFp32Fp32, + JaxFp32Fp16, + JaxFp64Fp64, + JaxFp64Fp32, + JaxFp64Fp16, +) + + +class Fp64Fp64: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp64Fp64() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + +class Fp64Fp32: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp64Fp32() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + + +class Fp32Fp32: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp32Fp32() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + + +class Fp64Fp16: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp64Fp16() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) + +class Fp32Fp16: + def __new__(cls): + if GlobalConfig.compute_backend == ComputeBackends.JAX: + return JaxFp32Fp16() + else: + raise ValueError( + f"Unsupported compute backend: {GlobalConfig.compute_backend}" + ) \ No newline at end of file diff --git a/xlb/solver/__init__.py b/xlb/solver/__init__.py new file mode 100644 index 0000000..62dfc30 --- /dev/null +++ b/xlb/solver/__init__.py @@ -0,0 +1 @@ +from xlb.solver.nse import IncompressibleNavierStokes diff --git a/xlb/solver/nse.py b/xlb/solver/nse.py new file mode 100644 index 0000000..96a529c --- /dev/null +++ b/xlb/solver/nse.py @@ -0,0 +1,108 @@ +# Base class for all stepper operators + +from functools import partial +from jax import jit + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.compute_backends import ComputeBackends +from xlb.operator.boundary_condition import ImplementationStep +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.collision import BGK, KBC +from xlb.operator.stream import Stream +from xlb.operator.macroscopic import Macroscopic +from xlb.solver.solver import Solver +from xlb.operator import Operator + + +class IncompressibleNavierStokes(Solver): + def __init__( + self, + grid, + velocity_set: VelocitySet = None, + compute_backend=None, + precision_policy=None, + boundary_conditions=[], + collision_kernel="BGK", + ): + self.grid = grid + self.collision_kernel = collision_kernel + super().__init__(velocity_set=velocity_set, compute_backend=compute_backend, precision_policy=precision_policy, boundary_conditions=boundary_conditions) + self.create_operators() + + # Operators + def create_operators(self): + self.macroscopic = Macroscopic( + velocity_set=self.velocity_set, compute_backend=self.compute_backend + ) + self.equilibrium = QuadraticEquilibrium( + velocity_set=self.velocity_set, compute_backend=self.compute_backend + ) + self.collision = ( + KBC( + omega=1.0, + velocity_set=self.velocity_set, + compute_backend=self.compute_backend, + ) + if self.collision_kernel == "KBC" + else BGK( + omega=1.0, + velocity_set=self.velocity_set, + compute_backend=self.compute_backend, + ) + ) + self.stream = Stream(self.grid, + velocity_set=self.velocity_set, compute_backend=self.compute_backend + ) + + @Operator.register_backend(ComputeBackends.JAX) + @partial(jit, static_argnums=(0,)) + def step(self, f, timestep): + """ + Perform a single step of the lattice boltzmann method + """ + + # Cast to compute precision + f = self.precision_policy.cast_to_compute(f) + + # Compute the macroscopic variables + rho, u = self.macroscopic(f) + + # Compute equilibrium + feq = self.equilibrium(rho, u) + + # Apply collision + f_post_collision = self.collision( + f, + feq, + ) + + # # Apply collision type boundary conditions + # for id_number, bc in self.collision_boundary_conditions.items(): + # f_post_collision = bc( + # f_pre_collision, + # f_post_collision, + # boundary_id == id_number, + # mask, + # ) + f_pre_streaming = f_post_collision + + ## Apply forcing + # if self.forcing_op is not None: + # f = self.forcing_op.apply_jax(f, timestep) + + # Apply streaming + f_post_streaming = self.stream(f_pre_streaming) + + # Apply boundary conditions + # for id_number, bc in self.stream_boundary_conditions.items(): + # f_post_streaming = bc( + # f_pre_streaming, + # f_post_streaming, + # boundary_id == id_number, + # mask, + # ) + + # Copy back to store precision + f = self.precision_policy.cast_to_store(f_post_streaming) + + return f diff --git a/xlb/solver/solver.py b/xlb/solver/solver.py new file mode 100644 index 0000000..9f4deaa --- /dev/null +++ b/xlb/solver/solver.py @@ -0,0 +1,37 @@ +# Base class for all stepper operators + +from xlb.compute_backends import ComputeBackends +from xlb.operator.boundary_condition import ImplementationStep +from xlb.global_config import GlobalConfig +from xlb.operator import Operator + + +class Solver(Operator): + """ + Abstract class for the construction of lattice boltzmann solvers + """ + + def __init__( + self, + velocity_set=None, + compute_backend=None, + precision_policy=None, + boundary_conditions=[], + ): + # Set parameters + self.velocity_set = velocity_set or GlobalConfig.velocity_set + self.compute_backend = compute_backend or GlobalConfig.compute_backend + self.precision_policy = precision_policy or GlobalConfig.precision_policy + self.boundary_conditions = boundary_conditions + + # Get collision and stream boundary conditions + self.collision_boundary_conditions = {} + self.stream_boundary_conditions = {} + for id_number, bc in enumerate(self.boundary_conditions): + bc_id = id_number + 1 + if bc.implementation_step == ImplementationStep.COLLISION: + self.collision_boundary_conditions[bc_id] = bc + elif bc.implementation_step == ImplementationStep.STREAMING: + self.stream_boundary_conditions[bc_id] = bc + else: + raise ValueError("Boundary condition step not recognized")