Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onboard MoE #18279

Merged
merged 28 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933",
"commitHash": "6f47420213f757831fae65c686aa471749fa8d60",
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
},
"comments": "cutlass"
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0
2 changes: 1 addition & 1 deletion cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTIO
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch
# PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch
)

FetchContent_GetProperties(cutlass)
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(contrib_ops_excluded_files
"math/gemm_float8.cc"
"math/gemm_float8.cu"
"math/gemm_float8.h"
"moe/*"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down
92 changes: 0 additions & 92 deletions cmake/patches/cutlass/cutlass.patch

This file was deleted.

50 changes: 50 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Do not modify directly.*
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
* <a href="#com.microsoft.MatMulNBits">com.microsoft.MatMulNBits</a>
* <a href="#com.microsoft.MaxpoolWithMask">com.microsoft.MaxpoolWithMask</a>
* <a href="#com.microsoft.MoEBlock">com.microsoft.MoEBlock</a>
* <a href="#com.microsoft.MulInteger">com.microsoft.MulInteger</a>
* <a href="#com.microsoft.MultiHeadAttention">com.microsoft.MultiHeadAttention</a>
* <a href="#com.microsoft.MurmurHash3">com.microsoft.MurmurHash3</a>
Expand Down Expand Up @@ -2906,6 +2907,55 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MoEBlock"></a><a name="com.microsoft.moeblock">**com.microsoft.MoEBlock**</a>
wangyems marked this conversation as resolved.
Show resolved Hide resolved

Mixture of experts.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use</dd>
wangyems marked this conversation as resolved.
Show resolved Hide resolved
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
</dl>

#### Inputs

<dl>
<dt><tt>input</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size)</dd>
<dt><tt>gated_output</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc1_experts_bias</tt> : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_bias</tt> : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D input tensor with shape (num_rows, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
</dl>


### <a name="com.microsoft.MulInteger"></a><a name="com.microsoft.mulinteger">**com.microsoft.MulInteger**</a>

Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support).
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoEBlock|*in* input:**T**<br> *in* gated_output:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoEBlock);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoEBlock);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -252,6 +254,8 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoEBlock)>,

Check warning on line 257 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L257

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:257:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoEBlock)>,

Check warning on line 258 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L258

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:258:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
Expand Down
80 changes: 80 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <iostream>
#include <cuda_runtime.h>

Check warning on line 19 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L19

Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:19:  Found C system header after C++ system header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]
#include <cuda_fp16.h>

Check warning on line 20 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L20

Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:20:  Found C system header after C++ system header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]
#include <cublas_v2.h>

Check warning on line 21 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L21

Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:21:  Found C system header after C++ system header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]
#include <cublasLt.h>
#include <stdexcept>
#include <map>
#include "stdio.h"

Check warning on line 25 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L25

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:25:  Include the directory when naming header files  [build/include_subdir] [4]
#include <fstream>

Check warning on line 26 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L26

Found C++ system header after other header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:26:  Found C++ system header after other header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]

namespace fastertransformer {

static const char* _cudaGetErrorEnum(cublasStatus_t error) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";

case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";

case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";

case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";

case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";

case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";

case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";

case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";

case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";

case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}

static const char* _cudaGetErrorEnum(cudaError_t error) {
return cudaGetErrorString(error);
}

template <typename T>
void check(T result, char const* const func, const char* const file, int const line) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") +

Check warning on line 72 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L72

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:72:  Add #include <string> for string  [build/include_what_you_use] [4]
(_cudaGetErrorEnum(result)) + " " + file +
":" + std::to_string(line) + " \n");
}
}

#define check_cuda_error(val) fastertransformer::check((val), #val, __FILE__, __LINE__)

} // namespace fastertransformer
49 changes: 49 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cuda_runtime_api.h>

#include "cutlass/device_kernel.h"
#include "common.h"

Check warning on line 21 in onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h#L21

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h:21:  Include the directory when naming header files  [build/include_subdir] [4]

namespace fastertransformer {

template <typename GemmKernel>
inline int compute_occupancy_for_kernel() {
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

Check warning on line 27 in onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h#L27

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h:27:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]

if (smem_size > (48 << 10)) {
cudaError_t status =
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (status == cudaError::cudaErrorInvalidValue) {
// Clear the error bit since we can ignore this.
// This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
// occupancy of 0. This will cause the heuristic to ignore this configuration.
status = cudaGetLastError();
return 0;
}
check_cuda_error(status);
}

int max_active_blocks = -1;
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));

return max_active_blocks;
}

} // namespace fastertransformer
Loading
Loading