From 05b87bf72d146721f48954f87367d36e5b838860 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 22 Feb 2024 11:53:03 -0800 Subject: [PATCH] stepper almost working --- examples/warp_backend/equilibrium.py | 38 ------ examples/warp_backend/testing.py | 108 ++++++++++++++++++ xlb/operator/__init__.py | 1 + .../boundary_condition/boundary_condition.py | 26 +++++ .../indices_boundary_masker.py | 22 +++- .../equilibrium_boundary.py | 21 +++- .../boundary_condition/full_bounce_back.py | 13 ++- xlb/operator/collision/bgk.py | 13 ++- xlb/operator/collision/collision.py | 2 +- xlb/operator/initializer/__init__.py | 3 - xlb/operator/initializer/const_init.py | 36 ------ xlb/operator/initializer/equilibrium_init.py | 33 ------ xlb/operator/initializer/initializer.py | 13 --- xlb/operator/operator.py | 16 ++- xlb/operator/precision_caster/__init__.py | 1 + .../precision_caster/precision_caster.py | 13 +-- xlb/operator/stepper/__init__.py | 2 + xlb/operator/stepper/nse.py | 16 +-- xlb/operator/stepper/stepper.py | 30 ++--- xlb/operator/stream/stream.py | 3 +- xlb/operator/test/test.py | 1 + 21 files changed, 240 insertions(+), 171 deletions(-) delete mode 100644 examples/warp_backend/equilibrium.py create mode 100644 examples/warp_backend/testing.py delete mode 100644 xlb/operator/initializer/__init__.py delete mode 100644 xlb/operator/initializer/const_init.py delete mode 100644 xlb/operator/initializer/equilibrium_init.py delete mode 100644 xlb/operator/initializer/initializer.py create mode 100644 xlb/operator/precision_caster/__init__.py create mode 100644 xlb/operator/stepper/__init__.py create mode 100644 xlb/operator/test/test.py diff --git a/examples/warp_backend/equilibrium.py b/examples/warp_backend/equilibrium.py deleted file mode 100644 index a99ace4..0000000 --- a/examples/warp_backend/equilibrium.py +++ /dev/null @@ -1,38 +0,0 @@ -# from IPython import display -import numpy as np -import jax -import jax.numpy as jnp -import scipy -import time -from tqdm import tqdm -import matplotlib.pyplot as plt - -import warp as wp -wp.init() - -import xlb - -if __name__ == "__main__": - - # Make operator - precision_policy = xlb.PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D3Q27() - compute_backend = xlb.ComputeBackend.WARP - equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - - # Make warp arrays - nr = 128 - f = wp.zeros((27, nr, nr, nr), dtype=wp.float32) - u = wp.zeros((3, nr, nr, nr), dtype=wp.float32) - rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32) - - # Run simulation - equilibrium(rho, u, f) - macroscopic(f, rho, u) diff --git a/examples/warp_backend/testing.py b/examples/warp_backend/testing.py new file mode 100644 index 0000000..3940378 --- /dev/null +++ b/examples/warp_backend/testing.py @@ -0,0 +1,108 @@ +# from IPython import display +import numpy as np +import jax +import jax.numpy as jnp +import scipy +import time +from tqdm import tqdm +import matplotlib.pyplot as plt + +import warp as wp +wp.init() + +import xlb + + +def test_backends(compute_backend): + + # Set parameters + precision_policy = xlb.PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q27() + + # Make operators + collision = xlb.operator.collision.BGK( + omega=1.0, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stream = xlb.operator.stream.Stream( + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + bounceback = xlb.operator.boundary_condition.FullBounceBack.from_indices( + indices=np.array([[0, 0, 0], [0, 0, 1]]), + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend) + stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=collision, + equilibrium=equilibrium, + macroscopic=macroscopic, + stream=stream, + boundary_conditions=[bounceback]) + + # Test operators + if compute_backend == xlb.ComputeBackend.WARP: + # Make warp arrays + nr = 128 + f_0 = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + f_1 = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + f_out = wp.zeros((27, nr, nr, nr), dtype=wp.float32) + u = wp.zeros((3, nr, nr, nr), dtype=wp.float32) + rho = wp.zeros((1, nr, nr, nr), dtype=wp.float32) + boundary_id = wp.zeros((1, nr, nr, nr), dtype=wp.uint8) + boundary = wp.zeros((1, nr, nr, nr), dtype=wp.bool) + mask = wp.zeros((27, nr, nr, nr), dtype=wp.bool) + + # Test operators + collision(f_0, f_1, rho, u, f_out) + equilibrium(rho, u, f_0) + macroscopic(f_0, rho, u) + stream(f_0, f_1) + bounceback(f_0, f_1, f_out, boundary, mask) + #bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1) + + + + elif compute_backend == xlb.ComputeBackend.JAX: + # Make jax arrays + nr = 128 + f_0 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) + f_1 = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) + f_out = jnp.zeros((27, nr, nr, nr), dtype=jnp.float32) + u = jnp.zeros((3, nr, nr, nr), dtype=jnp.float32) + rho = jnp.zeros((1, nr, nr, nr), dtype=jnp.float32) + boundary_id = jnp.zeros((1, nr, nr, nr), dtype=jnp.uint8) + boundary = jnp.zeros((1, nr, nr, nr), dtype=jnp.bool_) + mask = jnp.zeros((27, nr, nr, nr), dtype=jnp.bool_) + + # Test operators + collision(f_0, f_1, rho, u) + equilibrium(rho, u) + macroscopic(f_0) + stream(f_0) + bounceback(f_0, f_1, boundary, mask) + bounceback.boundary_masker((0, 0, 0), boundary_id, mask, 1) + stepper(f_0, boundary_id, mask, 0) + + + +if __name__ == "__main__": + + # Test backends + compute_backends = [ + xlb.ComputeBackend.WARP, + xlb.ComputeBackend.JAX + ] + + for compute_backend in compute_backends: + test_backends(compute_backend) + print(f"Backend {compute_backend} passed all tests.") diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 02b8a59..501a7af 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,2 +1,3 @@ from xlb.operator.operator import Operator from xlb.operator.parallel_operator import ParallelOperator +import xlb.operator.stepper # diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 42fa5d4..95b1265 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -45,6 +45,32 @@ def __init__( # Set boundary masker self.boundary_masker = boundary_masker + @classmethod + def from_function( + cls, + implementation_step: ImplementationStep, + boundary_function, + velocity_set, + precision_policy, + compute_backend, + ): + """ + Create a boundary condition from a function. + """ + # Create boundary mask + boundary_mask = BoundaryMasker.from_function( + boundary_function, velocity_set, precision_policy, compute_backend + ) + + # Create boundary condition + return cls( + implementation_step, + boundary_mask, + velocity_set, + precision_policy, + compute_backend, + ) + @classmethod def from_indices( cls, diff --git a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py index 3b57895..fdf8ced 100644 --- a/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_condition/boundary_masker/indices_boundary_masker.py @@ -47,10 +47,10 @@ def _indices_to_tuple(indices): return tuple([indices[:, i] for i in range(indices.shape[1])]) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0), inline=True) + #@partial(jit, static_argnums=(0), inline=True) TODO: Fix this def jax_implementation(self, start_index, boundary_id, mask, id_number): # Get local indices from the meshgrid and the indices - local_indices = self.indices - start_index + local_indices = self.indices - np.array(start_index)[np.newaxis, :] # Remove any indices that are out of bounds local_indices = local_indices[ @@ -98,3 +98,21 @@ def jax_implementation(self, start_index, boundary_id, mask, id_number): mask = mask.at[self._indices_to_tuple(local_indices)].set(True) return boundary_id, mask + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, start_index, boundary_id, mask, id_number): + # Reuse the jax implementation, TODO: implement a warp version + # Convert to jax + boundary_id = wp.jax.to_jax(boundary_id) + mask = wp.jax.to_jax(mask) + + # Call jax implementation + boundary_id, mask = self.jax_implementation( + start_index, boundary_id, mask, id_number + ) + + # Convert back to warp + boundary_id = wp.jax.to_warp(boundary_id) + mask = wp.jax.to_warp(mask) + + return boundary_id, mask diff --git a/xlb/operator/boundary_condition/equilibrium_boundary.py b/xlb/operator/boundary_condition/equilibrium_boundary.py index fbc5418..4b47980 100644 --- a/xlb/operator/boundary_condition/equilibrium_boundary.py +++ b/xlb/operator/boundary_condition/equilibrium_boundary.py @@ -5,18 +5,24 @@ import numpy as np from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend -from xlb.operator.stream.stream import Stream +from xlb.operator import Operator from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator.boundary_condition.boundary_condition import ( BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_condition.boundary_masker import ( + BoundaryMasker, + IndicesBoundaryMasker, +) + class EquilibriumBoundary(BoundaryCondition): """ - A boundary condition that skips the streaming step. + Equilibrium boundary condition for a lattice Boltzmann method simulation. """ def __init__( @@ -25,11 +31,13 @@ def __init__( rho: float, u: tuple[float, float], equilibrium: Equilibrium, + boundary_masker: BoundaryMasker, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, ): super().__init__( - set_boundary=set_boundary, + ImplementationStep.COLLISION, implementation_step=ImplementationStep.STREAMING, velocity_set=velocity_set, compute_backend=compute_backend, @@ -39,12 +47,13 @@ def __init__( @classmethod def from_indices( cls, - indices, + indices: np.ndarray, rho: float, u: tuple[float, float], equilibrium: Equilibrium, velocity_set: VelocitySet, - compute_backend: ComputeBackend = ComputeBackend.JAX, + precision_policy: PrecisionPolicy, + compute_backend: ComputeBackend, ): """ Creates a boundary condition from a list of indices. diff --git a/xlb/operator/boundary_condition/full_bounce_back.py b/xlb/operator/boundary_condition/full_bounce_back.py index 91fda7f..ed0ec5a 100644 --- a/xlb/operator/boundary_condition/full_bounce_back.py +++ b/xlb/operator/boundary_condition/full_bounce_back.py @@ -7,6 +7,7 @@ import jax.lax as lax from functools import partial import numpy as np +import warp as wp from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -32,7 +33,7 @@ def __init__( boundary_masker: BoundaryMasker, velocity_set: VelocitySet, precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.JAX, + compute_backend: ComputeBackend, ): super().__init__( ImplementationStep.COLLISION, @@ -66,13 +67,12 @@ def from_indices( @partial(jit, static_argnums=(0), donate_argnums=(1, 2, 3, 4)) def apply_jax(self, f_pre, f_post, boundary, mask): flip = jnp.repeat(boundary, self.velocity_set.q, axis=0) - print(flip.shape) flipped_f = lax.select(flip, f_pre[self.velocity_set.opp_indices, ...], f_post) return flipped_f def _construct_warp(self): # Make constants for warp - _opp_indices = wp.constant(self.velocity_set.opp_indices) + _opp_indices = wp.constant(self._warp_int_lattice_vec(self.velocity_set.opp_indices)) _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) @@ -107,7 +107,12 @@ def kernel( for l in range(_q): _f_pre[l] = f_pre[l, i, j, k] _f_post[l] = f_post[l, i, j, k] - _mask[l] = mask[l, i, j, k] + + # TODO fix vec bool + if mask[l, i, j, k]: + _mask[l] = wp.uint8(1) + else: + _mask[l] = wp.uint8(0) # Check if the boundary is active if boundary[i, j, k]: diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 5052a70..4071345 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,5 +1,7 @@ import jax.numpy as jnp from jax import jit +import warp as wp + from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision @@ -18,7 +20,7 @@ def jax_implementation( self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray ): fneq = f - feq - fout = f - self.omega * fneq + fout = f - self.compute_dtype(self.omega) * fneq return fout @Operator.register_backend(ComputeBackend.PALLAS) @@ -35,6 +37,7 @@ def _construct_warp(self): _q = wp.constant(self.velocity_set.q) _w = wp.constant(self._warp_lattice_vec(self.velocity_set.w)) _d = wp.constant(self.velocity_set.d) + _omega = wp.constant(self.compute_dtype(self.omega)) # Construct the functional @wp.func @@ -45,7 +48,7 @@ def functional( u: self._warp_u_vec, ) -> self._warp_lattice_vec: fneq = f - feq - fout = f - self.omega * fneq + fout = f - _omega * fneq return fout # Construct the warp kernel @@ -66,7 +69,11 @@ def kernel( for l in range(_q): _f[l] = f[l, i, j, k] _feq[l] = feq[l, i, j, k] - _fout = functional(_f, _feq) + _u = self._warp_u_vec() + for l in range(_d): + _u[l] = u[l, i, j, k] + _rho = rho[0, i, j, k] + _fout = functional(_f, _feq, _rho, _u) # Write the result for l in range(_q): diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 1fe0a5b..acf4538 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -26,5 +26,5 @@ def __init__( precision_policy=None, compute_backend=None, ): - super().__init__(velocity_set, precision_policy, compute_backend) self.omega = omega + super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/initializer/__init__.py b/xlb/operator/initializer/__init__.py deleted file mode 100644 index 026bee9..0000000 --- a/xlb/operator/initializer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from xlb.operator.initializer.initializer import Initializer -from xlb.operator.initializer.equilibrium_init import EquilibriumInitializer -from xlb.operator.initializer.const_init import ConstInitializer diff --git a/xlb/operator/initializer/const_init.py b/xlb/operator/initializer/const_init.py deleted file mode 100644 index 17235cd..0000000 --- a/xlb/operator/initializer/const_init.py +++ /dev/null @@ -1,36 +0,0 @@ -from xlb.velocity_set import VelocitySet -from xlb.global_config import GlobalConfig -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.grid.grid import Grid -from functools import partial -import numpy as np -import jax - - -class ConstInitializer(Operator): - def __init__( - self, - type=np.float32, - velocity_set: VelocitySet = None, - compute_backend: ComputeBackends = None, - ): - self.type = type - self.grid = grid - velocity_set = velocity_set or GlobalConfig.velocity_set - compute_backend = compute_backend or GlobalConfig.compute_backend - - super().__init__(velocity_set, compute_backend) - - @Operator.register_backend(ComputeBackends.JAX) - @partial(jax.jit, static_argnums=(0, 2)) - def jax_implementation(self, const_value, sharding=None): - if sharding is None: - sharding = self.grid.sharding - x = jax.numpy.full(shape=self.shape, fill_value=const_value, dtype=self.type) - return jax.lax.with_sharding_constraint(x, sharding) - - @Operator.register_backend(ComputeBackends.PALLAS) - @partial(jax.jit, static_argnums=(0, 2)) - def pallas_implementation(self, const_value, sharding=None): - return self.jax_implementation(const_value, sharding) diff --git a/xlb/operator/initializer/equilibrium_init.py b/xlb/operator/initializer/equilibrium_init.py deleted file mode 100644 index 5d96fbb..0000000 --- a/xlb/operator/initializer/equilibrium_init.py +++ /dev/null @@ -1,33 +0,0 @@ -from xlb.velocity_set import VelocitySet -from xlb.global_config import GlobalConfig -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.grid.grid import Grid -import numpy as np -import jax - - -class EquilibriumInitializer(Operator): - def __init__( - self, - grid: Grid, - velocity_set: VelocitySet = None, - compute_backend: ComputeBackends = None, - ): - velocity_set = velocity_set or GlobalConfig.velocity_set - compute_backend = compute_backend or GlobalConfig.compute_backend - local_shape = (-1,) + (1,) * (len(grid.pop_shape) - 1) - - self.init_values = np.zeros( - grid.field_global_to_local_shape(grid.pop_shape) - ) + velocity_set.w.reshape(local_shape) - - super().__init__(velocity_set, compute_backend) - - @Operator.register_backend(ComputeBackends.JAX) - def jax_implementation(self, index): - return self.init_values - - @Operator.register_backend(ComputeBackends.PALLAS) - def jax_implementation(self, index): - return self.init_values diff --git a/xlb/operator/initializer/initializer.py b/xlb/operator/initializer/initializer.py deleted file mode 100644 index e813934..0000000 --- a/xlb/operator/initializer/initializer.py +++ /dev/null @@ -1,13 +0,0 @@ -from xlb.velocity_set import VelocitySet -from xlb.global_config import GlobalConfig -from xlb.compute_backends import ComputeBackends -from xlb.operator.operator import Operator -from xlb.grid.grid import Grid -import numpy as np -import jax - - -class Initializer(Operator): - """ - Base class for all initializers. - """ diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 87c6f15..81b3035 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -92,7 +92,7 @@ def backend(self): This should be used with caution as all backends may not have the same API. """ if self.compute_backend == ComputeBackend.JAX: - import jax as backend + import jax.numpy as backend elif self.compute_backend == ComputeBackend.WARP: import warp as backend return backend @@ -152,7 +152,8 @@ def _warp_bool_lattice_vec(self): """ Returns the warp type for the streaming matrix (c) """ - return wp.vec(self.velocity_set.q, dtype=wp.bool) + #return wp.vec(self.velocity_set.q, dtype=wp.bool) + return wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO bool breaks @property def _warp_stream_mat(self): @@ -163,6 +164,15 @@ def _warp_stream_mat(self): (self.velocity_set.d, self.velocity_set.q), dtype=self.compute_dtype ) + @property + def _warp_int_stream_mat(self): + """ + Returns the warp type for the streaming matrix (c) + """ + return wp.mat( + (self.velocity_set.d, self.velocity_set.q), dtype=wp.int32 + ) + @property def _warp_array_type(self): """ @@ -199,4 +209,4 @@ def _construct_warp(self): TODO: Maybe a better way to do this? Maybe add this to the backend decorator? """ - raise NotImplementedError("Children must implement this method") + return None, None diff --git a/xlb/operator/precision_caster/__init__.py b/xlb/operator/precision_caster/__init__.py new file mode 100644 index 0000000..a027c52 --- /dev/null +++ b/xlb/operator/precision_caster/__init__.py @@ -0,0 +1 @@ +from xlb.operator.precision_caster.precision_caster import PrecisionCaster diff --git a/xlb/operator/precision_caster/precision_caster.py b/xlb/operator/precision_caster/precision_caster.py index be676f4..cb441c5 100644 --- a/xlb/operator/precision_caster/precision_caster.py +++ b/xlb/operator/precision_caster/precision_caster.py @@ -16,7 +16,7 @@ class PrecisionCaster(Operator): """ - Class that handles the construction of lattice boltzmann precision casting operator + Class that handles the construction of lattice boltzmann precision casting operator. """ def __init__( @@ -84,15 +84,14 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout): + def warp_implementation(self, from_f, to_f): # Launch the warp kernel wp.launch( self._kernel, inputs=[ - f, - feq, - fout, + from_f, + to_f, ], - dim=f.shape[1:], + dim=from_f.shape[1:], ) - return fout + return to_f diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py new file mode 100644 index 0000000..44ff137 --- /dev/null +++ b/xlb/operator/stepper/__init__.py @@ -0,0 +1,2 @@ +from xlb.operator.stepper.stepper import Stepper +from xlb.operator.stepper.nse import IncompressibleNavierStokesStepper diff --git a/xlb/operator/stepper/nse.py b/xlb/operator/stepper/nse.py index ebb9e13..55a7703 100644 --- a/xlb/operator/stepper/nse.py +++ b/xlb/operator/stepper/nse.py @@ -1,14 +1,16 @@ # Base class for all stepper operators +from logging import warning from functools import partial from jax import jit -from logging import warning +import warp as wp -from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend -from xlb.operator.stepper.stepper import Stepper +from xlb.operator import Operator +from xlb.operator.stepper import Stepper from xlb.operator.boundary_condition import ImplementationStep -from xlb.operator.collision.bgk import BGK +from xlb.operator.collision import BGK class IncompressibleNavierStokesStepper(Stepper): @@ -71,7 +73,7 @@ def apply_jax(self, f, boundary_id, mask, timestep): return f - @Operator.register_backend(ComputeBackends.PALLAS) + @Operator.register_backend(ComputeBackend.PALLAS) @partial(jit, static_argnums=(0,)) def apply_pallas(self, fin, boundary_id, mask, timestep): # Raise warning that the boundary conditions are not implemented @@ -141,7 +143,7 @@ def kernel( f_0: self._warp_array_type, f_1: self._warp_array_type, boundary_id: self._warp_uint8_array_type, - mask: self._warp_array_bool_array_type, + mask: self._warp_bool_array_type, timestep: wp.int32, ): # Get the global index @@ -201,7 +203,7 @@ def kernel( # Set the output f_1[streamed_l, streamed_i, streamed_j, streamed_k] = f_pre_streaming[l] - return functional, kernel + return None, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f, boundary_id, mask, timestep): diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index 48c62a8..b65924b 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,10 +1,13 @@ # Base class for all stepper operators +from functools import partial import jax.numpy as jnp +from jax import jit +import warp as wp -from xlb.velocity_set.velocity_set import VelocitySet +from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator +from xlb.operator import Operator from xlb.operator.boundary_condition import ImplementationStep from xlb.operator.precision_caster import PrecisionCaster @@ -31,6 +34,15 @@ def __init__( self.boundary_conditions = boundary_conditions self.forcing = forcing + # Get all operators for checking + self.operators = [ + collision, + stream, + equilibrium, + macroscopic, + *boundary_conditions, + ] + # Get velocity set, precision policy, and compute backend velocity_sets = set([op.velocity_set for op in self.operators]) assert len(velocity_sets) == 1, "All velocity sets must be the same" @@ -55,8 +67,7 @@ def __init__( raise ValueError("Boundary condition step not recognized") # Make operators for converting the precisions - self.cast_to_compute = PrecisionCaster( - + #self.cast_to_compute = PrecisionCaster( # Make operator for setting boundary condition arrays self.set_boundary = SetBoundary( @@ -66,16 +77,7 @@ def __init__( precision_policy, compute_backend, ) - - # Get all operators for checking - self.operators = [ - collision, - stream, - equilibrium, - macroscopic, - *boundary_conditions, - self.set_boundary, - ] + self.operators.append(self.set_boundary) # Initialize operator super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index e6a46c6..ebb932b 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -3,6 +3,7 @@ from functools import partial import jax.numpy as jnp from jax import jit, vmap +import warp as wp from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend @@ -51,7 +52,7 @@ def _streaming_jax_i(f, c): def _construct_warp(self): # Make constants for warp - _c = wp.constant(self._warp_stream_mat(self.velocity_set.c)) + _c = wp.constant(self._warp_int_stream_mat(self.velocity_set.c)) _q = wp.constant(self.velocity_set.q) _d = wp.constant(self.velocity_set.d) diff --git a/xlb/operator/test/test.py b/xlb/operator/test/test.py new file mode 100644 index 0000000..7d4290a --- /dev/null +++ b/xlb/operator/test/test.py @@ -0,0 +1 @@ +x = 1