Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onboard MoE #18279

Merged
merged 28 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(contrib_ops_excluded_files
"math/gemm_float8.cc"
"math/gemm_float8.cu"
"math/gemm_float8.h"
"moe/*"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoEBlock);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoEBlock);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -242,6 +244,8 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoEBlock)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoEBlock)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
Expand All @@ -251,7 +255,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch)>,

Check warning on line 258 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L258

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:258:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv)>,
Expand Down
80 changes: 80 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <iostream>
#include <cuda_runtime.h>

Check warning on line 19 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L19

Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:19:  Found C system header after C++ system header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]
#include <cuda_fp16.h>

Check warning on line 20 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L20

Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:20:  Found C system header after C++ system header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]
#include <cublas_v2.h>

Check warning on line 21 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L21

Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:21:  Found C system header after C++ system header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]
#include <cublasLt.h>
#include <stdexcept>
#include <map>
#include "stdio.h"

Check warning on line 25 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L25

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:25:  Include the directory when naming header files  [build/include_subdir] [4]
#include <fstream>

Check warning on line 26 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L26

Found C++ system header after other header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:26:  Found C++ system header after other header. Should be: common.h, c system, c++ system, other.  [build/include_order] [4]

namespace fastertransformer {

static const char* _cudaGetErrorEnum(cublasStatus_t error) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";

case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";

case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";

case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";

case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";

case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";

case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";

case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";

case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";

case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}

static const char* _cudaGetErrorEnum(cudaError_t error) {
return cudaGetErrorString(error);
}

template <typename T>
void check(T result, char const* const func, const char* const file, int const line) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
if (result) {
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") +

Check warning on line 72 in onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h#L72

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h:72:  Add #include <string> for string  [build/include_what_you_use] [4]
(_cudaGetErrorEnum(result)) + " " + file +
":" + std::to_string(line) + " \n");
}
}

#define check_cuda_error(val) fastertransformer::check((val), #val, __FILE__, __LINE__)

} // namespace fastertransformer
49 changes: 49 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cuda_runtime_api.h>

#include "cutlass/device_kernel.h"
#include "common.h"

Check warning on line 21 in onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h#L21

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h:21:  Include the directory when naming header files  [build/include_subdir] [4]

namespace fastertransformer {

template <typename GemmKernel>
inline int compute_occupancy_for_kernel() {
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

Check warning on line 27 in onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h#L27

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h:27:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]

if (smem_size > (48 << 10)) {
cudaError_t status =
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (status == cudaError::cudaErrorInvalidValue) {
// Clear the error bit since we can ignore this.
// This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
// occupancy of 0. This will cause the heuristic to ignore this configuration.
status = cudaGetLastError();
return 0;
}
check_cuda_error(status);
}

int max_active_blocks = -1;
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));

return max_active_blocks;
}

} // namespace fastertransformer
195 changes: 195 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cutlass_heuristic.h"

Check warning on line 17 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc#L17

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:17:  Include the directory when naming header files  [build/include_subdir] [4]

#include <cuda_runtime_api.h>
#include <vector>
#include <stdexcept>

namespace fastertransformer {

struct TileShape {
int m;
int n;
};

TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) {
switch (tile_config) {
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
return TileShape{64, 128};
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
return TileShape{128, 128};
default:
throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config");
}
}

bool is_valid_split_k_factor(const int64_t m,
const int64_t n,
const int64_t k,
const TileShape tile_shape,
const int split_k_factor,
const size_t workspace_bytes,
const bool is_weight_only) {
// All tile sizes have a k_tile of 64.
static constexpr int k_tile = 64;

// For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
if (is_weight_only) {
if ((k % k_tile) != 0) {
return false;
}

if ((k % split_k_factor) != 0) {
return false;
}

const int k_elements_per_split = k / split_k_factor;
if ((k_elements_per_split % k_tile) != 0) {
return false;
}
}

// Check that the workspace has sufficient space for this split-k factor
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;

if (required_ws_bytes > workspace_bytes) {
return false;
}

return true;
}

