Skip to content

Commit

Permalink
example_mehdi works with no error
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Jan 26, 2024
1 parent 9fc7d20 commit a05441f
Show file tree
Hide file tree
Showing 29 changed files with 573 additions and 290 deletions.
37 changes: 37 additions & 0 deletions examples/refactor/example_mehdi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import xlb
from xlb.compute_backends import ComputeBackends
from xlb.precision_policy import Fp32Fp32

from xlb.solver import IncompressibleNavierStokes
from xlb.operator.equilibrium import QuadraticEquilibrium
from xlb.operator.stream import Stream
from xlb.global_config import GlobalConfig
from xlb.grid import Grid
from xlb.operator.initializer import EquilibriumInitializer, ConstInitializer

import numpy as np
import jax.numpy as jnp

xlb.init(precision_policy=Fp32Fp32, compute_backend=ComputeBackends.JAX, velocity_set=xlb.velocity_set.D2Q9)

grid_shape = (100, 100)
grid = Grid.create(grid_shape)

f_init = grid.create_field(cardinality=9, callback=EquilibriumInitializer(grid))

u_init = grid.create_field(cardinality=2, callback=ConstInitializer(grid, cardinality=2, const_value=0.0))
rho_init = grid.create_field(cardinality=1, callback=ConstInitializer(grid, cardinality=1, const_value=1.0))


st = Stream(grid)

f_init = st(f_init)
print("here")
solver = IncompressibleNavierStokes(grid)

num_steps = 100
f = f_init
for step in range(num_steps):
f = solver.step(f, timestep=step)
print(f"Step {step+1}/{num_steps} complete")

40 changes: 40 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import jax.numpy as jnp
from xlb.operator.macroscopic import Macroscopic
from xlb.operator.equilibrium import QuadraticEquilibrium
from xlb.operator.stream import Stream
from xlb.velocity_set import D2Q9, D3Q27
from xlb.operator.collision.kbc import KBC, BGK
from xlb.compute_backends import ComputeBackends
from xlb.grid import Grid
import xlb


xlb.init(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX)

collision = BGK(omega=0.6)

# eq = QuadraticEquilibrium(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX)

# macro = Macroscopic(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX)

# s = Stream(velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX)

Q = 19
# create random jnp arrays
f = jnp.ones((Q, 10, 10))
rho = jnp.ones((1, 10, 10))
u = jnp.zeros((2, 10, 10))
# feq = eq(rho, u)

print(collision(f, f))

grid = Grid.create(grid_shape=(10, 10), velocity_set=D2Q9(), compute_backend=ComputeBackends.JAX)

def advection_result(index):
return 1.0


f = grid.initialize_pop(advection_result)

print(f)
print(f.sharding)
12 changes: 11 additions & 1 deletion xlb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from xlb.compute_backends import ComputeBackends
from xlb.physics_type import PhysicsType


# Config
from .global_config import init


# Precision policy
import xlb.precision_policy

Expand All @@ -15,4 +20,9 @@
import xlb.operator.boundary_condition
# import xlb.operator.force
import xlb.operator.macroscopic
import xlb.operator.stepper

# Grids
import xlb.grid

# Solvers
import xlb.solver
10 changes: 10 additions & 0 deletions xlb/global_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class GlobalConfig:
precision_policy = None
velocity_set = None
compute_backend = None


def init(velocity_set, compute_backend, precision_policy):
GlobalConfig.velocity_set = velocity_set()
GlobalConfig.compute_backend = compute_backend
GlobalConfig.precision_policy = precision_policy()
1 change: 1 addition & 0 deletions xlb/grid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from xlb.grid.grid import Grid
34 changes: 34 additions & 0 deletions xlb/grid/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from xlb.compute_backends import ComputeBackends
from xlb.global_config import GlobalConfig
from xlb.velocity_set import VelocitySet


