Skip to content

Commit

Permalink
added recent changes to the momentum exchange method
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Oct 4, 2024
1 parent 0f912d4 commit f215c5b
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions xlb/operator/force/momentum_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _construct_warp(self):
# Construct the warp kernel
@wp.kernel
def kernel2d(
f: wp.array3d(dtype=Any),
f_0: wp.array3d(dtype=Any),
f_1: wp.array3d(dtype=Any),
bc_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
force: wp.array(dtype=Any),
Expand Down Expand Up @@ -134,11 +135,12 @@ def kernel2d(
# Get the distribution function
f_post_collision = _f_vec()
for l in range(self.velocity_set.q):
f_post_collision[l] = f[l, index[0], index[1]]
f_post_collision[l] = f_0[l, index[0], index[1]]

# Apply streaming (pull method)
f_post_stream = self.stream.warp_functional(f, index)
f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask)
timestep = 0
f_post_stream = self.stream.warp_functional(f_0, index)
f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream)

# Compute the momentum transfer
for d in range(self.velocity_set.d):
Expand All @@ -156,7 +158,8 @@ def kernel2d(
# Construct the warp kernel
@wp.kernel
def kernel3d(
fpop: wp.array4d(dtype=Any),
f_0: wp.array4d(dtype=Any),
f_1: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
force: wp.array(dtype=Any),
Expand Down Expand Up @@ -187,12 +190,12 @@ def kernel3d(
# Get the distribution function
f_post_collision = _f_vec()
for l in range(self.velocity_set.q):
f_post_collision[l] = fpop[l, index[0], index[1], index[2]]
f_post_collision[l] = f_0[l, index[0], index[1], index[2]]

# Apply streaming (pull method)
timestep = 0
f_post_stream = self.stream.warp_functional(fpop, index)
f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, fpop, fpop, f_post_collision, f_post_stream)
f_post_stream = self.stream.warp_functional(f_0, index)
f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream)

# Compute the momentum transfer
for d in range(self.velocity_set.d):
Expand All @@ -213,15 +216,15 @@ def kernel3d(
return None, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, bc_mask, missing_mask):
def warp_implementation(self, f_0, f_1, bc_mask, missing_mask):
# Allocate the force vector (the total integral value will be computed)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
force = wp.zeros((1), dtype=_u_vec)

# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f, bc_mask, missing_mask, force],
dim=f.shape[1:],
inputs=[f_0, f_1, bc_mask, missing_mask, force],
dim=f_0.shape[1:],
)
return force.numpy()[0]

0 comments on commit f215c5b

Please sign in to comment.