Skip to content

Commit

Permalink
stepper functional
Browse files Browse the repository at this point in the history
  • Loading branch information
loliverhennigh committed Feb 28, 2024
1 parent 9ebf244 commit 299c3cc
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 55 deletions.
136 changes: 136 additions & 0 deletions examples/interfaces/functional_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Simple Taylor green example using the functional interface to xlb

import time
from tqdm import tqdm
import matplotlib.pyplot as plt

import warp as wp
wp.init()

import xlb
from xlb.operator import Operator

class TaylorGreenInitializer(Operator):

def _construct_warp(self):
# Construct the warp kernel
@wp.kernel
def kernel(
rho: self._warp_array_type,
u: self._warp_array_type,
vel: float,
nr: int,
):
# Get the global index
i, j, k = wp.tid()

# Get real pos
x = 2.0 * wp.pi * wp.float(i) / wp.float(nr)
y = 2.0 * wp.pi * wp.float(j) / wp.float(nr)
z = 2.0 * wp.pi * wp.float(k) / wp.float(nr)

# Compute u
u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z)
u[1, i, j, k] = - vel * wp.cos(x) * wp.sin(y) * wp.cos(z)
u[2, i, j, k] = 0.0

# Compute rho
rho[0, i, j, k] = (
3.0
* vel
* vel
* (1.0 / 16.0)
* (
wp.cos(2.0 * x)
+ (wp.cos(2.0 * y)
* (wp.cos(2.0 * z) + 2.0))
)
+ 1.0
)

return None, kernel

@Operator.register_backend(xlb.ComputeBackend.WARP)
def warp_implementation(self, rho, u, vel, nr):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[
rho,
u,
vel,
nr,
],
dim=rho.shape[1:],
)
return rho, u

if __name__ == "__main__":

# Set parameters
compute_backend = xlb.ComputeBackend.WARP
precision_policy = xlb.PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q19()

# Make feilds
nr = 256
shape = (nr, nr, nr)
grid = xlb.grid.WarpGrid(shape=shape)
rho = grid.create_field(cardinality=1, dtype=wp.float32)
u = grid.create_field(cardinality=velocity_set.d, dtype=wp.float32)
f0 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32)
f1 = grid.create_field(cardinality=velocity_set.q, dtype=wp.float32)
boundary_id = grid.create_field(cardinality=1, dtype=wp.uint8)
mask = grid.create_field(cardinality=velocity_set.q, dtype=wp.bool)