class Grid(ABC):
def __init__(self, grid_shape, velocity_set, compute_backend):
self.velocity_set: VelocitySet = velocity_set
self.compute_backend = compute_backend
self.grid_shape = grid_shape
self.pop_shape = (self.velocity_set.q, *grid_shape)
self.u_shape = (self.velocity_set.d, *grid_shape)
self.rho_shape = (1, *grid_shape)
self.dim = self.velocity_set.d

@abstractmethod
def create_field(self, cardinality, callback=None):
pass

@staticmethod
def create(grid_shape, velocity_set=None, compute_backend=None):
compute_backend = compute_backend or GlobalConfig.compute_backend
velocity_set = velocity_set or GlobalConfig.velocity_set

if compute_backend == ComputeBackends.JAX:
from xlb.grid.jax_grid import JaxGrid # Avoids circular import

return JaxGrid(grid_shape, velocity_set, compute_backend)
raise ValueError(f"Compute backend {compute_backend} is not supported")

@abstractmethod
def global_to_local_shape(self, shape):
pass
46 changes: 46 additions & 0 deletions xlb/grid/jax_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from xlb.grid.grid import Grid
from xlb.compute_backends import ComputeBackends
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding, Mesh
from jax.experimental import mesh_utils
from xlb.operator.initializer import ConstInitializer
import jax


class JaxGrid(Grid):
def __init__(self, grid_shape, velocity_set, compute_backend):
super().__init__(grid_shape, velocity_set, compute_backend)
self.initialize_jax_backend()

def initialize_jax_backend(self):
self.nDevices = jax.device_count()
self.backend = jax.default_backend()
device_mesh = (
mesh_utils.create_device_mesh((1, self.nDevices, 1))
if self.dim == 2
else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1))
)
self.global_mesh = (
Mesh(device_mesh, axis_names=("cardinality", "x", "y"))
if self.dim == 2
else Mesh(self.devices, axis_names=("cardinality", "x", "y", "z"))
)
self.sharding = (
NamedSharding(self.global_mesh, P("cardinality", "x", "y"))
if self.dim == 2
else NamedSharding(self.global_mesh, P("cardinality", "x", "y", "z"))
)

def global_to_local_shape(self, shape):
if len(shape) < 2:
raise ValueError("Shape must have at least two dimensions")

new_second_index = shape[1] // self.nDevices

return shape[:1] + (new_second_index,) + shape[2:]

def create_field(self, cardinality, callback=None):
if callback is None:
callback = ConstInitializer(self, cardinality, const_value=0.0)
shape = (cardinality,) + (self.grid_shape)
return jax.make_array_from_callback(shape, self.sharding, callback)
6 changes: 3 additions & 3 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ class BGK(Collision):
def __init__(
self,
omega: float,
velocity_set: VelocitySet,
compute_backend=ComputeBackends.JAX,
velocity_set: VelocitySet = None,
compute_backend=None,
):
super().__init__(
omega=omega, velocity_set=velocity_set, compute_backend=compute_backend
)

@Operator.register_backend(ComputeBackends.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation_2(self, f: jnp.ndarray, feq: jnp.ndarray):
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray):
fneq = f - feq
fout = f - self.omega * fneq
return fout
Expand Down
5 changes: 2 additions & 3 deletions xlb/operator/collision/collision.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Base class for Collision operators
"""
from xlb.compute_backends import ComputeBackends
from xlb.velocity_set import VelocitySet
from xlb.operator import Operator

Expand All @@ -23,8 +22,8 @@ class Collision(Operator):
def __init__(
self,
omega: float,
velocity_set: VelocitySet,
compute_backend=ComputeBackends.JAX,
velocity_set: VelocitySet = None,
compute_backend=None,
):
super().__init__(velocity_set, compute_backend)
self.omega = omega
14 changes: 12 additions & 2 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class KBC(Collision):
def __init__(
self,
omega,
velocity_set: VelocitySet,
compute_backend=ComputeBackends.JAX,
velocity_set: VelocitySet = None,
compute_backend=None,
):
super().__init__(
omega=omega, velocity_set=velocity_set, compute_backend=compute_backend
Expand Down Expand Up @@ -75,6 +75,16 @@ def jax_implementation(

return fout

@Operator.register_backend(ComputeBackends.WARP)
@partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3))
def warp_implementation(
self,
f: jnp.ndarray,
feq: jnp.ndarray,
rho: jnp.ndarray,
):
raise NotImplementedError("Warp implementation not yet implemented")

@partial(jit, static_argnums=(0,), inline=True)
def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarray):
"""
Expand Down
5 changes: 2 additions & 3 deletions xlb/operator/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Base class for all equilibriums
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.compute_backends import ComputeBackends
from xlb.operator.operator import Operator


