Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Nov 3, 2023
1 parent f733207 commit 1d3ca92
Show file tree
Hide file tree
Showing 22 changed files with 171 additions and 88 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,34 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(const T* A,
num_experts,
stream);
break;
case ActivationType::Silu:
run_gemm<EpilogueOpBiasSilu>(A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
total_rows,
gemm_n,
gemm_k,
num_experts,
stream);
break;
case ActivationType::Identity:
run_gemm<EpilogueOpBias>(A,
B,
weight_scales,
biases,
C,
total_rows_before_expert,
total_rows,
gemm_n,
gemm_k,
num_experts,
stream);
break;
case ActivationType::InvalidType:
std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM");
break;
default: {
std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -854,15 +854,22 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows,
const int original_row = blockIdx.x;
const int num_rows = gridDim.x;
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
const T* skip_1_row_ptr = skip_1 + original_row * cols;

const T* skip_1_row_ptr;
if (RESIDUAL_NUM == 1) {
skip_1_row_ptr = skip_1 + original_row * cols;
}
const T* skip_2_row_ptr;
if (RESIDUAL_NUM == 2) {
skip_2_row_ptr = skip_2 + original_row * cols;
}

for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
T thread_output;
if (RESIDUAL_NUM == 1) {
if (RESIDUAL_NUM == 0) {
thread_output = T(0);
}
else if (RESIDUAL_NUM == 1) {

Check warning on line 872 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu#L872

An else should appear on the same line as the preceding } [whitespace/newline] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:872:  An else should appear on the same line as the preceding }  [whitespace/newline] [4]

Check warning on line 872 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu#L872

If an else has a brace on one side, it should have it on both [readability/braces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:872:  If an else has a brace on one side, it should have it on both  [readability/braces] [5]
thread_output = skip_1_row_ptr[tid];
}
else if (RESIDUAL_NUM == 2) {

Check warning on line 875 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu#L875

An else should appear on the same line as the preceding } [whitespace/newline] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:875:  An else should appear on the same line as the preceding }  [whitespace/newline] [4]

Check warning on line 875 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu#L875

If an else has a brace on one side, it should have it on both [readability/braces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:875:  If an else has a brace on one side, it should have it on both  [readability/braces] [5]
Expand All @@ -885,6 +892,32 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows,
}
}

template<typename T>
void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const T* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int num_rows,
const int cols,
const int k,
cudaStream_t stream)
{

Check warning on line 906 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu#L906

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:906:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
const int blocks = num_rows;
const int threads = std::min(cols, 1024);
finalize_moe_routing_kernel<T, 0><<<blocks, threads, 0, stream>>>(expanded_permuted_rows,
reduced_unpermuted_output,
nullptr,
nullptr,
bias,
scales,
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
cols,
k);
}

template<typename T>
void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
Expand Down Expand Up @@ -971,6 +1004,26 @@ template void initialize_moe_routing_kernelLauncher(
const half*, half*, const int*, int*, const int, const int, const int, const int, cudaStream_t);

// ==================== Specializations for final routing ===================================
template void finalize_moe_routing_kernelLauncher(const float*,
float*,
const float*,
const float*,
const int*,
const int*,
const int,
const int,
const int,
cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const half*,
half*,
const half*,
const half*,
const int*,
const int*,
const int,
const int,
const int,
cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const float*,
float*,
const float*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ void initialize_moe_routing_kernelLauncher(const T* unpermuted_input,
const int k,
cudaStream_t stream);

template<typename T>
void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const T* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int num_rows,
const int cols,
const int k,
cudaStream_t stream);

template<typename T>
void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
Expand Down
File renamed without changes.
File renamed without changes.
31 changes: 12 additions & 19 deletions onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "moe.h"

Check warning on line 6 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L6

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:6:  Include the directory when naming header files  [build/include_subdir] [4]
#include "moe_kernel.h"

using namespace onnxruntime::cuda;

Check warning on line 8 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L8

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:8:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
using namespace ::onnxruntime::common;

