diff --git a/onnxruntime/contrib_ops/cuda/moe/common.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/common.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h diff --git a/onnxruntime/contrib_ops/cuda/moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/compute_occupancy.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/cutlass_heuristic.cc rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/cutlass_heuristic.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h diff --git a/onnxruntime/contrib_ops/cuda/moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/epilogue_helpers.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/ft_gemm_configs.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h diff --git a/onnxruntime/contrib_ops/cuda/moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/gemm_moe_problem_visitor.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h diff --git a/onnxruntime/contrib_ops/cuda/moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/layout_traits_helper.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/moe_cutlass_kernel.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_fp16_fp16.cu rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_fp32_fp32.cu rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h similarity index 97% rename from onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_template.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 4c00f53b67177..c68ad97bb968c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -826,7 +826,34 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, num_experts, stream); break; + case ActivationType::Silu: + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + stream); + break; + case ActivationType::Identity: + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + stream); + break; case ActivationType::InvalidType: + std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); break; default: { std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu similarity index 94% rename from onnxruntime/contrib_ops/cuda/moe/moe_kernel.cu rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 6a02c7172e545..d6693cde103d8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -854,7 +854,11 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, const int original_row = blockIdx.x; const int num_rows = gridDim.x; T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; - const T* skip_1_row_ptr = skip_1 + original_row * cols; + + const T* skip_1_row_ptr; + if (RESIDUAL_NUM == 1) { + skip_1_row_ptr = skip_1 + original_row * cols; + } const T* skip_2_row_ptr; if (RESIDUAL_NUM == 2) { skip_2_row_ptr = skip_2 + original_row * cols; @@ -862,7 +866,10 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { T thread_output; - if (RESIDUAL_NUM == 1) { + if (RESIDUAL_NUM == 0) { + thread_output = T(0); + } + else if (RESIDUAL_NUM == 1) { thread_output = skip_1_row_ptr[tid]; } else if (RESIDUAL_NUM == 2) { @@ -885,6 +892,32 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, } } +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int num_rows, + const int cols, + const int k, + cudaStream_t stream) +{ + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel<<>>(expanded_permuted_rows, + reduced_unpermuted_output, + nullptr, + nullptr, + bias, + scales, + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + cols, + k); +} + template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, @@ -971,6 +1004,26 @@ template void initialize_moe_routing_kernelLauncher( const half*, half*, const int*, int*, const int, const int, const int, const int, cudaStream_t); // ==================== Specializations for final routing =================================== +template void finalize_moe_routing_kernelLauncher(const float*, + float*, + const float*, + const float*, + const int*, + const int*, + const int, + const int, + const int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, + half*, + const half*, + const half*, + const int*, + const int*, + const int, + const int, + const int, + cudaStream_t); template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h similarity index 92% rename from onnxruntime/contrib_ops/cuda/moe/moe_kernel.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index e22f008a85add..44d5d3ea900e4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -91,6 +91,18 @@ void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, const int k, cudaStream_t stream); +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int num_rows, + const int cols, + const int k, + cudaStream_t stream); + template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/moe_problem_visitor.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h diff --git a/onnxruntime/contrib_ops/cuda/moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/moe/tile_interleaved_layout.h rename to onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 46ece6147277b..4c7c65fe572f5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -4,7 +4,6 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" -#include "moe_kernel.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -30,10 +29,6 @@ REGISTER_KERNEL_TYPED(MLFloat16) using namespace ONNX_NAMESPACE; -template -MoEBlock::MoEBlock(const OpKernelInfo& info) : CudaKernel(info) { -} - template Status MoEBlock::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); @@ -43,7 +38,6 @@ Status MoEBlock::ComputeInternal(OpKernelContext* context) const { const Tensor* fc1_experts_bias = context->Input(4); const Tensor* fc2_experts_bias = context->Input(5); - // Shape const auto& input_dims = input->Shape().GetDims(); const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); @@ -51,45 +45,47 @@ Status MoEBlock::ComputeInternal(OpKernelContext* context) const { const int64_t hidden_size = input_dims[1]; const int64_t num_experts = fc1_experts_weights_dims[0]; const int64_t inter_size = fc1_experts_weights_dims[2]; - const int64_t k = 1; typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); fastertransformer::CutlassMoeFCRunner moe_runner; - size_t ws_size = moe_runner.getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k); - size_t fc2_output_size = k * num_rows * hidden_size * sizeof(CudaT); - size_t expert_scales_size = k * num_rows * sizeof(CudaT); - size_t expanded_source_row_to_expanded_dest_row_size = k * num_rows * sizeof(int); - size_t expert_for_source_row_size = k * num_rows * sizeof(int); + size_t ws_size = moe_runner.getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k_); + size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); + + //TODO: check shape AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + //TODO: allocate once and reuse IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); IAllocatorUniquePtr expert_scales = IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); IAllocatorUniquePtr expert_for_source_row = IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + // fc1_scales and fc2_scales are used in quantized MoE const CudaT* fc1_scales_ptr = nullptr; const CudaT* fc2_scales_ptr = nullptr; - // bugbug: use a string to select from different activationType moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), reinterpret_cast(gated_output->template Data()), reinterpret_cast(fc1_experts_weights->template Data()), std::move(fc1_scales_ptr), reinterpret_cast(fc1_experts_bias->template Data()), - fastertransformer::ActivationType::Gelu, + activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), std::move(fc2_scales_ptr), static_cast(num_rows), static_cast(hidden_size), static_cast(inter_size), static_cast(num_experts), - static_cast(k), + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), @@ -99,18 +95,15 @@ Status MoEBlock::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); - // bugbug: support no skip in moe_kernel - IAllocatorUniquePtr skip_layer = IAllocator::MakeUniquePtr(allocator, num_rows * hidden_size * sizeof(T), false, stream); fastertransformer::finalize_moe_routing_kernelLauncher(reinterpret_cast(fc2_output.get()), reinterpret_cast(output->template MutableData()), - reinterpret_cast(skip_layer.get()), reinterpret_cast(fc2_experts_bias->template Data()), reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), reinterpret_cast(expert_for_source_row.get()), static_cast(num_rows), static_cast(hidden_size), - static_cast(k), + static_cast(k_), Stream(context)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 11512b82285e6..dfa5437413685 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -2,6 +2,8 @@ // Licensed under the MIT License. #pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" #include "core/common/common.h" #include "core/providers/cuda/cuda_kernel.h" @@ -14,8 +16,28 @@ using namespace onnxruntime::cuda; template class MoEBlock final : public CudaKernel { public: - MoEBlock(const OpKernelInfo& op_kernel_info); + explicit MoEBlock(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info){ + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = fastertransformer::ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = fastertransformer::ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = fastertransformer::ActivationType::Identity; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + } Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int64_t k_; + fastertransformer::ActivationType activation_type_; }; } // namespace cuda diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index f0a1013319b96..b2b6c0484629a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1378,15 +1378,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, ONNX_MS_OPERATOR_SET_SCHEMA(MoEBlock, 1, OpSchema() .SetDoc("Mixture of experts.") - //.Attr("expert_start_idx", "Not implemented", AttributeProto::INT, static_cast(-1)) - //.Attr("expert_end_idx", "Not implemented", AttributeProto::INT, static_cast(-1)) - //.Attr("k", "Not implemented", AttributeProto::INT, static_cast(1)) + .Attr("activation_type", "Activation function to use", AttributeProto::STRING, std::string("relu")) + .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size)", "T") .Input(1, "gated_output", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") - .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) - .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T") + .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T") .Output(0, "output", "3D input tensor with shape (num_rows, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef1c46b83946a..22a7fcd408d5e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -154,6 +154,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "MaxPool": self._infer_Pool, "Max": self._infer_symbolic_compute_ops, "Min": self._infer_symbolic_compute_ops, + "MoEBlock": self._pass_on_shape_and_type, "Mul": self._infer_symbolic_compute_ops, "NonMaxSuppression": self._infer_NonMaxSuppression, "NonZero": self._infer_NonZero, diff --git a/onnxruntime/test/python/transformers/test_parity_moe_block.py b/onnxruntime/test/python/transformers/test_parity_moe_block.py index df3227a430e06..2e7e9b30e8c9b 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe_block.py +++ b/onnxruntime/test/python/transformers/test_parity_moe_block.py @@ -21,8 +21,6 @@ import torch.nn as nn import torch.nn.functional as F -np.random.seed(0) -torch.manual_seed(0) def create_moe_onnx_graph( num_rows, @@ -33,7 +31,6 @@ def create_moe_onnx_graph( fc2_experts_weights, fc1_experts_bias, fc2_experts_bias, - has_bias=True, ): from onnx import TensorProto, helper @@ -45,11 +42,13 @@ def create_moe_onnx_graph( "gated_output", "fc1_experts_weights", "fc2_experts_weights", - "fc1_experts_bias" if has_bias else "", - "fc2_experts_bias" if has_bias else "", + "fc1_experts_bias", + "fc2_experts_bias", ], ["output"], "MoEBlock_0", + k=1, + activation_type="gelu", domain="com.microsoft", ), ] @@ -74,34 +73,29 @@ def create_moe_onnx_graph( ), ] - fc1_bias_shape = None - fc2_bias_shape = None - if has_bias: - fc1_bias_shape = [num_experts, inter_size] - fc2_bias_shape = [num_experts, hidden_size] - initializers.extend( - [ - helper.make_tensor( - "fc1_experts_bias", - TensorProto.FLOAT16, - fc1_bias_shape, - fc1_experts_bias.to(torch.float16).flatten().tolist(), - raw=False, - ), - helper.make_tensor( - "fc2_experts_bias", - TensorProto.FLOAT16, - fc2_bias_shape, - fc2_experts_bias.to(torch.float16).flatten().tolist(), - raw=False, - ), - ] - ) + fc1_bias_shape = [num_experts, inter_size] + fc2_bias_shape = [num_experts, hidden_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_bias", + TensorProto.FLOAT16, + fc1_bias_shape, + fc1_experts_bias.to(torch.float16).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + TensorProto.FLOAT16, + fc2_bias_shape, + fc2_experts_bias.to(torch.float16).flatten().tolist(), + raw=False, + ), + ] + ) graph_inputs = [ - helper.make_tensor_value_info( - "input", TensorProto.FLOAT16, [num_rows, hidden_size] - ), + helper.make_tensor_value_info("input", TensorProto.FLOAT16, [num_rows, hidden_size]), ] graph_inputs.append( @@ -113,9 +107,7 @@ def create_moe_onnx_graph( ) graph_outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT16, [num_rows, hidden_size] - ), + helper.make_tensor_value_info("output", TensorProto.FLOAT16, [num_rows, hidden_size]), ] graph = helper.make_graph( @@ -138,9 +130,7 @@ def onnx_inference( sess_options = SessionOptions() sess_options.log_severity_level = 2 - ort_session = InferenceSession( - onnx_model_path, sess_options, providers=["CUDAExecutionProvider"] - ) + ort_session = InferenceSession(onnx_model_path, sess_options, providers=["CUDAExecutionProvider"]) ort_output = ort_session.run(None, ort_inputs) return ort_output @@ -189,7 +179,7 @@ def __init__( out_features=None, act_layer=nn.GELU, drop=0.0, - bias=False, + bias=True, chunk_size=-1, ): super().__init__() @@ -197,19 +187,11 @@ def __init__( assert drop == 0.0, "Current drop is not supported" assert chunk_size == -1, "Current chunk is not supported" - self.weight1 = nn.Parameter( - torch.Tensor(num_experts, in_features, hidden_features) - ) - self.weight2 = nn.Parameter( - torch.Tensor(num_experts, hidden_features, out_features) - ) + self.weight1 = nn.Parameter(torch.Tensor(num_experts, in_features, hidden_features)) + self.weight2 = nn.Parameter(torch.Tensor(num_experts, hidden_features, out_features)) - self.bias1 = ( - nn.Parameter(torch.Tensor(num_experts, hidden_features)) if bias else None - ) - self.bias2 = ( - nn.Parameter(torch.Tensor(num_experts, in_features)) if bias else None - ) + self.bias1 = nn.Parameter(torch.Tensor(num_experts, hidden_features)) if bias else None + self.bias2 = nn.Parameter(torch.Tensor(num_experts, in_features)) if bias else None self.act = act_layer() @@ -279,15 +261,13 @@ def torch_forward(self): logits = self.gate(x) gates = torch.nn.functional.softmax(logits, dim=1) ret = torch.max(gates, dim=1) - indices_s = ( - ret.indices - ) # dim: [bs], the index of the expert with highest softmax value + indices_s = ret.indices # dim: [bs], the index of the expert with highest softmax value scores = ret.values.unsqueeze(-1).unsqueeze(-1) # S x = self.moe_experts(x, indices_s) x = x * scores x = x.reshape(B, T, C) - print(x) + #print(x) return x, torch.sum(x) def onnx_forward(self): @@ -299,14 +279,10 @@ def onnx_forward(self): ort_inputs = { "input": numpy.ascontiguousarray(y.detach().numpy().astype(numpy.float16)), - "gated_output": numpy.ascontiguousarray( - logits.detach().numpy().astype(numpy.float16) - ), + "gated_output": numpy.ascontiguousarray(logits.detach().numpy().astype(numpy.float16)), } - ort_output = onnx_inference( - self.moe_onnx_graph, ort_inputs - ) - print(ort_output) + ort_output = onnx_inference(self.moe_onnx_graph, ort_inputs) + #print(ort_output) return ort_output