diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 79d75a4..8b0aacf 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -103,7 +103,8 @@ def _construct_warp(self): # Construct the warp kernel @wp.kernel def kernel2d( - f: wp.array3d(dtype=Any), + f_0: wp.array3d(dtype=Any), + f_1: wp.array3d(dtype=Any), bc_mask: wp.array3d(dtype=wp.uint8), missing_mask: wp.array3d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -134,11 +135,12 @@ def kernel2d( # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = f[l, index[0], index[1]] + f_post_collision[l] = f_0[l, index[0], index[1]] # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(f_post_collision, f_post_stream, _f_vec(), _missing_mask) + timestep = 0 + f_post_stream = self.stream.warp_functional(f_0, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -156,7 +158,8 @@ def kernel2d( # Construct the warp kernel @wp.kernel def kernel3d( - fpop: wp.array4d(dtype=Any), + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), missing_mask: wp.array4d(dtype=wp.bool), force: wp.array(dtype=Any), @@ -187,12 +190,12 @@ def kernel3d( # Get the distribution function f_post_collision = _f_vec() for l in range(self.velocity_set.q): - f_post_collision[l] = fpop[l, index[0], index[1], index[2]] + f_post_collision[l] = f_0[l, index[0], index[1], index[2]] # Apply streaming (pull method) timestep = 0 - f_post_stream = self.stream.warp_functional(fpop, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, fpop, fpop, f_post_collision, f_post_stream) + f_post_stream = self.stream.warp_functional(f_0, index) + f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -213,7 +216,7 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, bc_mask, missing_mask): + def warp_implementation(self, f_0, f_1, bc_mask, missing_mask): # Allocate the force vector (the total integral value will be computed) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) force = wp.zeros((1), dtype=_u_vec) @@ -221,7 +224,7 @@ def warp_implementation(self, f, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f, bc_mask, missing_mask, force], - dim=f.shape[1:], + inputs=[f_0, f_1, bc_mask, missing_mask, force], + dim=f_0.shape[1:], ) return force.numpy()[0]