Check warning on line 9 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L9

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:9:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
Expand All @@ -30,10 +29,6 @@ REGISTER_KERNEL_TYPED(MLFloat16)

using namespace ONNX_NAMESPACE;

Check warning on line 30 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L30

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:30:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

template <typename T>
MoEBlock<T>::MoEBlock(const OpKernelInfo& info) : CudaKernel(info) {
}

template <typename T>
Status MoEBlock<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
Expand All @@ -43,53 +38,54 @@ Status MoEBlock<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc1_experts_bias = context->Input<Tensor>(4);
const Tensor* fc2_experts_bias = context->Input<Tensor>(5);

// Shape
const auto& input_dims = input->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();

const int64_t num_rows = input_dims[0];
const int64_t hidden_size = input_dims[1];
const int64_t num_experts = fc1_experts_weights_dims[0];
const int64_t inter_size = fc1_experts_weights_dims[2];
const int64_t k = 1;

typedef typename ToCudaType<T>::MappedType CudaT;
auto stream = context->GetComputeStream();

fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner;

size_t ws_size = moe_runner.getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k);
size_t fc2_output_size = k * num_rows * hidden_size * sizeof(CudaT);
size_t expert_scales_size = k * num_rows * sizeof(CudaT);
size_t expanded_source_row_to_expanded_dest_row_size = k * num_rows * sizeof(int);
size_t expert_for_source_row_size = k * num_rows * sizeof(int);
size_t ws_size = moe_runner.getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k_);
size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * num_rows * sizeof(CudaT);
size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int);
size_t expert_for_source_row_size = k_ * num_rows * sizeof(int);

//TODO: check shape

Check warning on line 60 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L60

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:60:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]

Check warning on line 60 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L60

Should have a space between // and comment [whitespace/comments] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:60:  Should have a space between // and comment  [whitespace/comments] [4]

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

//TODO: allocate once and reuse

Check warning on line 65 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L65

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:65:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]

Check warning on line 65 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L65

Should have a space between // and comment [whitespace/comments] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:65:  Should have a space between // and comment  [whitespace/comments] [4]
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, ws_size, false, stream);
IAllocatorUniquePtr<void> fc2_output = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
IAllocatorUniquePtr<void> expert_scales = IAllocator::MakeUniquePtr<void>(allocator, expert_scales_size, false, stream);

Check warning on line 68 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L68

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:68:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
IAllocatorUniquePtr<void> expanded_source_row_to_expanded_dest_row = IAllocator::MakeUniquePtr<void>(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream);

Check warning on line 69 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L69

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:69:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
IAllocatorUniquePtr<void> expert_for_source_row = IAllocator::MakeUniquePtr<void>(allocator, expert_for_source_row_size, false, stream);

Check warning on line 70 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L70

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:70:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// fc1_scales and fc2_scales are used in quantized MoE
const CudaT* fc1_scales_ptr = nullptr;
const CudaT* fc2_scales_ptr = nullptr;

