Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jun 22, 2024
1 parent 621af1a commit 6641d95
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ __launch_bounds__(TPB) __global__
using cub_kvp = cub::KeyValuePair<int, T>;
using KVBlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename KVBlockReduce::TempStorage kvTmpStorage;
__shared__ float result_kvp_value;

cub_kvp thread_kvp;
cub::ArgMax arg_max;
Expand Down Expand Up @@ -218,23 +219,22 @@ __launch_bounds__(TPB) __global__
}

const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = K * block_row + k_idx;
result_kvp_value = (float)result_kvp.value;
indices[idx] = result_kvp.key;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();

for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
factor[k_idx] = max(abs((float)inputs[idx]), (float)result_kvp.value);
logits_mask[k_idx] = ((float)result_kvp.value - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]);
factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value);
logits_mask[k_idx] = (result_kvp_value - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]);
if (k_idx == 1 && expert == indices[K * block_row]) {
logits_mask[1] = true;
}
}

if (threadIdx.x == 0) {
const int idx = K * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = result_kvp.key;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}

using BlockReduce = cub::BlockReduce<float, TPB>;
Expand Down

0 comments on commit 6641d95

Please sign in to comment.