From c4c994bfcf543d753b081c0b9cc5dd282968d07d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 19 Sep 2024 18:16:19 -0400 Subject: [PATCH] Separated zero and first kernels --- xlb/operator/macroscopic/__init__.py | 6 +- xlb/operator/macroscopic/first_moment.py | 83 +++++++++++++++++++++++ xlb/operator/macroscopic/macroscopic.py | 85 ++++++++++++++++++++++++ xlb/operator/macroscopic/zero_moment.py | 69 +++++++++++++++++++ 4 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 xlb/operator/macroscopic/first_moment.py create mode 100644 xlb/operator/macroscopic/macroscopic.py create mode 100644 xlb/operator/macroscopic/zero_moment.py diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 38195cd..75dec9e 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1,2 +1,4 @@ -from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic -from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment +from xlb.operator.macroscopic.macroscopic import Macroscopic +from xlb.operator.macroscopic.second_moment import SecondMoment +from xlb.operator.macroscopic.zero_moment import ZeroMoment +from xlb.operator.macroscopic.first_moment import FirstMoment diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py new file mode 100644 index 0000000..218458e --- /dev/null +++ b/xlb/operator/macroscopic/first_moment.py @@ -0,0 +1,83 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +class FirstMoment(Operator): + """A class to compute the first moment (velocity) of distribution functions.""" + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, f, rho): + u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho + return u + + def _construct_warp(self): + _c = self.velocity_set.c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + + @wp.func + def functional(f: _f_vec, rho: float): + u = _u_vec() + for l in range(self.velocity_set.q): + for d in range(self.velocity_set.d): + if _c[d, l] == 1: + u[d] += f[l] + elif _c[d, l] == -1: + u[d] -= f[l] + u /= rho + return u + + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + _u = functional(_f, _rho) + + for d in range(self.velocity_set.d): + u[d, index[0], index[1], index[2]] = _u[d] + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + i, j = wp.tid() + index = wp.vec2i(i, j) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _rho = rho[0, index[0], index[1]] + _u = functional(_f, _rho) + + for d in range(self.velocity_set.d): + u[d, index[0], index[1]] = _u[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho, u): + wp.launch( + self.warp_kernel, + inputs=[f, rho, u], + dim=u.shape[1:], + ) + return u \ No newline at end of file diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py new file mode 100644 index 0000000..b585b5e --- /dev/null +++ b/xlb/operator/macroscopic/macroscopic.py @@ -0,0 +1,85 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic.zero_moment import ZeroMoment +from xlb.operator.macroscopic.first_moment import FirstMoment + +class Macroscopic(Operator): + """A class to compute both zero and first moments of distribution functions (rho, u).""" + + def __init__(self, *args, **kwargs): + self.zero_moment = ZeroMoment(*args, **kwargs) + self.first_moment = FirstMoment(*args, **kwargs) + super().__init__(*args, **kwargs) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, f): + rho = self.zero_moment(f) + u = self.first_moment(f, rho) + return rho, u + + def _construct_warp(self): + zero_moment_func = self.zero_moment.warp_functional + first_moment_func = self.first_moment.warp_functional + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @wp.func + def functional(f: _f_vec): + rho = zero_moment_func(f) + u = first_moment_func(f, rho) + return rho, u + + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + ): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _rho, _u = functional(_f) + + rho[0, index[0], index[1], index[2]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1], index[2]] = _u[d] + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + u: wp.array3d(dtype=Any), + ): + i, j = wp.tid() + index = wp.vec2i(i, j) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _rho, _u = functional(_f) + + rho[0, index[0], index[1]] = _rho + for d in range(self.velocity_set.d): + u[d, index[0], index[1]] = _u[d] + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho, u): + wp.launch( + self.warp_kernel, + inputs=[f, rho, u], + dim=rho.shape[1:], + ) + return rho, u \ No newline at end of file diff --git a/xlb/operator/macroscopic/zero_moment.py b/xlb/operator/macroscopic/zero_moment.py new file mode 100644 index 0000000..a37ede7 --- /dev/null +++ b/xlb/operator/macroscopic/zero_moment.py @@ -0,0 +1,69 @@ +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator + +class ZeroMoment(Operator): + """A class to compute the zeroth moment (density) of distribution functions.""" + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0), inline=True) + def jax_implementation(self, f): + return jnp.sum(f, axis=0, keepdims=True) + + def _construct_warp(self): + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @wp.func + def functional(f: _f_vec): + rho = self.compute_dtype(0.0) + for l in range(self.velocity_set.q): + rho += f[l] + return rho + + @wp.kernel + def kernel3d( + f: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + ): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _rho = functional(_f) + + rho[0, index[0], index[1], index[2]] = _rho + + @wp.kernel + def kernel2d( + f: wp.array3d(dtype=Any), + rho: wp.array3d(dtype=Any), + ): + i, j = wp.tid() + index = wp.vec2i(i, j) + + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1]] + _rho = functional(_f) + + rho[0, index[0], index[1]] = _rho + + kernel = kernel3d if self.velocity_set.d == 3 else kernel2d + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f, rho): + wp.launch( + self.warp_kernel, + inputs=[f, rho], + dim=rho.shape[1:], + ) + return rho \ No newline at end of file