Skip to content

Commit

Permalink
Clean up top_k support in the C++ code
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 committed Dec 20, 2023
1 parent a383b67 commit 0255d6b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 67 deletions.
15 changes: 15 additions & 0 deletions deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#define TOP_K_SWITCH(N_TOP_K, ...) \
[&] { \
if (1 == N_TOP_K) { \
constexpr int CONST_TOP_K = 1; \
__VA_ARGS__(); \
} else if (2 == N_TOP_K) { \
constexpr int CONST_TOP_K = 2; \
__VA_ARGS__(); \
} \
}()
50 changes: 21 additions & 29 deletions deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "moe_gather.cuh"
#include "reduction_utils.h"
#include "top_k_gating.cuh"
#include "top_k_utils.h"

namespace gather {

Expand Down Expand Up @@ -105,27 +106,16 @@ __global__ void moe_gather_kernel(T* layer_output,
}
}

#define LAUNCH_FOR_UNROLL(COUNT) \
case COUNT: \
if (n_top_k == 1) { \
moe_gather_kernel<T, COUNT, 1><<<grid, block, 0, stream>>>(layer_output, \
moe_output, \
scores, \
mapped_slots, \
expert_counts, \
n_channels, \
n_experts, \
normalize_scales); \
} else if (n_top_k == 2) { \
moe_gather_kernel<T, COUNT, 2><<<grid, block, 0, stream>>>(layer_output, \
moe_output, \
scores, \
mapped_slots, \
expert_counts, \
n_channels, \
n_experts, \
normalize_scales); \
} \
#define LAUNCH_FOR_UNROLL(COUNT) \
case COUNT: \
moe_gather_kernel<T, COUNT, CONST_TOP_K><<<grid, block, 0, stream>>>(layer_output, \
moe_output, \
scores, \
mapped_slots, \
expert_counts, \
n_channels, \
n_experts, \
normalize_scales); \
break;

template <typename T>
Expand All @@ -147,14 +137,16 @@ void launch_moe_gather(T* layer_output,
const dim3 block(gather::threads);
const dim3 grid(n_tokens);

switch (copy_unroll) {
LAUNCH_FOR_UNROLL(1)
LAUNCH_FOR_UNROLL(2)
LAUNCH_FOR_UNROLL(3)
LAUNCH_FOR_UNROLL(4)
LAUNCH_FOR_UNROLL(5)
LAUNCH_FOR_UNROLL(6)
}
TOP_K_SWITCH(n_top_k, [&] {
switch (copy_unroll) {
LAUNCH_FOR_UNROLL(1)
LAUNCH_FOR_UNROLL(2)
LAUNCH_FOR_UNROLL(3)
LAUNCH_FOR_UNROLL(4)
LAUNCH_FOR_UNROLL(5)
LAUNCH_FOR_UNROLL(6)
}
});
}

#define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
// DeepSpeed Team

#include "ds_kernel_utils.h"
#include "moe_scatter.cuh"
#include "reduction_utils.h"
#include "top_k_gating.cuh"
#include "top_k_utils.h"

using ROp = reduce::ROpType;

Expand Down Expand Up @@ -149,29 +149,18 @@ __global__ void moe_scatter_kernel(T* moe_input,
}
}

#define LAUNCH_FOR_UNROLL(COUNT) \
case COUNT: \
if (n_top_k == 1) { \
moe_scatter_kernel<T, COUNT, 1><<<grid, block, 0, stream>>>(moe_input, \
expert_count_cumsums, \
mapped_slots, \
activations, \
assignments, \
expert_counts, \
offsets, \
n_channels, \
n_experts); \
} else if (n_top_k == 2) { \
moe_scatter_kernel<T, COUNT, 2><<<grid, block, 0, stream>>>(moe_input, \
expert_count_cumsums, \
mapped_slots, \
activations, \
assignments, \
expert_counts, \
offsets, \
n_channels, \
n_experts); \
} \
#define LAUNCH_FOR_UNROLL(COUNT) \
case COUNT: \
moe_scatter_kernel<T, COUNT, CONST_TOP_K> \
<<<grid, block, 0, stream>>>(moe_input, \
expert_count_cumsums, \
mapped_slots, \
activations, \
assignments, \
expert_counts, \
offsets, \
n_channels, \
n_experts); \
break;

template <typename T>
Expand All @@ -194,14 +183,16 @@ void launch_moe_scatter(T* moe_input,
const dim3 block(scatter::threads);
const dim3 grid(n_tokens);

switch (copy_unroll) {
LAUNCH_FOR_UNROLL(1);
LAUNCH_FOR_UNROLL(2);
LAUNCH_FOR_UNROLL(3);
LAUNCH_FOR_UNROLL(4);
LAUNCH_FOR_UNROLL(5);
LAUNCH_FOR_UNROLL(6);
}
TOP_K_SWITCH(n_top_k, [&] {
switch (copy_unroll) {
LAUNCH_FOR_UNROLL(1);
LAUNCH_FOR_UNROLL(2);
LAUNCH_FOR_UNROLL(3);
LAUNCH_FOR_UNROLL(4);
LAUNCH_FOR_UNROLL(5);
LAUNCH_FOR_UNROLL(6);
}
});
}

#define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "memory_access_utils.h"
#include "reduction_utils.h"
#include "top_k_gating.cuh"
#include "top_k_utils.h"

using ROp = reduce::ROpType;

Expand Down Expand Up @@ -100,13 +101,10 @@ void launch_top_k_gating(int32_t* expert_counts,
const dim3 grid(n_tokens);
const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size);

if (n_top_k == 1) {
top_k_gating_kernel<T, 1><<<grid, block, 0, stream>>>(
TOP_K_SWITCH(n_top_k, [&] {
top_k_gating_kernel<T, CONST_TOP_K><<<grid, block, 0, stream>>>(
expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts);
} else if (n_top_k == 2) {
top_k_gating_kernel<T, 2><<<grid, block, 0, stream>>>(
expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts);
}
});
}

#define INSTANTIATE_top_k_KERNEL(T) \
Expand Down
1 change: 1 addition & 0 deletions op_builder/ragged_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def include_paths(self):
'inference/v2/kernels/ragged_ops/atom_builder',
'inference/v2/kernels/ragged_ops/blocked_flash',
'inference/v2/kernels/ragged_ops/embed',
'inference/v2/kernels/ragged_ops/includes',
'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary',
'inference/v2/kernels/ragged_ops/logits_gather',
'inference/v2/kernels/ragged_ops/moe_gather',
Expand Down

0 comments on commit 0255d6b

Please sign in to comment.