Skip to content

Commit

Permalink
Added a faster and more accurate algorithm for mesh_boundary_masker
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Dec 19, 2024
1 parent 237ac22 commit 83f3683
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 209 deletions.
1 change: 1 addition & 0 deletions xlb/helper/nse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from xlb.precision_policy import Precision
from typing import Tuple


def create_nse_fields(
grid_shape: Tuple[int, int, int] = None,
grid=None,
Expand Down
1 change: 0 additions & 1 deletion xlb/operator/boundary_masker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from xlb.operator.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker
from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
from xlb.operator.boundary_masker.mesh_distance_boundary_masker import MeshDistanceBoundaryMasker
4 changes: 2 additions & 2 deletions xlb/operator/boundary_masker/indices_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
# For now, we compute the bmap on GPU zero.
if dim == 2:
bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1]), dtype=jnp.uint8)
bmap = bmap.at[pad_x : -pad_x, pad_y : -pad_y].set(bc_mask[0])
bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y].set(bc_mask[0])
grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True)
# bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0)
if dim == 3:
bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1], pad_z * 2 + bc_mask[0].shape[2]), dtype=jnp.uint8)
bmap = bmap.at[pad_x : -pad_x, pad_y : -pad_y, pad_z : -pad_z].set(bc_mask[0])
bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z].set(bc_mask[0])
grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True)
# bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0)

Expand Down
56 changes: 34 additions & 22 deletions xlb/operator/boundary_masker/mesh_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,20 @@ def jax_implementation(

def _construct_warp(self):
# Make constants for warp
_c = self.velocity_set.c
_q = self.velocity_set.q
_c_float = self.velocity_set.c_float
_q = wp.constant(self.velocity_set.q)
_opp_indices = self.velocity_set.opp_indices

@wp.func
def index_to_position(index: wp.vec3i):
# position of the point
ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]))
pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center
return pos

# Construct the warp kernel
# Do voxelization mesh query (warp.mesh_query_aabb) to find solid voxels
# - this gives an approximate 1 voxel thick surface around mesh
@wp.kernel
def kernel(
mesh_id: wp.uint64,
Expand All @@ -66,25 +76,27 @@ def kernel(
index = wp.vec3i(i, j, k)

# position of the point
ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]))
pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center
# Compute the maximum length
max_length = wp.sqrt(2.0) / 2.0 # half of unit cell diagonal

# evaluate if point is inside mesh
query = wp.mesh_query_point_no_sign(mesh_id, pos, max_length)
if query.result:
# set point to be solid
# Stream indices
pos_bc_cell = index_to_position(index)
half = wp.vec3(0.5, 0.5, 0.5)

vox_query = wp.mesh_query_aabb(mesh_id, pos_bc_cell - half, pos_bc_cell + half)
face = wp.int32(0)
if wp.mesh_query_aabb_next(vox_query, face):
# Make solid voxel
bc_mask[0, index[0], index[1], index[2]] = wp.uint8(255)
else:
# Find the fractional distance to the mesh in each direction
for l in range(1, _q):
# Get the index of the streaming direction
push_index = wp.vec3i()
for d in range(self.velocity_set.d):
push_index[d] = index[d] + _c[d, l]
_dir = wp.vec3f(_c_float[0, l], _c_float[1, l], _c_float[2, l])

# Set the boundary id and missing_mask
bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number)
missing_mask[l, push_index[0], push_index[1], push_index[2]] = True
# Check to see if this neighbor is solid
vox_query_dir = wp.mesh_query_aabb(mesh_id, pos_bc_cell + _dir - half, pos_bc_cell + _dir + half)
face = wp.int32(0)
if wp.mesh_query_aabb_next(vox_query_dir, face):
# We know we have a solid neighbor
# Set the boundary id and missing_mask
bc_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number)
missing_mask[_opp_indices[l], index[0], index[1], index[2]] = True

return None, kernel

Expand All @@ -97,9 +109,9 @@ def warp_implementation(
):
assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!'
assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!"
assert bc.mesh_vertices.shape[1] == self.velocity_set.d, (
"Mesh points must be reshaped into an array (N, 3) where N indicates number of points!"
)
assert (
bc.mesh_vertices.shape[1] == self.velocity_set.d
), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!"
mesh_vertices = bc.mesh_vertices
id_number = bc.id

Expand Down
184 changes: 0 additions & 184 deletions xlb/operator/boundary_masker/mesh_distance_boundary_masker.py

This file was deleted.

1 change: 1 addition & 0 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def prepare_fields(self, initializer=None):
f_0 = initializer(self.grid, self.velocity_set, self.precision_policy, self.compute_backend)
else:
from xlb.helper.initializers import initialize_eq

f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend)

# Copy f_0 using backend-specific copy to f_1
Expand Down

0 comments on commit 83f3683

Please sign in to comment.