Skip to content

Commit

Permalink
stepper almost working
Browse files Browse the repository at this point in the history
  • Loading branch information
loliverhennigh committed Feb 22, 2024
1 parent 93a130d commit 05b87bf
Show file tree
Hide file tree
Showing 21 changed files with 240 additions and 171 deletions.
38 changes: 0 additions & 38 deletions examples/warp_backend/equilibrium.py

This file was deleted.

108 changes: 108 additions & 0 deletions examples/warp_backend/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# from IPython import display
import numpy as np
import jax
import jax.numpy as jnp
import scipy
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

import warp as wp
wp.init()

import xlb


def test_backends(compute_backend):

# Set parameters
precision_policy = xlb.PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q27()

# Make operators
collision = xlb.operator.collision.BGK(
omega=1.0,
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
macroscopic = xlb.operator.macroscopic.Macroscopic(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
stream = xlb.operator.stream.Stream(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
bounceback = xlb.operator.boundary_condition.FullBounceBack.from_indices(
indices=np.array([[0, 0, 0], [0, 0, 1]]),
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend)
stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper(
collision=collision,
equilibrium=equilibrium,
macroscopic=macroscopic,
stream=stream,
boundary_conditions=[bounceback])

# Test operators
if compute_backend == xlb.ComputeBackend.WARP:
# Make warp arrays
nr = 128
f_0 = wp.zeros((27, nr, nr, nr), dtype=wp.float32)
f_1 = wp.zeros((27, nr, nr, nr), dtype=wp.float32)
f_out = wp.zeros((27, nr, nr, nr), dtype=wp.float32)
u = wp.zeros((3, nr, nr, nr), dtype=wp.float32)
rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32)
boundary_id = wp.zeros((1, nr, nr, nr), dtype=wp.uint8)
boundary = wp.zeros((1, nr, nr, nr), dtype=wp.bool)
mask = wp.zeros((27, nr, nr, nr), dtype=wp.bool)

# Test operators
collision(f_0, f_1, rho, u, f_out)
equilibrium(rho, u, f_0)
macroscopic(f_0, rho, u)
stream(f_0, f_1)
bounceback(f_0, f_1, f_out, boundary, mask)
#bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1)



elif compute_backend == xlb.ComputeBackend.JAX:
# Make jax arrays
nr = 128
f_0 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32)
f_1 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32)
f_out = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32)
u = jnp.zeros((3, nr, nr, nr), dtype=jnp.float32)
rho = jnp.zeros((1, nr, nr, nr), dtype=jnp.float32)
boundary_id = jnp.zeros((1, nr, nr, nr), dtype=jnp.uint8)
boundary = jnp.zeros((1, nr, nr, nr), dtype=jnp.bool_)
mask = jnp.zeros((27, nr, nr, nr), dtype=jnp.bool_)

# Test operators
collision(f_0, f_1, rho, u)
equilibrium(rho, u)
macroscopic(f_0)
stream(f_0)
bounceback(f_0, f_1, boundary, mask)
bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1)
stepper(f_0, boundary_id, mask, 0)



if __name__ == "__main__":

# Test backends
compute_backends = [
xlb.ComputeBackend.WARP,
xlb.ComputeBackend.JAX
]