# Make operators
initializer = TaylorGreenInitializer(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
collision = xlb.operator.collision.BGK(
omega=1.0,
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
macroscopic = xlb.operator.macroscopic.Macroscopic(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
stream = xlb.operator.stream.Stream(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper(
collision=collision,
equilibrium=equilibrium,
macroscopic=macroscopic,
stream=stream,
boundary_conditions=[])

# Parrallelize the stepper
#stepper = grid.parallelize_operator(stepper)

# Set initial conditions
rho, u = initializer(rho, u, 0.1, nr)
f0 = equilibrium(rho, u, f0)

# Plot initial conditions
#plt.imshow(f0[0, nr//2, :, :].numpy())
#plt.show()

# Time stepping
num_steps = 1024
start = time.time()
for _ in tqdm(range(num_steps)):
f1 = stepper(f0, f1, boundary_id, mask, _)
f1, f0 = f0, f1
wp.synchronize()
end = time.time()

# Print MLUPS
print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}")
2 changes: 2 additions & 0 deletions xlb/grid/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from xlb.grid.grid import Grid
from xlb.grid.warp_grid import WarpGrid
from xlb.grid.jax_grid import JaxGrid
19 changes: 2 additions & 17 deletions xlb/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,10 @@ class Grid(ABC):
def __init__(
self,
shape : tuple,
velocity_set : VelocitySet,
precision_policy : PrecisionPolicy,
grid_backend : ComputeBackend
):
# Set parameters
self.shape = shape
self.velocity_set = velocity_set
self.precision_policy = precision_policy
self.grid_backend = grid_backend
self.dim = self.velocity_set.d

# Create field dict
self.fields = {}
self.dim = len(shape)

def parallelize_operator(self, operator: Operator):
raise NotImplementedError("Parallelization not implemented, child class must implement")
Expand All @@ -33,10 +24,4 @@ def parallelize_operator(self, operator: Operator):
def create_field(
self, name: str, cardinality: int, precision: Precision, callback=None
):
pass

def get_field(self, name: str):
return self.fields[name]

def swap_fields(self, field1, field2):
self.fields[field1], self.fields[field2] = self.fields[field2], self.fields[field1]
raise NotImplementedError("create_field not implemented, child class must implement")
14 changes: 6 additions & 8 deletions xlb/grid/jax_grid.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from xlb.grid.grid import Grid
from xlb.compute_backends import ComputeBackends
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding, Mesh
from jax.experimental import mesh_utils
from xlb.operator.initializer import ConstInitializer
import jax

from xlb.grid import Grid
from xlb.compute_backend import ComputeBackend
from xlb.operator import Operator

class JaxGrid(Grid):
def __init__(self, grid_shape, velocity_set, precision_policy, grid_backend):
super().__init__(grid_shape, velocity_set, precision_policy, grid_backend)
def __init__(self, shape):
super().__init__(shape)
self._initialize_jax_backend()

def _initialize_jax_backend(self):
Expand Down Expand Up @@ -73,8 +73,6 @@ def _parallel_operator(f):
return f




def create_field(self, name: str, cardinality: int, callback=None):
# Get shape of the field
shape = (cardinality,) + (self.shape)
Expand All @@ -88,4 +86,4 @@ def create_field(self, name: str, cardinality: int, callback=None):
f = jax.make_array_from_callback(shape, self.sharding, callback)

# Add field to the field dictionary
self.fields[name] = f
return f
26 changes: 26 additions & 0 deletions xlb/grid/warp_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import warp as wp

from xlb.grid import Grid
from xlb.operator import Operator

class WarpGrid(Grid):
def __init__(self, shape):
super().__init__(shape)

def parallelize_operator(self, operator: Operator):
# TODO: Implement parallelization of the operator
raise NotImplementedError("Parallelization of the operator is not implemented yet for the WarpGrid")

def create_field(self, cardinality: int, dtype, callback=None):
# Get shape of the field
shape = (cardinality,) + (self.shape)

# Create the field
f = wp.zeros(shape, dtype=dtype)

# Raise error on callback
if callback is not None:
raise ValueError("Callback is not supported in the WarpGrid")

# Add field to the field dictionary
return f
2 changes: 1 addition & 1 deletion xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def kernel(
def warp_implementation(self, rho, u, f):
# Launch the warp kernel
wp.launch(
self._kernel,
self.warp_kernel,
inputs=[
rho,
u,
Expand Down
6 changes: 3 additions & 3 deletions xlb/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, velocity_set, precision_policy, compute_backend):

# Construct the kernel based backend functions TODO: Maybe move this to the register or something
if self.compute_backend == ComputeBackend.WARP:
self._functional, self._kernel = self._construct_warp()
self.warp_functional, self.warp_kernel = self._construct_warp()

@classmethod
def register_backend(cls, backend_name):
Expand Down Expand Up @@ -189,9 +189,9 @@ def _warp_uint8_array_type(self):
Returns the warp type for arrays
"""
if self.velocity_set.d == 2:
return wp.array3d(dtype=wp.bool)
return wp.array3d(dtype=wp.uint8)
elif self.velocity_set.d == 3:
return wp.array4d(dtype=wp.bool)
return wp.array4d(dtype=wp.uint8)

@property
def _warp_bool_array_type(self):
Expand Down
66 changes: 40 additions & 26 deletions xlb/operator/stepper/nse.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _construct_warp(self):
_q = wp.constant(self.velocity_set.q)
_d = wp.constant(self.velocity_set.d)
_nr_boundary_conditions = wp.constant(len(self.boundary_conditions))
print(_q)

# Construct the kernel
@wp.kernel
Expand All @@ -145,58 +146,66 @@ def kernel(
boundary_id: self._warp_uint8_array_type,
mask: self._warp_bool_array_type,
timestep: wp.int32,
max_i: wp.int32,
max_j: wp.int32,
max_k: wp.int32,
):
# Get the global index
i, j, k = wp.tid()

# Get the f, boundary id and mask
_f = self._warp_lattice_vec()
_boundary_id = boundary_id[0, i, j, k]
_mask = self._bool_lattice_vec()
_mask = self._warp_bool_lattice_vec()
for l in range(_q):
_f[l] = self.f_0[l, i, j, k]
_mask[l] = mask[l, i, j, k]
_f[l] = f_0[l, i, j, k]

# TODO fix vec bool
if mask[l, i, j, k]:
_mask[l] = wp.uint8(1)
else:
_mask[l] = wp.uint8(0)

# Compute rho and u
rho, u = self.macroscopic.functional(_f)
rho, u = self.macroscopic.warp_functional(_f)

# Compute equilibrium
feq = self.equilibrium.functional(rho, u)
feq = self.equilibrium.warp_functional(rho, u)

# Apply collision
f_post_collision = self.collision.functional(
f_post_collision = self.collision.warp_functional(
_f,
feq,
rho,
u,
)

# Apply collision type boundary conditions
if _boundary_id == id_number:
f_post_collision = self.collision_boundary_conditions[
id_number
].functional(
_f,
f_post_collision,
_mask,
)
## Apply collision type boundary conditions
#if _boundary_id != wp.uint8(0):
# f_post_collision = self.collision_boundary_conditions[
# _boundary_id
# ].warp_functional(
# _f,
# f_post_collision,
# _mask,
# )
f_pre_streaming = f_post_collision # store pre streaming vector

# Apply forcing
# if self.forcing_op is not None:
# f = self.forcing.functional(f, timestep)
# f = self.forcing.warp_functional(f, timestep)

# Apply streaming
for l in range(_q):
# Get the streamed indices
streamed_i, streamed_j, streamed_k = self.stream.functional(
l, i, j, k, self._warp_max_i, self._warp_max_j, self._warp_max_k
streamed_i, streamed_j, streamed_k = self.stream.warp_functional(
l, i, j, k, max_i, max_j, max_k
)
streamed_l = l

## Modify the streamed indices based on streaming boundary condition
# if _boundary_id != 0:
# streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].functional(
# streamed_l, streamed_i, streamed_j, streamed_k = self.stream_boundary_conditions[id_number].warp_functional(
# streamed_l, streamed_i, streamed_j, streamed_k, self._warp_max_i, self._warp_max_j, self._warp_max_k
# )

Expand All @@ -206,15 +215,20 @@ def kernel(
return None, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, boundary_id, mask, timestep):
def warp_implementation(self, f_0, f_1, boundary_id, mask, timestep):
# Launch the warp kernel
wp.launch(
self._kernel,
self.warp_kernel,
inputs=[
f,
rho,
u,
f_0,
f_1,
boundary_id,
mask,
timestep,
f_0.shape[1],
f_0.shape[2],
f_0.shape[3],
],
dim=rho.shape[1:],
dim=f_0.shape[1:],
)
return rho, u
return f_1
Loading

0 comments on commit 299c3cc

Please sign in to comment.