diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h new file mode 100644 index 000000000000..abb9e15f8f6f --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#define TOP_K_SWITCH(N_TOP_K, ...) \ + [&] { \ + if (1 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 1; \ + __VA_ARGS__(); \ + } else if (2 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 2; \ + __VA_ARGS__(); \ + } \ + }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu index 7e8d8a14237a..4153a2a3636f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu @@ -8,6 +8,7 @@ #include "moe_gather.cuh" #include "reduction_utils.h" #include "top_k_gating.cuh" +#include "top_k_utils.h" namespace gather { @@ -105,27 +106,16 @@ __global__ void moe_gather_kernel(T* layer_output, } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - if (n_top_k == 1) { \ - moe_gather_kernel<<>>(layer_output, \ - moe_output, \ - scores, \ - mapped_slots, \ - expert_counts, \ - n_channels, \ - n_experts, \ - normalize_scales); \ - } else if (n_top_k == 2) { \ - moe_gather_kernel<<>>(layer_output, \ - moe_output, \ - scores, \ - mapped_slots, \ - expert_counts, \ - n_channels, \ - n_experts, \ - normalize_scales); \ - } \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + expert_counts, \ + n_channels, \ + n_experts, \ + normalize_scales); \ break; template @@ -147,14 +137,16 @@ void launch_moe_gather(T* layer_output, const dim3 block(gather::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1) - LAUNCH_FOR_UNROLL(2) - LAUNCH_FOR_UNROLL(3) - LAUNCH_FOR_UNROLL(4) - LAUNCH_FOR_UNROLL(5) - LAUNCH_FOR_UNROLL(6) - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1) + LAUNCH_FOR_UNROLL(2) + LAUNCH_FOR_UNROLL(3) + LAUNCH_FOR_UNROLL(4) + LAUNCH_FOR_UNROLL(5) + LAUNCH_FOR_UNROLL(6) + } + }); } #define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu index 9738c417cd25..d3eb4f649e79 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu @@ -4,9 +4,9 @@ // DeepSpeed Team #include "ds_kernel_utils.h" -#include "moe_scatter.cuh" #include "reduction_utils.h" #include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; @@ -149,29 +149,18 @@ __global__ void moe_scatter_kernel(T* moe_input, } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - if (n_top_k == 1) { \ - moe_scatter_kernel<<>>(moe_input, \ - expert_count_cumsums, \ - mapped_slots, \ - activations, \ - assignments, \ - expert_counts, \ - offsets, \ - n_channels, \ - n_experts); \ - } else if (n_top_k == 2) { \ - moe_scatter_kernel<<>>(moe_input, \ - expert_count_cumsums, \ - mapped_slots, \ - activations, \ - assignments, \ - expert_counts, \ - offsets, \ - n_channels, \ - n_experts); \ - } \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel \ + <<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + activations, \ + assignments, \ + expert_counts, \ + offsets, \ + n_channels, \ + n_experts); \ break; template @@ -194,14 +183,16 @@ void launch_moe_scatter(T* moe_input, const dim3 block(scatter::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1); - LAUNCH_FOR_UNROLL(2); - LAUNCH_FOR_UNROLL(3); - LAUNCH_FOR_UNROLL(4); - LAUNCH_FOR_UNROLL(5); - LAUNCH_FOR_UNROLL(6); - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } + }); } #define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \ diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu index 47c620bd4cf1..58f95c045593 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu @@ -7,6 +7,7 @@ #include "memory_access_utils.h" #include "reduction_utils.h" #include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; @@ -100,13 +101,10 @@ void launch_top_k_gating(int32_t* expert_counts, const dim3 grid(n_tokens); const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size); - if (n_top_k == 1) { - top_k_gating_kernel<<>>( + TOP_K_SWITCH(n_top_k, [&] { + top_k_gating_kernel<<>>( expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); - } else if (n_top_k == 2) { - top_k_gating_kernel<<>>( - expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); - } + }); } #define INSTANTIATE_top_k_KERNEL(T) \ diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index cf7b4c07d7f7..8cb372e96c37 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -101,6 +101,7 @@ def include_paths(self): 'inference/v2/kernels/ragged_ops/atom_builder', 'inference/v2/kernels/ragged_ops/blocked_flash', 'inference/v2/kernels/ragged_ops/embed', + 'inference/v2/kernels/ragged_ops/includes', 'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary', 'inference/v2/kernels/ragged_ops/logits_gather', 'inference/v2/kernels/ragged_ops/moe_gather',