std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) {
std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};

std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};

std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};

const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
return simt_configs_only ? simt_configs : allowed_configs;
}

std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);

std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
const int max_stages = sm >= 80 ? 4 : 2;

for (const auto& tile_config : tiles) {
for (int stages = min_stages; stages <= max_stages; ++stages) {
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
candidate_configs.push_back(config);
}
}

return candidate_configs;
}

CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t,
const int split_k_limit,
const size_t workspace_bytes,
const int multi_processor_count,
const int is_weight_only) {
if (occupancies.size() != candidate_configs.size()) {
throw std::runtime_error(
"[FT Error][estimate_best_config_from_occupancies] occpancies and "
"candidate configs vectors must have equal length.");
}

CutlassGemmConfig best_config;
// Score will be [0, 1]. The objective is to minimize this score.
// It represents the fraction of SM resources unused in the last wave.
float config_score = 1.0f;
int config_waves = INT_MAX;
int current_m_tile = 0;

const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (int ii = 0; ii < candidate_configs.size(); ++ii) {
CutlassGemmConfig candidate_config = candidate_configs[ii];
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
int occupancy = occupancies[ii];

if (occupancy == 0) {
continue;
}

// Keep small tile sizes when possible.
if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_shape.m) {

Check warning on line 150 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc#L150

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:150:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
continue;
}

const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;

for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
const int ctas_per_wave = occupancy * multi_processor_count;
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;

const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);

Check warning on line 163 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc#L163

Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:163:  Using deprecated casting style.  Use static_cast<float>(...) instead  [readability/casting] [4]
const float current_score = float(num_waves_total) - num_waves_fractional;

Check warning on line 164 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc#L164

Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:164:  Using deprecated casting style.  Use static_cast<float>(...) instead  [readability/casting] [4]

const float score_slack = 0.1f;
if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {

Check warning on line 167 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc#L167

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:167:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
config_score = current_score;
config_waves = num_waves_total;
SplitKStyle split_style =
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig{
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
current_m_tile = tile_shape.m;
} else if (current_score == config_score && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) {

Check warning on line 175 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc#L175

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:175:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
wangyems marked this conversation as resolved.
Show resolved Hide resolved
// Prefer deeper pipeline or smaller split-k
SplitKStyle split_style =
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig{
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
current_m_tile = tile_shape.m;
config_waves = num_waves_total;
}
}
}
}

if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config.");
}

return best_config;
}

} // namespace fastertransformer
42 changes: 42 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "ft_gemm_configs.h"

Check warning on line 19 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h#L19

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h:19:  Include the directory when naming header files  [build/include_subdir] [4]

#include <cstddef>

Check warning on line 21 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h#L21

Found C++ system header after other header. Should be: cutlass_heuristic.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h:21:  Found C++ system header after other header. Should be: cutlass_heuristic.h, c system, c++ system, other.  [build/include_order] [4]
#include <cstdint>

Check warning on line 22 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h#L22

Found C++ system header after other header. Should be: cutlass_heuristic.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h:22:  Found C++ system header after other header. Should be: cutlass_heuristic.h, c system, c++ system, other.  [build/include_order] [4]
#include <vector>

Check warning on line 23 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h#L23

Found C++ system header after other header. Should be: cutlass_heuristic.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h:23:  Found C++ system header after other header. Should be: cutlass_heuristic.h, c system, c++ system, other.  [build/include_order] [4]

// #include "src/fastertransformer/utils/cuda_utils.h"
wangyems marked this conversation as resolved.
Show resolved Hide resolved

namespace fastertransformer {
wangyems marked this conversation as resolved.
Show resolved Hide resolved

std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);

CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t num_experts,
const int split_k_limit,
const size_t workspace_bytes,
const int multi_processor_count,
const int is_weight_only);

} // namespace fastertransformer
Loading
Loading