-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from mehdiataei/major-refactoring
Separated zero and first moment kernels
- Loading branch information
Showing
5 changed files
with
173 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic | ||
from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment | ||
from xlb.operator.macroscopic.macroscopic import Macroscopic | ||
from xlb.operator.macroscopic.second_moment import SecondMoment | ||
from xlb.operator.macroscopic.zero_moment import ZeroMoment | ||
from xlb.operator.macroscopic.first_moment import FirstMoment |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from functools import partial | ||
import jax.numpy as jnp | ||
from jax import jit | ||
import warp as wp | ||
from typing import Any | ||
|
||
from xlb.compute_backend import ComputeBackend | ||
from xlb.operator.operator import Operator | ||
from xlb.operator.macroscopic.zero_moment import ZeroMoment | ||
from xlb.operator.macroscopic.first_moment import FirstMoment | ||
|
||
|
||
class Macroscopic(Operator): | ||
"""A class to compute both zero and first moments of distribution functions (rho, u).""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
self.zero_moment = ZeroMoment(*args, **kwargs) | ||
self.first_moment = FirstMoment(*args, **kwargs) | ||
super().__init__(*args, **kwargs) | ||
|
||
@Operator.register_backend(ComputeBackend.JAX) | ||
@partial(jit, static_argnums=(0), inline=True) | ||
def jax_implementation(self, f): | ||
rho = self.zero_moment(f) | ||
u = self.first_moment(f, rho) | ||
return rho, u | ||
|
||
def _construct_warp(self): | ||
zero_moment_func = self.zero_moment.warp_functional | ||
first_moment_func = self.first_moment.warp_functional | ||
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) | ||
|
||
@wp.func | ||
def functional(f: _f_vec): | ||
rho = zero_moment_func(f) | ||
u = first_moment_func(f, rho) | ||
return rho, u | ||
|
||
@wp.kernel | ||
def kernel3d( | ||
f: wp.array4d(dtype=Any), | ||
rho: wp.array4d(dtype=Any), | ||
u: wp.array4d(dtype=Any), | ||
): | ||
i, j, k = wp.tid() | ||
index = wp.vec3i(i, j, k) | ||
|
||
_f = _f_vec() | ||
for l in range(self.velocity_set.q): | ||
_f[l] = f[l, index[0], index[1], index[2]] | ||
_rho, _u = functional(_f) | ||
|
||
rho[0, index[0], index[1], index[2]] = _rho | ||
for d in range(self.velocity_set.d): | ||
u[d, index[0], index[1], index[2]] = _u[d] | ||
|
||
@wp.kernel | ||
def kernel2d( | ||
f: wp.array3d(dtype=Any), | ||
rho: wp.array3d(dtype=Any), | ||
u: wp.array3d(dtype=Any), | ||
): | ||
i, j = wp.tid() | ||
index = wp.vec2i(i, j) | ||
|
||
_f = _f_vec() | ||
for l in range(self.velocity_set.q): | ||
_f[l] = f[l, index[0], index[1]] | ||
_rho, _u = functional(_f) | ||
|
||
rho[0, index[0], index[1]] = _rho | ||
for d in range(self.velocity_set.d): | ||
u[d, index[0], index[1]] = _u[d] | ||
|
||
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d | ||
|
||
return functional, kernel | ||
|
||
@Operator.register_backend(ComputeBackend.WARP) | ||
def warp_implementation(self, f, rho, u): | ||
wp.launch( | ||
self.warp_kernel, | ||
inputs=[f, rho, u], | ||
dim=rho.shape[1:], | ||
) | ||
return rho, u |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from functools import partial | ||
import jax.numpy as jnp | ||
from jax import jit | ||
import warp as wp | ||
from typing import Any | ||
|
||
from xlb.compute_backend import ComputeBackend | ||
from xlb.operator.operator import Operator | ||
|
||
|
||
class ZeroMoment(Operator): | ||
"""A class to compute the zeroth moment (density) of distribution functions.""" | ||
|
||
@Operator.register_backend(ComputeBackend.JAX) | ||
@partial(jit, static_argnums=(0), inline=True) | ||
def jax_implementation(self, f): | ||
return jnp.sum(f, axis=0, keepdims=True) | ||
|
||
def _construct_warp(self): | ||
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) | ||
|
||
@wp.func | ||
def functional(f: _f_vec): | ||
rho = self.compute_dtype(0.0) | ||
for l in range(self.velocity_set.q): | ||
rho += f[l] | ||
return rho | ||
|
||
@wp.kernel | ||
def kernel3d( | ||
f: wp.array4d(dtype=Any), | ||
rho: wp.array4d(dtype=Any), | ||
): | ||
i, j, k = wp.tid() | ||
index = wp.vec3i(i, j, k) | ||
|
||
_f = _f_vec() | ||
for l in range(self.velocity_set.q): | ||
_f[l] = f[l, index[0], index[1], index[2]] | ||
_rho = functional(_f) | ||
|
||
rho[0, index[0], index[1], index[2]] = _rho | ||
|
||
@wp.kernel | ||
def kernel2d( | ||
f: wp.array3d(dtype=Any), | ||
rho: wp.array3d(dtype=Any), | ||
): | ||
i, j = wp.tid() | ||
index = wp.vec2i(i, j) | ||
|
||
_f = _f_vec() | ||
for l in range(self.velocity_set.q): | ||
_f[l] = f[l, index[0], index[1]] | ||
_rho = functional(_f) | ||
|
||
rho[0, index[0], index[1]] = _rho | ||
|
||
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d | ||
|
||
return functional, kernel | ||
|
||
@Operator.register_backend(ComputeBackend.WARP) | ||
def warp_implementation(self, f, rho): | ||
wp.launch(self.warp_kernel, inputs=[f, rho], dim=rho.shape[1:]) | ||
return rho |