Skip to content

Commit

Permalink
Merge pull request #65 from mehdiataei/major-refactoring
Browse files Browse the repository at this point in the history
Separated zero and first moment kernels
  • Loading branch information
hsalehipour authored Sep 20, 2024
2 parents 3ff0814 + 448883f commit 834da9a
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 68 deletions.
6 changes: 4 additions & 2 deletions xlb/operator/macroscopic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic
from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment
from xlb.operator.macroscopic.macroscopic import Macroscopic
from xlb.operator.macroscopic.second_moment import SecondMoment
from xlb.operator.macroscopic.zero_moment import ZeroMoment
from xlb.operator.macroscopic.first_moment import FirstMoment
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Base class for all equilibriums

from functools import partial
import jax.numpy as jnp
from jax import jit
Expand All @@ -10,82 +8,50 @@
from xlb.operator.operator import Operator


class ZeroAndFirstMoments(Operator):
"""
A class to compute first and zeroth moments of distribution functions.
TODO: Currently this is only used for the standard rho and u moments.
In the future, this should be extended to include higher order moments
and other physic types (e.g. temperature, electromagnetism, etc...)
"""
class FirstMoment(Operator):
"""A class to compute the first moment (velocity) of distribution functions."""

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f):
"""
Apply the macroscopic operator to the lattice distribution function
TODO: Check if the following implementation is more efficient (
as the compiler may be able to remove operations resulting in zero)
c_x = tuple(self.velocity_set.c[0])
c_y = tuple(self.velocity_set.c[1])
u_x = 0.0
u_y = 0.0
rho = jnp.sum(f, axis=0, keepdims=True)
for i in range(self.velocity_set.q):
u_x += c_x[i] * f[i, ...]
u_y += c_y[i] * f[i, ...]
return rho, jnp.stack((u_x, u_y))
"""
rho = jnp.sum(f, axis=0, keepdims=True)
def jax_implementation(self, f, rho):
u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho

return rho, u
return u

def _construct_warp(self):
# Make constants for warp
_c = self.velocity_set.c
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)

# Construct the functional
@wp.func
def functional(f: _f_vec):
# Compute rho and u
rho = self.compute_dtype(0.0)
def functional(
f: _f_vec,
rho: Any,
):
u = _u_vec()
for l in range(self.velocity_set.q):
rho += f[l]
for d in range(self.velocity_set.d):
if _c[d, l] == 1:
u[d] += f[l]
elif _c[d, l] == -1:
u[d] -= f[l]
u /= rho
return u

return rho, u

# Construct the kernel
@wp.kernel
def kernel3d(
f: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the equilibrium
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
(_rho, _u) = functional(_f)
_rho = rho[0, index[0], index[1], index[2]]
_u = functional(_f, _rho)

# Set the output
rho[0, index[0], index[1], index[2]] = _rho
for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]

Expand All @@ -95,18 +61,15 @@ def kernel2d(
rho: wp.array3d(dtype=Any),
u: wp.array3d(dtype=Any),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the equilibrium
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
(_rho, _u) = functional(_f)
_rho = rho[0, index[0], index[1]]
_u = functional(_f, _rho)

# Set the output
rho[0, index[0], index[1]] = _rho
for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]

Expand All @@ -116,14 +79,9 @@ def kernel2d(

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, rho, u):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[
f,
rho,
u,
],
dim=rho.shape[1:],
inputs=[f, rho, u],
dim=u.shape[1:],
)
return rho, u
return u
86 changes: 86 additions & 0 deletions xlb/operator/macroscopic/macroscopic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from functools import partial
import jax.numpy as jnp
from jax import jit
import warp as wp
from typing import Any

from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.macroscopic.zero_moment import ZeroMoment
from xlb.operator.macroscopic.first_moment import FirstMoment


class Macroscopic(Operator):
"""A class to compute both zero and first moments of distribution functions (rho, u)."""

def __init__(self, *args, **kwargs):
self.zero_moment = ZeroMoment(*args, **kwargs)
self.first_moment = FirstMoment(*args, **kwargs)
super().__init__(*args, **kwargs)

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f):
rho = self.zero_moment(f)
u = self.first_moment(f, rho)
return rho, u

def _construct_warp(self):
zero_moment_func = self.zero_moment.warp_functional
first_moment_func = self.first_moment.warp_functional
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

@wp.func
def functional(f: _f_vec):
rho = zero_moment_func(f)
u = first_moment_func(f, rho)
return rho, u

@wp.kernel
def kernel3d(
f: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
):
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_rho, _u = functional(_f)

rho[0, index[0], index[1], index[2]] = _rho
for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]

@wp.kernel
def kernel2d(
f: wp.array3d(dtype=Any),
rho: wp.array3d(dtype=Any),
u: wp.array3d(dtype=Any),
):
i, j = wp.tid()
index = wp.vec2i(i, j)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
_rho, _u = functional(_f)

rho[0, index[0], index[1]] = _rho
for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, rho, u):
wp.launch(
self.warp_kernel,
inputs=[f, rho, u],
dim=rho.shape[1:],
)
return rho, u
9 changes: 1 addition & 8 deletions xlb/operator/macroscopic/second_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,5 @@ def kernel2d(
@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, pi):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[
f,
pi,
],
dim=pi.shape[1:],
)
wp.launch(self.warp_kernel, inputs=[f, pi], dim=pi.shape[1:])
return pi
66 changes: 66 additions & 0 deletions xlb/operator/macroscopic/zero_moment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from functools import partial
import jax.numpy as jnp
from jax import jit
import warp as wp
from typing import Any

from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator


class ZeroMoment(Operator):
"""A class to compute the zeroth moment (density) of distribution functions."""

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f):
return jnp.sum(f, axis=0, keepdims=True)

def _construct_warp(self):
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

@wp.func
def functional(f: _f_vec):
rho = self.compute_dtype(0.0)
for l in range(self.velocity_set.q):
rho += f[l]
return rho

@wp.kernel
def kernel3d(
f: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
):
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_rho = functional(_f)

rho[0, index[0], index[1], index[2]] = _rho

@wp.kernel
def kernel2d(
f: wp.array3d(dtype=Any),
rho: wp.array3d(dtype=Any),
):
i, j = wp.tid()
index = wp.vec2i(i, j)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
_rho = functional(_f)

rho[0, index[0], index[1]] = _rho

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, rho):
wp.launch(self.warp_kernel, inputs=[f, rho], dim=rho.shape[1:])
return rho

0 comments on commit 834da9a

Please sign in to comment.