Expand All @@ -11,7 +10,7 @@ class Equilibrium(Operator):

def __init__(
self,
velocity_set: VelocitySet,
compute_backend=ComputeBackends.JAX,
velocity_set: VelocitySet = None,
compute_backend=None,
):
super().__init__(velocity_set, compute_backend)
16 changes: 8 additions & 8 deletions xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xlb.operator.equilibrium.equilibrium import Equilibrium
from functools import partial
from xlb.operator import Operator
from xlb.global_config import GlobalConfig


class QuadraticEquilibrium(Equilibrium):
Expand All @@ -17,21 +18,20 @@ class QuadraticEquilibrium(Equilibrium):

def __init__(
self,
velocity_set: VelocitySet,
compute_backend=ComputeBackends.JAX,
velocity_set: VelocitySet = None,
compute_backend=None,
):
velocity_set = velocity_set or GlobalConfig.velocity_set
compute_backend = compute_backend or GlobalConfig.compute_backend

super().__init__(velocity_set, compute_backend)

@Operator.register_backend(ComputeBackends.JAX)
# @partial(jit, static_argnums=(0), donate_argnums=(1, 2))
@partial(jit, static_argnums=(0), donate_argnums=(1, 2))
def jax_implementation(self, rho, u):
cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0))
usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True)
w = self.velocity_set.w.reshape(-1, 1, 1)

feq = (
rho
* w
* (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
)
feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
return feq
2 changes: 2 additions & 0 deletions xlb/operator/initializer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer
from xlb.operator.initializer.const_init import ConstInitializer
28 changes: 28 additions & 0 deletions xlb/operator/initializer/const_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from xlb.velocity_set import VelocitySet
from xlb.global_config import GlobalConfig
from xlb.compute_backends import ComputeBackends
from xlb.operator.operator import Operator
from xlb.grid.grid import Grid
import numpy as np
import jax


class ConstInitializer(Operator):
def __init__(
self,
grid: Grid,
cardinality,
const_value=0.0,
velocity_set: VelocitySet = None,
compute_backend: ComputeBackends = None,
):
velocity_set = velocity_set or GlobalConfig.velocity_set
compute_backend = compute_backend or GlobalConfig.compute_backend
shape = (cardinality,) + (grid.grid_shape)
self.init_values = np.zeros(grid.global_to_local_shape(shape)) + const_value

super().__init__(velocity_set, compute_backend)

@Operator.register_backend(ComputeBackends.JAX)
def jax_implementation(self, index):
return self.init_values
28 changes: 28 additions & 0 deletions xlb/operator/initializer/equilibrium_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from xlb.velocity_set import VelocitySet
from xlb.global_config import GlobalConfig
from xlb.compute_backends import ComputeBackends
from xlb.operator.operator import Operator
from xlb.grid.grid import Grid
import numpy as np
import jax


class EquilibriumInitializer(Operator):
def __init__(
self,
grid: Grid,
velocity_set: VelocitySet = None,
compute_backend: ComputeBackends = None,
):
velocity_set = velocity_set or GlobalConfig.velocity_set
compute_backend = compute_backend or GlobalConfig.compute_backend
local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1)
self.init_values = np.zeros(
grid.global_to_local_shape(grid.pop_shape)
) + velocity_set.w.reshape(local_shape)

super().__init__(velocity_set, compute_backend)

@Operator.register_backend(ComputeBackends.JAX)
def jax_implementation(self, index):
return self.init_values
Loading

0 comments on commit a05441f

Please sign in to comment.