diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 5eb0c10..6bd9311 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -72,7 +72,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): f = f.numpy() f_post = f_post.numpy() - assert f.shape == (velocity_set.q,) + grid_shape + assert f.shape == (velocity_set.q,) + grid_shape if dim == 3 else (velocity_set.q, grid_shape[0], grid_shape[1], 1) # Assert that the values are correct in the indices of the sphere weights = velocity_set.w diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 10b9244..59c6c9d 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -58,7 +58,10 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): bc_mask, missing_mask = indices_boundary_masker([fullway_bc], bc_mask, missing_mask, start_index=None) # Generate a random field with the same shape - random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) + if dim == 2: + random_field = np.random.rand(velocity_set.q, grid_shape[0], grid_shape[1], 1).astype(np.float32) + else: + random_field = np.random.rand(velocity_set.q, grid_shape[0], grid_shape[1], grid_shape[2]).astype(np.float32) # Add the random field to f_pre f_pre = wp.array(random_field) @@ -71,7 +74,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): f = f_pre.numpy() f_post = f_post.numpy() - assert f.shape == (velocity_set.q,) + grid_shape + assert f.shape == (velocity_set.q,) + grid_shape if dim == 3 else (velocity_set.q, grid_shape[0], grid_shape[1], 1) for i in range(velocity_set.q): np.allclose( diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 67b343d..56a332f 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -64,4 +64,16 @@ def functional( ): return f_pre - return functional, None + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index b4b957a..77f408f 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -88,4 +88,17 @@ def functional( _f = self.equilibrium_operator.warp_functional(_rho, _u) return _f - return functional, None + # Use the parent class's kernel and pass the functional + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 4a96c73..38657e5 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -193,4 +193,16 @@ def prepare_bc_auxilary_data( _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux return _f - return (functional, prepare_bc_auxilary_data), None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return (functional, prepare_bc_auxilary_data), kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index afe05de..19a3013 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -74,4 +74,16 @@ def functional( fliped_f[l] = f_pre[_opp_indices[l]] return fliped_f - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 5806375..94ddba3 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -309,4 +309,16 @@ def functional_method2( functional = functional_method1 - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post \ No newline at end of file diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index ee68b50..bf04af0 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -87,5 +87,16 @@ def functional( return _f - return functional, None + kernel = self._construct_kernel(functional) + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 12622e2..af4c783 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -266,4 +266,16 @@ def functional_pressure( functional = functional_velocity elif self.bc_type == "pressure": functional = functional_pressure - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index c5d9498..a92d909 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -283,4 +283,16 @@ def functional_pressure( elif self.bc_type == "velocity": functional = functional_pressure - return functional, None \ No newline at end of file + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, missing_mask], + dim=f_pre.shape[1:], + ) + return f_post diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 2cd2a11..bf1eef2 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -111,3 +111,38 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): currently being called after collision only. """ return f_post + + def _construct_kernel(self, functional): + """ + Constructs the warp kernel for the boundary condition. + The functional is specific to each boundary condition and should be passed as an argument. + """ + _id = wp.uint8(self.id) + + # Construct the warp kernel + @wp.kernel + def kernel( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data(f_pre, f_post, bc_mask, missing_mask, index) + + # Apply the boundary condition + if _boundary_id == _id: + timestep = 0 + _f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post) + else: + _f = _f_post + + # Write the result + for l in range(self.velocity_set.q): + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) + + return kernel