Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marged 2D and 3D kernels in Warp #75

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def run(self, num_steps, post_process_interval=100):
def post_process(self, i):
# Write the results. We'll use JAX backend for the post-processing
if not isinstance(self.f_0, jnp.ndarray):
f_0 = wp.to_jax(self.f_0)
# If the backend is warp, we need to drop the last dimension added by warp for 2D simulations
f_0 = wp.to_jax(self.f_0)[..., 0]
else:
f_0 = self.f_0

Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def define_boundary_indices(self):
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

# Load the mesh
stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl"
# Load the mesh (replace with your own mesh)
stl_filename = "../stl-files/DrivAer-Notchback.stl"
mesh = trimesh.load_mesh(stl_filename, process=False)
mesh_vertices = mesh.vertices

Expand Down
25 changes: 10 additions & 15 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
jax==0.4.20
jaxlib==0.4.20
matplotlib==3.8.0
numpy==1.26.1
pyvista==0.43.4
Rtree==1.0.1
trimesh==4.4.1
orbax-checkpoint==0.4.1
termcolor==2.3.0
PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git
tqdm==4.66.2
warp-lang==1.0.2
numpy-stl==3.1.1
pydantic==2.7.0
ruff==0.5.6
jax[cuda]
matplotlib
numpy
pyvista
Rtree
trimesh
warp-lang
numpy-stl
pydantic
ruff
7 changes: 4 additions & 3 deletions tests/grids/test_grid_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def test_warp_grid_create_field(grid_size):
init_xlb_env(xlb.velocity_set.D3Q19)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9, dtype=Precision.FP32)

assert f.shape == (9,) + grid_shape, "Field shape is incorrect"
if len(grid_shape) == 2:
assert f.shape == (9,) + grid_shape + (1,), "Field shape is incorrect got {}".format(f.shape)
else:
assert f.shape == (9,) + grid_shape, "Field shape is incorrect got {}".format(f.shape)
assert isinstance(f, wp.array), "Field should be a Warp ndarray"


Expand All @@ -37,7 +39,6 @@ def test_warp_grid_create_field_fill_value():
assert isinstance(f, wp.array), "Field should be a Warp ndarray"

f = f.numpy()
assert f.shape == (9,) + grid_shape, "Field shape is incorrect"
assert np.allclose(f, fill_value), "Field not properly initialized with fill_value"


Expand Down
7 changes: 5 additions & 2 deletions tests/kernels/stream/test_stream_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape):
expected = jnp.stack(expected, axis=0)

if dim == 2:
f_initial_warp = wp.array(f_initial)
f_initial_warp = wp.array(f_initial[..., np.newaxis])

elif dim == 3:
f_initial_warp = wp.array(f_initial)
Expand All @@ -71,7 +71,10 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape):
f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q)
f_streamed = stream_op(f_initial_warp, f_streamed)

assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected"
if len(grid_shape) == 2:
assert jnp.allclose(f_streamed.numpy()[..., 0], np.array(expected)), "Streaming did not occur as expected"
else:
assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected"


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion xlb/grid/warp_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def create_field(
fill_value=None,
):
dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype
shape = (cardinality,) + (self.shape)

# Check if shape is 2D, and if so, append a singleton dimension to the shape
shape = (cardinality,) + (self.shape if len(self.shape) != 2 else self.shape + (1,))

if fill_value is None:
f = wp.zeros(shape, dtype=dtype)
Expand Down
65 changes: 1 addition & 64 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,67 +64,4 @@ def functional(
):
return f_pre

@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
bc_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.uint8),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
def kernel3d(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f_post
return functional, None
66 changes: 1 addition & 65 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,68 +88,4 @@ def functional(
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f

# Construct the warp kernel
@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
bc_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
def kernel3d(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f_post
return functional, None
121 changes: 4 additions & 117 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,7 @@ def _construct_warp(self):
_opp_indices = self.velocity_set.opp_indices

@wp.func
def get_normal_vectors_2d(
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -wp.vec2i(_c[0, l], _c[1, l])

@wp.func
def get_normal_vectors_3d(
def get_normal_vectors(
missing_mask: Any,
):
for l in range(_q):
Expand All @@ -175,7 +167,7 @@ def functional(
return _f

@wp.func
def prepare_bc_auxilary_data_2d(
def prepare_bc_auxilary_data(
index: Any,
timestep: Any,
missing_mask: Any,
Expand All @@ -188,34 +180,7 @@ def prepare_bc_auxilary_data_2d(
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
nv = get_normal_vectors_2d(missing_mask)
for l in range(self.velocity_set.q):
if missing_mask[l] == wp.uint8(1):
# f_0 is the post-collision values of the current time-step
# Get pull index associated with the "neighbours" pull_index
pull_index = type(index)()
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]])
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux
return _f

@wp.func
def prepare_bc_auxilary_data_3d(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
):
# Preparing the formulation for this BC using the neighbour's populations stored in f_aux and
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
nv = get_normal_vectors_3d(missing_mask)
nv = get_normal_vectors(missing_mask)
for l in range(self.velocity_set.q):
if missing_mask[l] == wp.uint8(1):
# f_0 is the post-collision values of the current time-step
Expand All @@ -228,82 +193,4 @@ def prepare_bc_auxilary_data_3d(
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux
return _f

# Construct the warp kernel
@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
bc_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)
timestep = 0

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index)

# special preparation of auxiliary data
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
_f_pre = prepare_bc_auxilary_data_2d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post)

# Apply the boundary condition
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
# TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both
# collision and streaming?
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
def kernel3d(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)
timestep = 0

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index)
_f_aux = _f_vec()

# special preparation of auxiliary data
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
_f_pre = prepare_bc_auxilary_data_3d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post)

# Apply the boundary condition
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
# TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both
# collision and streaming?
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
prepare_bc_auxilary_data = prepare_bc_auxilary_data_3d if self.velocity_set.d == 3 else prepare_bc_auxilary_data_2d

return (functional, prepare_bc_auxilary_data), kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f_post
return (functional, prepare_bc_auxilary_data), None
Loading
Loading