// bugbug: use a string to select from different activationType
moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(gated_output->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
std::move(fc1_scales_ptr),
reinterpret_cast<const CudaT*>(fc1_experts_bias->template Data<T>()),
fastertransformer::ActivationType::Gelu,
activation_type_,
reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
std::move(fc2_scales_ptr),

Check warning on line 83 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L83

Add #include <utility> for move [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:83:  Add #include <utility> for move  [build/include_what_you_use] [4]
static_cast<int>(num_rows),
static_cast<int>(hidden_size),
static_cast<int>(inter_size),
static_cast<int>(num_experts),
static_cast<int>(k),
static_cast<int>(k_),
reinterpret_cast<char*>(work_space.get()),
reinterpret_cast<CudaT*>(fc2_output.get()),
reinterpret_cast<CudaT*>(expert_scales.get()),
Expand All @@ -99,18 +95,15 @@ Status MoEBlock<T>::ComputeInternal(OpKernelContext* context) const {

Tensor* output = context->Output(0, input->Shape());

// bugbug: support no skip in moe_kernel
IAllocatorUniquePtr<void> skip_layer = IAllocator::MakeUniquePtr<void>(allocator, num_rows * hidden_size * sizeof(T), false, stream);
fastertransformer::finalize_moe_routing_kernelLauncher(reinterpret_cast<CudaT*>(fc2_output.get()),
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
reinterpret_cast<CudaT*>(skip_layer.get()),
reinterpret_cast<const CudaT*>(fc2_experts_bias->template Data<T>()),

Check warning on line 100 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L100

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:100:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),

Check warning on line 102 in onnxruntime/contrib_ops/cuda/moe/moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.cc#L102

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.cc:102:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
reinterpret_cast<int*>(expert_for_source_row.get()),
static_cast<int>(num_rows),
static_cast<int>(hidden_size),
static_cast<int>(k),
static_cast<int>(k_),
Stream(context));

return Status::OK();
Expand Down
24 changes: 23 additions & 1 deletion onnxruntime/contrib_ops/cuda/moe/moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"

Expand All @@ -14,8 +16,28 @@ using namespace onnxruntime::cuda;
template <typename T>
class MoEBlock final : public CudaKernel {
public:
MoEBlock(const OpKernelInfo& op_kernel_info);
explicit MoEBlock(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info){

Check warning on line 19 in onnxruntime/contrib_ops/cuda/moe/moe.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.h#L19

Missing space before { [whitespace/braces] [5]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.h:19:  Missing space before {  [whitespace/braces] [5]
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());

std::string activation_type_str;
ORT_ENFORCE(op_kernel_info.GetAttr<std::string>("activation_type", &activation_type_str).IsOK());

Check warning on line 23 in onnxruntime/contrib_ops/cuda/moe/moe.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/moe.h#L23

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/moe/moe.h:23:  Add #include <string> for string  [build/include_what_you_use] [4]
if (activation_type_str == "relu") {
activation_type_ = fastertransformer::ActivationType::Relu;
} else if (activation_type_str == "gelu") {
activation_type_ = fastertransformer::ActivationType::Gelu;
} else if (activation_type_str == "silu") {
activation_type_ = fastertransformer::ActivationType::Silu;
} else if (activation_type_str == "identity") {
activation_type_ = fastertransformer::ActivationType::Identity;
} else {
ORT_THROW("Unsupported MoE activation type: ", activation_type_str);
}
}
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
int64_t k_;
fastertransformer::ActivationType activation_type_;
};

} // namespace cuda
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1378,15 +1378,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
ONNX_MS_OPERATOR_SET_SCHEMA(MoEBlock, 1,
OpSchema()
.SetDoc("Mixture of experts.")
//.Attr("expert_start_idx", "Not implemented", AttributeProto::INT, static_cast<int64_t>(-1))
//.Attr("expert_end_idx", "Not implemented", AttributeProto::INT, static_cast<int64_t>(-1))
//.Attr("k", "Not implemented", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("activation_type", "Activation function to use", AttributeProto::STRING, std::string("relu"))

Check warning on line 1381 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1381

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1381:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast<int64_t>(1))

Check warning on line 1382 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1382

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1382:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Input(0, "input", "2D input tensor with shape (num_rows, hidden_size)", "T")
.Input(1, "gated_output", "2D input tensor with shape (num_rows, num_experts)", "T")
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")

Check warning on line 1385 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1385

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1385:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T")

Check warning on line 1386 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1386

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1386:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
.Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional)
.Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T")

Check warning on line 1387 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1387

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1387:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T")

Check warning on line 1388 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1388

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1388:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Output(0, "output", "3D input tensor with shape (num_rows, hidden_size)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.")

Check warning on line 1390 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1390

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1390:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"Min": self._infer_symbolic_compute_ops,
"MoEBlock": self._pass_on_shape_and_type,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
"NonZero": self._infer_NonZero,
Expand Down
Loading

0 comments on commit 1d3ca92

Please sign in to comment.