Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onboard MoE #18279

Merged
merged 28 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933",
"commitHash": "6f47420213f757831fae65c686aa471749fa8d60",
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
},
"comments": "cutlass"
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/a4f72a314a85732ed67d5aa8d1088d207a7e0e61.zip;f57357ab6d300e207a632d034ebc8aa036a090d9
1 change: 0 additions & 1 deletion cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTIO
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch
)

FetchContent_GetProperties(cutlass)
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(contrib_ops_excluded_files
"math/gemm_float8.cc"
"math/gemm_float8.cu"
"math/gemm_float8.h"
"moe/*"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down
92 changes: 0 additions & 92 deletions cmake/patches/cutlass/cutlass.patch

This file was deleted.

53 changes: 53 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Do not modify directly.*
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
* <a href="#com.microsoft.MatMulNBits">com.microsoft.MatMulNBits</a>
* <a href="#com.microsoft.MaxpoolWithMask">com.microsoft.MaxpoolWithMask</a>
* <a href="#com.microsoft.MoE">com.microsoft.MoE</a>
* <a href="#com.microsoft.MulInteger">com.microsoft.MulInteger</a>
* <a href="#com.microsoft.MultiHeadAttention">com.microsoft.MultiHeadAttention</a>
* <a href="#com.microsoft.MurmurHash3">com.microsoft.MurmurHash3</a>
Expand Down Expand Up @@ -2904,6 +2905,58 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MoE"></a><a name="com.microsoft.moe">**com.microsoft.MoE**</a>

Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
usually uses top 32 experts.


#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
</dl>

#### Inputs (4 - 6)

<dl>
<dt><tt>input</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
</dl>


### <a name="com.microsoft.MulInteger"></a><a name="com.microsoft.mulinteger">**com.microsoft.MulInteger**</a>

Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support).
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -260,6 +262,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cuda_runtime_api.h>

#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "cutlass/device_kernel.h"

using namespace onnxruntime;

Check warning on line 23 in onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h#L23

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h:23:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace ort_fastertransformer {

template <typename GemmKernel>
inline int compute_occupancy_for_kernel() {
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

Check warning on line 29 in onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h#L29

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h:29:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]

if (smem_size > (48 << 10)) {
cudaError_t status =
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (status == cudaError::cudaErrorInvalidValue) {
// Clear the error bit since we can ignore this.
// This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
// occupancy of 0. This will cause the heuristic to ignore this configuration.
status = cudaGetLastError();
return 0;
}
CUDA_CALL_THROW(status);
}

int max_active_blocks = -1;
CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::Kernel<GemmKernel>,
GemmKernel::kThreadCount, smem_size));

return max_active_blocks;
}

} // namespace ort_fastertransformer
Loading
Loading