Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a faster and more accurate algorithm for mesh_boundary_masker #96

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xlb/helper/nse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from xlb.precision_policy import Precision
from typing import Tuple


def create_nse_fields(
grid_shape: Tuple[int, int, int] = None,
grid=None,
Expand Down
1 change: 0 additions & 1 deletion xlb/operator/boundary_masker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from xlb.operator.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker
from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
from xlb.operator.boundary_masker.mesh_distance_boundary_masker import MeshDistanceBoundaryMasker
4 changes: 2 additions & 2 deletions xlb/operator/boundary_masker/indices_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
# For now, we compute the bmap on GPU zero.
if dim == 2:
bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1]), dtype=jnp.uint8)
bmap = bmap.at[pad_x : -pad_x, pad_y : -pad_y].set(bc_mask[0])
bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y].set(bc_mask[0])
grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True)
# bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0)
if dim == 3:
bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1], pad_z * 2 + bc_mask[0].shape[2]), dtype=jnp.uint8)
bmap = bmap.at[pad_x : -pad_x, pad_y : -pad_y, pad_z : -pad_z].set(bc_mask[0])
bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z].set(bc_mask[0])
grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True)
# bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0)

Expand Down
56 changes: 34 additions & 22 deletions xlb/operator/boundary_masker/mesh_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,20 @@ def jax_implementation(

def _construct_warp(self):
# Make constants for warp
_c = self.velocity_set.c
_q = self.velocity_set.q
_c_float = self.velocity_set.c_float
_q = wp.constant(self.velocity_set.q)
_opp_indices = self.velocity_set.opp_indices

@wp.func
def index_to_position(index: wp.vec3i):
# position of the point
ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]))
pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center
return pos

# Construct the warp kernel
# Do voxelization mesh query (warp.mesh_query_aabb) to find solid voxels
# - this gives an approximate 1 voxel thick surface around mesh
@wp.kernel
def kernel(
mesh_id: wp.uint64,
Expand All @@ -66,25 +76,27 @@ def kernel(
index = wp.vec3i(i, j, k)

# position of the point
ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]))
pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center
# Compute the maximum length
max_length = wp.sqrt(2.0) / 2.0 # half of unit cell diagonal

# evaluate if point is inside mesh
query = wp.mesh_query_point_no_sign(mesh_id, pos, max_length)
if query.result:
# set point to be solid
# Stream indices
pos_bc_cell = index_to_position(index)
half = wp.vec3(0.5, 0.5, 0.5)

vox_query = wp.mesh_query_aabb(mesh_id, pos_bc_cell - half, pos_bc_cell + half)
face = wp.int32(0)
if wp.mesh_query_aabb_next(vox_query, face):
# Make solid voxel
bc_mask[0, index[0], index[1], index[2]] = wp.uint8(255)
else:
# Find the fractional distance to the mesh in each direction
for l in range(1, _q):
# Get the index of the streaming direction
push_index = wp.vec3i()
for d in range(self.velocity_set.d):
push_index[d] = index[d] + _c[d, l]
_dir = wp.vec3f(_c_float[0, l], _c_float[1, l], _c_float[2, l])

# Set the boundary id and missing_mask
bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number)
missing_mask[l, push_index[0], push_index[1], push_index[2]] = True
# Check to see if this neighbor is solid
vox_query_dir = wp.mesh_query_aabb(mesh_id, pos_bc_cell + _dir - half, pos_bc_cell + _dir + half)
face = wp.int32(0)
if wp.mesh_query_aabb_next(vox_query_dir, face):
# We know we have a solid neighbor
# Set the boundary id and missing_mask
bc_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number)
missing_mask[_opp_indices[l], index[0], index[1], index[2]] = True

return None, kernel

Expand All @@ -97,9 +109,9 @@ def warp_implementation(
):
assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!'
assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!"
assert bc.mesh_vertices.shape[1] == self.velocity_set.d, (
"Mesh points must be reshaped into an array (N, 3) where N indicates number of points!"
)
assert (
bc.mesh_vertices.shape[1] == self.velocity_set.d
), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!"
mesh_vertices = bc.mesh_vertices
id_number = bc.id

Expand Down
184 changes: 0 additions & 184 deletions xlb/operator/boundary_masker/mesh_distance_boundary_masker.py

This file was deleted.

1 change: 1 addition & 0 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def prepare_fields(self, initializer=None):
f_0 = initializer(self.grid, self.velocity_set, self.precision_policy, self.compute_backend)
else:
from xlb.helper.initializers import initialize_eq

f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend)

# Copy f_0 using backend-specific copy to f_1
Expand Down
Loading