for compute_backend in compute_backends:
test_backends(compute_backend)
print(f"Backend {compute_backend} passed all tests.")
1 change: 1 addition & 0 deletions xlb/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from xlb.operator.operator import Operator
from xlb.operator.parallel_operator import ParallelOperator
import xlb.operator.stepper #
26 changes: 26 additions & 0 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,32 @@ def __init__(
# Set boundary masker
self.boundary_masker = boundary_masker

@classmethod
def from_function(
cls,
implementation_step: ImplementationStep,
boundary_function,
velocity_set,
precision_policy,
compute_backend,
):
"""
Create a boundary condition from a function.
"""
# Create boundary mask
boundary_mask = BoundaryMasker.from_function(
boundary_function, velocity_set, precision_policy, compute_backend
)

# Create boundary condition
return cls(
implementation_step,
boundary_mask,
velocity_set,
precision_policy,
compute_backend,
)

@classmethod
def from_indices(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def _indices_to_tuple(indices):
return tuple([indices[:, i] for i in range(indices.shape[1])])

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
#@partial(jit, static_argnums=(0), inline=True) TODO: Fix this
def jax_implementation(self, start_index, boundary_id, mask, id_number):
# Get local indices from the meshgrid and the indices
local_indices = self.indices - start_index
local_indices = self.indices - np.array(start_index)[np.newaxis, :]

# Remove any indices that are out of bounds
local_indices = local_indices[
Expand Down Expand Up @@ -98,3 +98,21 @@ def jax_implementation(self, start_index, boundary_id, mask, id_number):
mask = mask.at[self._indices_to_tuple(local_indices)].set(True)

return boundary_id, mask

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, start_index, boundary_id, mask, id_number):
# Reuse the jax implementation, TODO: implement a warp version
# Convert to jax
boundary_id = wp.jax.to_jax(boundary_id)
mask = wp.jax.to_jax(mask)

# Call jax implementation
boundary_id, mask = self.jax_implementation(
start_index, boundary_id, mask, id_number
)

# Convert back to warp
boundary_id = wp.jax.to_warp(boundary_id)
mask = wp.jax.to_warp(mask)

return boundary_id, mask
21 changes: 15 additions & 6 deletions xlb/operator/boundary_condition/equilibrium_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,24 @@
import numpy as np

from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
from xlb.compute_backend import ComputeBackend
from xlb.operator.stream.stream import Stream
from xlb.operator import Operator
from xlb.operator.equilibrium.equilibrium import Equilibrium
from xlb.operator.boundary_condition.boundary_condition import (
BoundaryCondition,
ImplementationStep,
)
from xlb.operator.boundary_condition.boundary_masker import (
BoundaryMasker,
IndicesBoundaryMasker,
)



class EquilibriumBoundary(BoundaryCondition):
"""
A boundary condition that skips the streaming step.
Equilibrium boundary condition for a lattice Boltzmann method simulation.
"""

def __init__(
Expand All @@ -25,11 +31,13 @@ def __init__(
rho: float,
u: tuple[float, float],
equilibrium: Equilibrium,
boundary_masker: BoundaryMasker,
velocity_set: VelocitySet,
compute_backend: ComputeBackend = ComputeBackend.JAX,
precision_policy: PrecisionPolicy,
compute_backend: ComputeBackend,
):
super().__init__(
set_boundary=set_boundary,
ImplementationStep.COLLISION,
implementation_step=ImplementationStep.STREAMING,
velocity_set=velocity_set,
compute_backend=compute_backend,
Expand All @@ -39,12 +47,13 @@ def __init__(
@classmethod
def from_indices(
cls,
indices,
indices: np.ndarray,
rho: float,
u: tuple[float, float],
equilibrium: Equilibrium,
velocity_set: VelocitySet,
compute_backend: ComputeBackend = ComputeBackend.JAX,
precision_policy: PrecisionPolicy,
compute_backend: ComputeBackend,
):
"""
Creates a boundary condition from a list of indices.
Expand Down
13 changes: 9 additions & 4 deletions xlb/operator/boundary_condition/full_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.lax as lax
from functools import partial
import numpy as np
import warp as wp

from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
Expand All @@ -32,7 +33,7 @@ def __init__(
boundary_masker: BoundaryMasker,
velocity_set: VelocitySet,
precision_policy: PrecisionPolicy,
compute_backend: ComputeBackend.JAX,
compute_backend: ComputeBackend,
):
super().__init__(
ImplementationStep.COLLISION,
Expand Down Expand Up @@ -66,13 +67,12 @@ def from_indices(
@partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4))
def apply_jax(self, f_pre, f_post, boundary, mask):
flip = jnp.repeat(boundary, self.velocity_set.q, axis=0)
print(flip.shape)
flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post)
return flipped_f

def _construct_warp(self):
# Make constants for warp
_opp_indices = wp.constant(self.velocity_set.opp_indices)
_opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices))
_q = wp.constant(self.velocity_set.q)
_d = wp.constant(self.velocity_set.d)

Expand Down Expand Up @@ -107,7 +107,12 @@ def kernel(
for l in range(_q):
_f_pre[l] = f_pre[l, i, j, k]
_f_post[l] = f_post[l, i, j, k]
_mask[l] = mask[l, i, j, k]

# TODO fix vec bool
if mask[l, i, j, k]:
_mask[l] = wp.uint8(1)
else:
_mask[l] = wp.uint8(0)

# Check if the boundary is active
if boundary[i, j, k]:
Expand Down
13 changes: 10 additions & 3 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax.numpy as jnp
from jax import jit
import warp as wp

from xlb.velocity_set import VelocitySet
from xlb.compute_backend import ComputeBackend
from xlb.operator.collision.collision import Collision
Expand All @@ -18,7 +20,7 @@ def jax_implementation(
self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray
):
fneq = f - feq
fout = f - self.omega * fneq
fout = f - self.compute_dtype(self.omega) * fneq
return fout

@Operator.register_backend(ComputeBackend.PALLAS)
Expand All @@ -35,6 +37,7 @@ def _construct_warp(self):
_q = wp.constant(self.velocity_set.q)
_w = wp.constant(self._warp_lattice_vec(self.velocity_set.w))
_d = wp.constant(self.velocity_set.d)
_omega = wp.constant(self.compute_dtype(self.omega))

# Construct the functional
@wp.func
Expand All @@ -45,7 +48,7 @@ def functional(
u: self._warp_u_vec,
) -> self._warp_lattice_vec:
fneq = f - feq
fout = f - self.omega * fneq
fout = f - _omega * fneq
return fout

# Construct the warp kernel
Expand All @@ -66,7 +69,11 @@ def kernel(
for l in range(_q):
_f[l] = f[l, i, j, k]
_feq[l] = feq[l, i, j, k]
_fout = functional(_f, _feq)
_u = self._warp_u_vec()
for l in range(_d):
_u[l] = u[l, i, j, k]
_rho = rho[0, i, j, k]
_fout = functional(_f, _feq, _rho, _u)

# Write the result
for l in range(_q):
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/collision/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def __init__(
precision_policy=None,
compute_backend=None,
):
super().__init__(velocity_set, precision_policy, compute_backend)
self.omega = omega
super().__init__(velocity_set, precision_policy, compute_backend)
3 changes: 0 additions & 3 deletions xlb/operator/initializer/__init__.py

This file was deleted.

Loading

0 comments on commit 05b87bf

Please sign in to comment.