From 48fb8a7e56a7263a8405dc644756eb5c55560352 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Sat, 27 Jul 2024 11:10:52 -0700 Subject: [PATCH 01/40] Security fuzz address sanitizer fix Bug #2 and #3 (#21528) ### Description Security fuzz test with address sanitizer found several bugs --- onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc | 2 ++ onnxruntime/core/optimizer/attention_fusion.cc | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 34a1da99316a2..030cdb1e1b17f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -143,6 +143,8 @@ Status GptSubgraph::Validate(const std::vector& subgraph_inputs, // Past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads). const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape == nullptr, + "subgraph past state cannot be nullptr"); ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size()); diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 08066f030a381..64a38214caff0 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -210,7 +210,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && - graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.InputDefs().size() > 2) { // Get hidden size from layer norm bias tensor shape. const NodeArg& layer_norm_bias = *(node.InputDefs()[2]); if (!optimizer_utils::IsShapeKnownOnAllDims(layer_norm_bias, 1)) { From 82b2955268e14f26eb71ad2d660452ab8db454d7 Mon Sep 17 00:00:00 2001 From: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Date: Sat, 27 Jul 2024 23:47:22 +0530 Subject: [PATCH 02/40] [AIX]test failure fix using gtest-1.15.0 for AIX (#21497) ### Description Local CI setup for AIX reported tests failure after the gtest 1.15.0 upgrade. ### Motivation and Context Below tests failure is observed after gtest upgrade. The following tests FAILED: 1 - onnxruntime_test_all (ILLEGAL) 7 - onnxruntime_logging_apis_test (Subprocess aborted) To fix this, I am enabling pthread support under gtest. This was disabled with previous version of gtest for some reason. Now by enabling this, above tests are getting passed with gtest 1.15.0. --- cmake/external/onnxruntime_external_deps.cmake | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 14e6ed515fd6e..775576a771529 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -46,9 +46,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(gtest_disable_pthreads ON) endif() - if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(gtest_disable_pthreads ON CACHE BOOL "gtest_disable_pthreads" FORCE) - endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) if (IOS OR ANDROID) # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing From 7e23212de9746ed2452061958f8aae3ffc171cee Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Sat, 27 Jul 2024 15:58:12 -0700 Subject: [PATCH 03/40] Delete tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml (#21529) ### Description Delete tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml ### Motivation and Context This CI pipeline has been divided into 4 different pipeline. --- .../azure-pipelines/win-gpu-ci-pipeline.yml | 125 ------------------ 1 file changed, 125 deletions(-) delete mode 100644 tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml deleted file mode 100644 index c5262880c4c55..0000000000000 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ /dev/null @@ -1,125 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -parameters: -- name: CudaVersion - displayName: CUDA version - type: string - default: '12.2' - values: - - 11.8 - - 12.2 -- name: RunOnnxRuntimeTests - displayName: Run Tests? - type: boolean - default: true - -stages: -- stage: cuda - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - additionalBuildFlags: >- - --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" - --enable_cuda_profiling --enable_transformers_tool_test - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON - --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: CUDA - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 - -- stage: training - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - additionalBuildFlags: >- - --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" - --skip_onnx_tests - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: CUDA - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 - isTraining: true - -- stage: dml - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat - buildArch: x64 - additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: DML - WITH_CACHE: false - MachinePool: onnxruntime-Win2022-GPU-dml-A10 - -- stage: kernelDocumentation - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags - additionalBuildFlags: >- - --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda - --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: false - GenerateDocumentation: true - ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 From a4d3a1ce0c18e1d1b31a9cc0b45beba290ee114c Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Sat, 27 Jul 2024 15:58:36 -0700 Subject: [PATCH 04/40] pick changes from https://github.com/onnx/onnx/pull/6195 to fix heap-buffer-overflow in onnx::convPoolShapeInference (#21507) ### Description onnx 1.16.2 is not available before ort 1.19.0 code freeze. Thus pick the needed change as patch --- cmake/patches/onnx/onnx.patch | 383 ++++++++++++++++++ .../providers/cpu/generator/random_test.cc | 8 +- .../core/graph/training_op_defs.cc | 104 +++-- 3 files changed, 447 insertions(+), 48 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 162d33581a5ca..6ac3555eeecf1 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -86,3 +86,386 @@ index 0aab3e26..398ac2d6 100644 +#endif + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc +index c315a2a7..58963154 100644 +--- a/onnx/defs/math/defs.cc ++++ b/onnx/defs/math/defs.cc +@@ -3472,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA( + } + + auto& input_shape = getInputShape(ctx, 0); ++ if (input_shape.dim_size() < 2) { ++ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), "."); ++ } + auto signal_dim = input_shape.dim(1); + if (!signal_dim.has_dim_value()) { + return; +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index be6a851d..fad595d0 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -126,6 +126,9 @@ void convPoolShapeInference( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +@@ -959,19 +962,21 @@ ONNX_OPERATOR_SET_SCHEMA( + auto w_type = ctx.getInputType(3); + if (nullptr == x_type || nullptr == w_type || x_type->value_case() != TypeProto::kTensorType || + w_type->value_case() != TypeProto::kTensorType) { +- fail_type_inference("inputs are expected to have tensor type."); ++ fail_type_inference("inputs are expected to have tensor type in ", ctx.getDisplayName(), "."); + } + + auto x_zero_point_type = ctx.getInputType(2); + if (nullptr == x_zero_point_type || + x_zero_point_type->tensor_type().elem_type() != x_type->tensor_type().elem_type()) { +- fail_type_inference("input and zero_point pair is expected to have be same type."); ++ fail_type_inference( ++ "input and zero_point pair is expected to have be same type in ", ctx.getDisplayName(), "."); + } + + auto w_zero_point_type = ctx.getInputType(5); + if (nullptr == w_zero_point_type || + w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) { +- fail_type_inference("weight and zero_point pair is expected to have same type."); ++ fail_type_inference( ++ "weight and zero_point pair is expected to have same type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromInputToOutput(ctx, 7, 0); +@@ -2647,7 +2652,8 @@ ONNX_OPERATOR_SET_SCHEMA( + if (!hasNInputShapes(ctx, 1)) { + return; + } +- auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); ++ ++ auto& input_shape = getInputShape(ctx, 0); + int64_t input_ndim = input_shape.dim_size(); + int64_t axis = -1; + auto axis_proto = ctx.getAttribute("axis"); +@@ -2659,7 +2665,16 @@ ONNX_OPERATOR_SET_SCHEMA( + // positive value. + axis += input_ndim; + } +- ++ if (axis < 0) { ++ fail_shape_inference( ++ "Unexpected axis value (", ++ axis, ++ ") rank of first input is ", ++ input_ndim, ++ " in ", ++ ctx.getDisplayName(), ++ "."); ++ } + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index 57f8e2a4..8b2dc07f 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -201,6 +201,9 @@ void convPoolShapeInference_opset19( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h +index a80473b3..d1bcd401 100644 +--- a/onnx/defs/shape_inference.h ++++ b/onnx/defs/shape_inference.h +@@ -105,6 +105,10 @@ struct InferenceContext { + virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0; + // Gets the shape inputs computed by partial data propagation. + virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0; ++ // To display a name the user can use to narrow its search. ++ virtual std::string getDisplayName() const { ++ return ""; ++ } + }; + + // We use data propagation to perform partial evaluation of the model, to compute statically +@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput( + } else { + // This is not expected to happen + fail_type_inference( +- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case); ++ "Output ", ++ outputIndex, ++ " expected to have: ", ++ expected_value_case, ++ " or UNDEFINED. Got: ", ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr + const auto attr_type = attr->type(); + if (attr_type == AttributeProto::TENSOR) { + if (attr->t().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim tensor"); ++ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->t().data_type(); + expected_value_case = TypeProto::kTensorType; + } else if (attr_type == AttributeProto::SPARSE_TENSOR) { + if (attr->sparse_tensor().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim sparse tensor"); ++ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->sparse_tensor().values().data_type(); + expected_value_case = TypeProto::kSparseTensorType; + } else { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case); +@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t + const auto* input_type = ctx.getInputType(n); + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); ++ } ++ if (!hasShape(*input_type)) { ++ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return input_type->tensor_type().shape(); +@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size + + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return &input_type->tensor_type().shape(); +@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + " does not match type of output: ", + outputIndex, + "type: ", +- output_value_case); ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + if (TypeProto::kTensorType == input_value_case) { + auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim(); +@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + *dim = input_type->sparse_tensor_type().shape().dim(static_cast(fromDimIndex)); + } else { + fail_type_inference( +- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type"); ++ "Input ", ++ inputIndex, ++ " and Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType + setTensorElementType(elemType, expected_type, *output_type); + } else { + // This is not expected to happen +- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type); ++ fail_type_inference( ++ "Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type: ", ++ expected_type, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput( + updateOutputElemType(ctx, outputIndex, default_value, expected_type); + return; + } else { +- fail_type_inference("Value of attribute ", attributeName, " not specified"); ++ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), "."); + } + } + if (!attr_proto->has_i()) { +- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type."); ++ fail_type_inference( ++ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), "."); + } + auto attr_value = attr_proto->i(); + auto elem_type = static_cast(attr_value); + if (!TensorProto_DataType_IsValid(elem_type)) { +- fail_type_inference("Attribute ", attributeName, " does not specify a valid type."); ++ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), "."); + } + updateOutputElemType(ctx, outputIndex, elem_type, expected_type); + } +@@ -497,7 +529,7 @@ inline TensorShapeProto* + getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) { + auto output_type = ctx.getOutputType(n); + if (output_type == nullptr) { +- fail_type_inference("Output ", n, " expected to have tensor or sparse type"); ++ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), "."); + } + const auto output_value_case = output_type->value_case(); + if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { +@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ + } else if (output_value_case == TypeProto::VALUE_NOT_SET) { + return getTensorMutableShape(default_type, *output_type); + } else { +- fail_type_inference("Output ", n, " expected to have tensor type"); ++ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), "."); + } + } + +@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput( + auto attr_proto = ctx.getAttribute(attributeName); + if ((nullptr == attr_proto) || (!attr_proto->has_type()) || + (attr_proto->type() != AttributeProto_AttributeType_INTS)) { +- fail_shape_inference("Attribute ", attributeName, " should specify a shape"); ++ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), "."); + } + auto& int_list = attr_proto->ints(); + TensorShapeProto shape; + for (auto dim_size : int_list) { + if (dim_size < 0) { +- fail_shape_inference("Negative values are not allowed in a shape specification"); ++ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), "."); + } + shape.add_dim()->set_dim_value(dim_size); + } +@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect + if (hasInputShape(ctx, input_index)) { + auto rank = getInputShape(ctx, input_index).dim_size(); + if (rank != expected_rank) { +- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank); ++ fail_shape_inference( ++ "Input ", ++ input_index, ++ " expected to have rank ", ++ expected_rank, ++ " but has rank ", ++ rank, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + } +@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind + // This shape is expected to have rank > dim_index: + if (input_shape.dim_size() <= dim_index) { + fail_shape_inference( +- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size()); ++ "Input ", ++ input_index, ++ " expected to have rank >", ++ dim_index, ++ " but has rank ", ++ input_shape.dim_size(), ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + const Dim& input_dim = input_shape.dim(dim_index); + // Now, unify dim and input_dim: +diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc +index 8723dcd4..8249fc59 100644 +--- a/onnx/shape_inference/implementation.cc ++++ b/onnx/shape_inference/implementation.cc +@@ -906,7 +906,7 @@ struct FunctionInferenceContext : public InferenceContext { + const std::vector& input_types, + const std::vector& attributes, + const ShapeInferenceOptions& options) +- : input_types_(input_types), options_(options) { ++ : input_types_(input_types), options_(options), func_proto_(&func_proto) { + for (const auto& attr : attributes) { + attributesByName_[attr.name()] = &attr; + } +@@ -971,11 +971,25 @@ struct FunctionInferenceContext : public InferenceContext { + return std::move(output_types_); + } + ++ std::string getDisplayName() const override { ++ if (func_proto_ == nullptr) ++ return ""; ++ if (func_proto_->domain().empty()) { ++ if (func_proto_->name().empty()) ++ return ""; ++ return MakeString("function ", func_proto_->name()); ++ } ++ if (func_proto_->name().empty()) ++ return MakeString("function [", func_proto_->domain(), "]"); ++ return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]"); ++ } ++ + private: + const std::vector& input_types_; + std::vector output_types_; + std::unordered_map attributesByName_; + ShapeInferenceOptions options_; ++ const FunctionProto* func_proto_; + }; + + std::vector InferFunctionOutputTypes( +diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h +index 2c63c910..b0e4c32d 100644 +--- a/onnx/shape_inference/implementation.h ++++ b/onnx/shape_inference/implementation.h +@@ -146,7 +146,7 @@ struct InferenceContextImpl : public InferenceContext { + const ShapeInferenceOptions& options, + DataValueMap* generatedShapeData = nullptr, + GraphInferenceContext* graphInferenceContext = nullptr) +- : graphInferenceContext_{graphInferenceContext}, options_(options) { ++ : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) { + for (auto& attr : *n.mutable_attribute()) { + attributesByName_[attr.name()] = &attr; + if (attr.has_g()) { +@@ -277,6 +277,19 @@ struct InferenceContextImpl : public InferenceContext { + return inferencer; + } + ++ std::string getDisplayName() const override { ++ if (node_ == nullptr) ++ return ""; ++ if (node_->domain().empty()) { ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type()); ++ return MakeString("node ", node_->op_type(), " (", node_->name(), ")"); ++ } ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]"); ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")"); ++ } ++ + std::vector allInputData_; + std::vector allInputSparseData_; + std::vector allShapeInputData_; +@@ -289,6 +302,7 @@ struct InferenceContextImpl : public InferenceContext { + // mutable as internal cache of GraphInferencer instances + mutable std::unordered_map> graphAttributeInferencers_; + ShapeInferenceOptions options_; ++ NodeProto* node_; + }; + + struct DataPropagationContextImpl : public DataPropagationContext { diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index ec9b1614488a7..f42f32d63d1fa 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -178,7 +178,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormal) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -194,7 +194,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniform) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -210,7 +210,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormalLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -226,7 +226,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniformLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 2a8d2de982e79..92f803030ada4 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -181,6 +181,64 @@ static void propagateRecvOutputTensorElemTypes( } } +void SendShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() < 3) { + fail_shape_inference("Send must have at least three inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Send must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Send must be a scalar."); + } + } + + checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); + } + + if (ctx.getNumOutputs() != 1) { + fail_shape_inference("Send must have one output."); + } + + auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_element_type->set_elem_type(TensorProto::BOOL); + ONNX_NAMESPACE::TensorShapeProto output_shape; + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); +} + +void RecvShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() != 2) { + fail_shape_inference("Recv must have two inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Recv must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Recv must be a scalar."); + } + } + } + + if (ctx.getNumOutputs() < 2) { + fail_shape_inference("Recv must have at least two outputs."); + } + + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); +} + TensorProto ToDimensionOneFloatTensor(float value) { auto t = ToTensor(std::vector({value})); t.add_dims(1); @@ -3388,30 +3446,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() < 3) { - fail_shape_inference("Send must have at least three inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Send must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Send must be a scalar."); - } - - checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); - } - - if (ctx.getNumOutputs() != 1) { - fail_shape_inference("Send must have one output."); - } - - auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); - output_element_type->set_elem_type(TensorProto::BOOL); - ONNX_NAMESPACE::TensorShapeProto output_shape; - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + SendShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(Recv) @@ -3437,26 +3472,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() != 2) { - fail_shape_inference("Recv must have two inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Recv must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Recv must be a scalar."); - } - } - - if (ctx.getNumOutputs() < 2) { - fail_shape_inference("Recv must have at least two outputs."); - } - - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); + RecvShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(MegatronF) From dbff0cd09860b60bd0a251c1dbe76785b0b2818c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 28 Jul 2024 13:03:17 -0700 Subject: [PATCH 05/40] [js/node] enable float16 support for Node.js binding (#20581) ### Description enable float16 support for Node.js binding. data of float16 tensor uses `Uint16Array`. --- js/node/src/tensor_helper.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 1c0b141e6a44f..1062d89f76c5f 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -38,13 +38,13 @@ constexpr size_t DATA_TYPE_ELEMENT_SIZE_MAP[] = { 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 INT64 not working in Javascript + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING N/A 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 not working in Javascript + 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 UINT64 not working in Javascript + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported 0 // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported @@ -60,13 +60,13 @@ constexpr napi_typedarray_type DATA_TYPE_TYPEDARRAY_MAP[] = { napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 napi_int16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 napi_int32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 INT64 not working i + napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING not supported napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 not working + napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 uses Uint16Array napi_float64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE napi_uint32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 UINT64 not working + napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported (napi_typedarray_type)(-1) // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported @@ -182,9 +182,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo * char *buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); - // there is a bug in TypedArray::ElementSize(): https://github.com/nodejs/node-addon-api/pull/705 - // TODO: change to TypedArray::ByteLength() in next node-addon-api release. - size_t bufferByteLength = tensorDataTypedArray.ElementLength() * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]; + size_t bufferByteLength = tensorDataTypedArray.ByteLength(); return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType); } From 5bc12bf209304e7f5800845bd612bb3e7b7ab918 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Mon, 29 Jul 2024 23:47:41 +0800 Subject: [PATCH 06/40] [js/webgpu] Add activation for conv3d naive (#21466) ### Description ### Motivation and Context --- .../ops/3rd-party/conv3d_naive_webgpu.ts | 64 +++++----- js/web/test/data/ops/fused-conv3dncdhw.jsonc | 112 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 3 files changed, 149 insertions(+), 28 deletions(-) create mode 100644 js/web/test/data/ops/fused-conv3dncdhw.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts index f428293add599..a2e5428385101 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts @@ -26,6 +26,9 @@ import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; + +import {typeSnippet} from './activation_util'; const arrayProduct = (arr: number[]) => { let product = 1; @@ -218,8 +221,8 @@ export const computeConv3DInfo = export const createConv3DNaiveProgramInfo = (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => { - const isChannelsLast = dataFormat === 'channelsLast'; - const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const isChannelLast = dataFormat === 'channelsLast'; + const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1]; // TODO: enable vec4. const isVec4 = false; const workGroupSize: [number, number, number] = [64, 1, 1]; @@ -228,13 +231,14 @@ export const createConv3DNaiveProgramInfo = LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; + const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides}, {type: DataType.uint32, data: attributes.dilations} ]; + appendActivationUniformsData(attributes, programUniforms); programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length === 3; @@ -251,6 +255,7 @@ export const createConv3DNaiveProgramInfo = {name: 'strides', type: 'u32', length: attributes.strides.length}, {name: 'dilations', type: 'u32', length: attributes.dilations.length} ]; + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; const t = tensorTypeToWsglStorageType(inputs[0].dataType); @@ -266,10 +271,12 @@ export const createConv3DNaiveProgramInfo = inputVariables.push(bias); declareFunctions += ` fn getBiasByOutputCoords(coords : array) -> ${isVec4 ? `vec4<${t}>` : t} { - return bias[${isChannelsLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ + return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ isVec4 ? '/ 4' : ''}]; }`; } + const resType = typeSnippet(innerElementSize, t); + const applyActivation = getActivationSnippet(attributes, resType, t); return ` ${declareFunctions} @@ -287,28 +294,28 @@ export const createConv3DNaiveProgramInfo = let coords = ${output.offsetToIndices('global_idx')}; let batch = ${getElementAt('coords', 0, x.rank)}; let d2 = ${ - isChannelsLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; + isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; let xFRCCorner = vec3(${ - isChannelsLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, - ${isChannelsLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, + isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, + ${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, ${ - isChannelsLast ? getElementAt('coords', 3, x.rank) : - getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; + isChannelLast ? getElementAt('coords', 3, x.rank) : + getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; let xFCorner = xFRCCorner.x; let xRCorner = xFRCCorner.y; let xCCorner = xFRCCorner.z; let xShapeY = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; let xShapeZ = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; let xShapeW = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; let xShapeU = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; let inputDepthNearestVec4 = (xShapeU / 4) * 4; let inputDepthVec4Remainder = xShapeU % 4; - var dotProd = 0.0; + var value = 0.0; for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) { let xF = xFCorner + wF * uniforms.dilations[0]; if (xF < 0 || xF >= xShapeY) { @@ -329,13 +336,13 @@ export const createConv3DNaiveProgramInfo = for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) { ${ - isChannelsLast ? `let xValues = vec4( + isChannelLast ? `let xValues = vec4( getX(batch, xF, xR, xC, d1), getX(batch, xF, xR, xC, d1 + 1), getX(batch, xF, xR, xC, d1 + 2), getX(batch, xF, xR, xC, d1 + 3)); ` : - `let xValues = vec4( + `let xValues = vec4( getX(batch, d1, xF, xR, xC), getX(batch, d1 + 1, xF, xR, xC), getX(batch, d1 + 2, xF, xR, xC), @@ -346,36 +353,36 @@ export const createConv3DNaiveProgramInfo = getW(d2, d1 + 1, wF, wR, wC), getW(d2, d1 + 2, wF, wR, wC), getW(d2, d1 + 3, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } if (inputDepthVec4Remainder == 1) { ${ - isChannelsLast ? `dotProd += getX(batch, xF, xR, xC, inputDepthNearestVec4) + isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4) * getW(d2, inputDepthNearestVec4, wF, wR, wC);` : - `dotProd += getX(batch, inputDepthNearestVec4, xF, xR, xC) + `value += getX(batch, inputDepthNearestVec4, xF, xR, xC) * getW(d2, inputDepthNearestVec4, wF, wR, wC);`} } else if (inputDepthVec4Remainder == 2) { ${ - isChannelsLast ? `let xValues = vec2( + isChannelLast ? `let xValues = vec2( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1)); ` : - `let xValues = vec2( + `let xValues = vec2( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC)); `} let wValues = vec2( getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } else if (inputDepthVec4Remainder == 3) { ${ - isChannelsLast ? `let xValues = vec3( + isChannelLast ? `let xValues = vec3( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2)); ` : - `let xValues = vec3( + `let xValues = vec3( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC)); @@ -384,19 +391,20 @@ export const createConv3DNaiveProgramInfo = getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } } } } - ${hasBias ? 'dotProd = dotProd + getBiasByOutputCoords(coords)' : ''}; - result[global_idx] = f32(dotProd); + ${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''}; + ${applyActivation} + result[global_idx] = f32(value); }`; }; return { name: 'Conv3DNaive', shaderCache: - {hint: `${attributes.cacheKey};${isChannelsLast};${innerElementSize};${hasBias}`, inputDependencies}, + {hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, diff --git a/js/web/test/data/ops/fused-conv3dncdhw.jsonc b/js/web/test/data/ops/fused-conv3dncdhw.jsonc new file mode 100644 index 0000000000000..1801ca380aa09 --- /dev/null +++ b/js/web/test/data/ops/fused-conv3dncdhw.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "fused conv3d with relu, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [-0.125, -0.25, -0.375, 0.5, 0.625, -0.75, -0.875, -1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0625, 0], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv3d with clip", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "activation_params", "data": [1.0, 3.0], "type": "floats" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2.1875], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv3d with HardSigmoid, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "activation_params", "data": [0.1, 0.3], "type": "floats" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.39375001192092896, 0.518750011920929], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 4a3a23bfe91b4..4aaf9d16b2b0e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1358,6 +1358,7 @@ "fast-gelu.jsonc", "floor.jsonc", "fused-conv.jsonc", + "fused-conv3dncdhw.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", From 94eb70d98348d83343207e113f9abaa0e7c6ea37 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 29 Jul 2024 23:50:14 +0800 Subject: [PATCH 07/40] [WebNN EP] Add labels for all WebNN operators (#21516) In order to provide more diagnosable error messages for developers. Spec change: https://github.com/webmachinelearning/webnn/pull/742 --- .../builders/impl/activation_op_builder.cc | 13 ++++--- .../builders/impl/argmax_min_op_builder.cc | 1 + .../webnn/builders/impl/binary_op_builder.cc | 15 ++++--- .../webnn/builders/impl/cast_op_builder.cc | 5 ++- .../webnn/builders/impl/clip_op_builder.cc | 1 + .../webnn/builders/impl/concat_op_builder.cc | 5 ++- .../webnn/builders/impl/conv_op_builder.cc | 16 +++++++- .../impl/dequantizeLinear_op_builder.cc | 17 ++++++-- .../impl/dynamicQuantizeLinear_op_builder.cc | 3 +- .../webnn/builders/impl/expand_op_builder.cc | 6 ++- .../webnn/builders/impl/flatten_op_builder.cc | 4 +- .../webnn/builders/impl/gather_op_builder.cc | 1 + .../webnn/builders/impl/gemm_op_builder.cc | 39 +++++++++++++++---- .../webnn/builders/impl/logical_op_builder.cc | 12 +++--- .../webnn/builders/impl/max_min_op_builder.cc | 10 +++-- .../builders/impl/normalization_op_builder.cc | 16 ++++++-- .../webnn/builders/impl/pad_op_builder.cc | 6 ++- .../webnn/builders/impl/pool_op_builder.cc | 1 + .../builders/impl/reduction_op_builder.cc | 1 + .../webnn/builders/impl/reshape_op_builder.cc | 7 +++- .../webnn/builders/impl/resize_op_builder.cc | 1 + .../webnn/builders/impl/shape_op_builder.cc | 9 ++++- .../webnn/builders/impl/slice_op_builder.cc | 5 ++- .../webnn/builders/impl/softmax_op_builder.cc | 4 +- .../webnn/builders/impl/split_op_builder.cc | 1 + .../impl/squeeze_unsqueeze_op_builder.cc | 8 +++- .../webnn/builders/impl/ternary_op_builder.cc | 4 +- .../builders/impl/transpose_op_builder.cc | 1 + .../builders/impl/triangular_op_builder.cc | 1 + .../webnn/builders/impl/unary_op_builder.cc | 30 +++++++------- 30 files changed, 180 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index af0f0133b497a..626aaf5c71b74 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -36,6 +36,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Elu") { options.set("alpha", helper.Get("alpha", 1.0f)); output = model_builder.GetBuilder().call("elu", input, options); @@ -46,20 +47,20 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("beta", helper.Get("beta", 0.5f)); output = model_builder.GetBuilder().call("hardSigmoid", input, options); } else if (op_type == "HardSwish") { - output = model_builder.GetBuilder().call("hardSwish", input); + output = model_builder.GetBuilder().call("hardSwish", input, options); } else if (op_type == "LeakyRelu") { options.set("alpha", helper.Get("alpha", 0.0f)); output = model_builder.GetBuilder().call("leakyRelu", input, options); } else if (op_type == "Relu") { - output = model_builder.GetBuilder().call("relu", input); + output = model_builder.GetBuilder().call("relu", input, options); } else if (op_type == "Sigmoid") { - output = model_builder.GetBuilder().call("sigmoid", input); + output = model_builder.GetBuilder().call("sigmoid", input, options); } else if (op_type == "Softplus") { - output = model_builder.GetBuilder().call("softplus", input); + output = model_builder.GetBuilder().call("softplus", input, options); } else if (op_type == "Softsign") { - output = model_builder.GetBuilder().call("softsign", input); + output = model_builder.GetBuilder().call("softsign", input, options); } else if (op_type == "Tanh") { - output = model_builder.GetBuilder().call("tanh", input); + output = model_builder.GetBuilder().call("tanh", input, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 1ae63a644a287..05f3a742a3775 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -47,6 +47,7 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("keepDimensions", keep_dims == 1); // TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API. options.set("outputDataType", "int64"); + options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 23e19d5943144..555de68cd60fe 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -35,18 +35,21 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + if (op_type == "Add") { - output = model_builder.GetBuilder().call("add", input0, input1); + output = model_builder.GetBuilder().call("add", input0, input1, options); } else if (op_type == "Sub") { - output = model_builder.GetBuilder().call("sub", input0, input1); + output = model_builder.GetBuilder().call("sub", input0, input1, options); } else if (op_type == "Mul") { - output = model_builder.GetBuilder().call("mul", input0, input1); + output = model_builder.GetBuilder().call("mul", input0, input1, options); } else if (op_type == "Div") { - output = model_builder.GetBuilder().call("div", input0, input1); + output = model_builder.GetBuilder().call("div", input0, input1, options); } else if (op_type == "Pow") { - output = model_builder.GetBuilder().call("pow", input0, input1); + output = model_builder.GetBuilder().call("pow", input0, input1, options); } else if (op_type == "PRelu") { - output = model_builder.GetBuilder().call("prelu", input0, input1); + output = model_builder.GetBuilder().call("prelu", input0, input1, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a97d71b90de55..a08e1681a8464 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -69,8 +69,11 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, node.Name(), " type: ", to_type); } + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = - model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type)); + model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type), options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index e6403a4cd12dc..b5c3206072d50 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -53,6 +53,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, "GetClipMinMax failed"); options.set("minValue", minValue); options.set("maxValue", maxValue); + options.set("label", node.Name()); emscripten::val input = model_builder.GetOperand(input_name); emscripten::val output = model_builder.GetBuilder().call("clamp", input, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index e4f98b09e03c5..dedc76b80e978 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -42,8 +42,11 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, inputs.push_back(model_builder.GetOperand(input->Name())); } + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = - model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis); + model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 320aaa03930fd..4f3f7459a7b5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -242,6 +242,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); ORT_RETURN_IF_ERROR(SetConvBaseOptions( model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); bool depthwise = false; @@ -276,7 +277,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (!is_nhwc || !is_constant_weight) { // The weight_shape has been appended 1's, reshape weight operand. std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_filter"); + filter = model_builder.GetBuilder().call("reshape", + filter, + emscripten::val::array(new_shape), + reshape_options); } } @@ -293,6 +299,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N perm = {0, 2, 3, 1}; // L_0231 } transpose_options.set("permutation", emscripten::val::array(perm)); + transpose_options.set("label", node.Name() + "_transpose_filter"); filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); } @@ -323,7 +330,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector output_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); std::vector new_shape = GetVecUint32FromVecInt64(output_shape); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_output"); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc index 66d502a4e6727..93a12a696cce1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc @@ -50,11 +50,22 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil std::vector target_shape{static_cast(input_shape[axis])}; target_shape.insert(target_shape.begin(), axis, 1); target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); - scale = model_builder.GetBuilder().call("reshape", scale, emscripten::val::array(target_shape)); + emscripten::val reshape_scale_options = emscripten::val::object(); + reshape_scale_options.set("label", node.Name() + "_reshape_scale"); + scale = model_builder.GetBuilder().call("reshape", + scale, + emscripten::val::array(target_shape), + reshape_scale_options); + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); zero_point = model_builder.GetBuilder().call("reshape", - zero_point, emscripten::val::array(target_shape)); + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); } - output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc index 3b5f64584b828..55746bb1f61f0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc @@ -31,8 +31,9 @@ Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); - output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input); + output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input, options); for (size_t i = 0, count = output_array["length"].as(); i < count; i++) { model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i])); diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index 9c75c00fa9273..c8cea833983b1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -53,10 +53,14 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector output_shape; ORT_RETURN_IF_NOT(GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape), "Cannot get output shape."); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("expand", input, - emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(output_shape)), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc index 31b1bd92a9503..d0ece026a7048 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc @@ -52,8 +52,10 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, SafeInt(num_post_axis_elements)}; emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call( - "reshape", inputs, emscripten::val::array(new_shape)); + "reshape", inputs, emscripten::val::array(new_shape), options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 014a08616c44f..23233539d34c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -42,6 +42,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); emscripten::val options = emscripten::val::object(); options.set("axis", axis); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("gather", input, indices, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 53f885019ab2f..bd452b118fe3e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -39,6 +39,8 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "MatMul") { std::vector a_shape; if (!GetShape(*input_defs[a_idx], a_shape, logger)) { @@ -53,23 +55,34 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (a_shape.size() == 1) { extended_a_shape = true; a_shape.insert(a_shape.begin(), 1); + emscripten::val reshape_a_options = emscripten::val::object(); + reshape_a_options.set("label", node.Name() + "_reshape_a"); a = model_builder.GetBuilder().call("reshape", a, - emscripten::val::array(GetVecUint32FromVecInt64(a_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(a_shape)), + reshape_a_options); } // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. bool extended_b_shape = false; if (b_shape.size() == 1) { extended_b_shape = true; b_shape.push_back(1); + emscripten::val reshape_b_options = emscripten::val::object(); + reshape_b_options.set("label", node.Name() + "_reshape_b"); b = model_builder.GetBuilder().call("reshape", b, - emscripten::val::array(GetVecUint32FromVecInt64(b_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(b_shape)), + reshape_b_options); } - output = model_builder.GetBuilder().call("matmul", a, b); + output = model_builder.GetBuilder().call("matmul", a, b, options); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node.Name() + "_reshape_output"); // If the inputs are both 1D, reduce the output to a scalar. if (extended_a_shape && extended_b_shape) { - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array()); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(), + reshape_output_options); } // After matrix multiplication the prepended 1 is removed. else if (extended_a_shape) { @@ -78,7 +91,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N new_shape.push_back(narrow(b_shape[i])); } new_shape.push_back(narrow(b_shape.back())); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_output_options); } // After matrix multiplication the appended 1 is removed. else if (extended_b_shape) { @@ -86,7 +102,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N for (size_t i = 0; i < a_shape.size() - 1; i++) { new_shape.push_back(narrow(a_shape[i])); } - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_output_options); } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); @@ -101,9 +120,13 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { b_zero_point = model_builder.GetZeroConstant("uint8"); } - output = model_builder.GetBuilder().call("matmulInteger", a, a_zero_point, b, b_zero_point); + output = model_builder.GetBuilder().call("matmulInteger", + a, + a_zero_point, + b, + b_zero_point, + options); } else { // Gemm - emscripten::val options = emscripten::val::object(); NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); options.set("aTranspose", emscripten::val(transA == 1)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index e56e8f6a3eb6d..23f3a938fee5e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -33,16 +33,18 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Equal") { - output = model_builder.GetBuilder().call("equal", input0, input1); + output = model_builder.GetBuilder().call("equal", input0, input1, options); } else if (op_type == "Greater") { - output = model_builder.GetBuilder().call("greater", input0, input1); + output = model_builder.GetBuilder().call("greater", input0, input1, options); } else if (op_type == "GreaterOrEqual") { - output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1); + output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1, options); } else if (op_type == "Less") { - output = model_builder.GetBuilder().call("lesser", input0, input1); + output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { - output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1); + output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 0168f59273545..1080fd0a3f943 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -43,22 +43,26 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(op_type == "Max" || op_type == "Min", "MaxMinOpBuilder, unknown op: ", op_type); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (input_count == 1) { // For 1 input, just concat the single input as workaround. // TODO: use identity instead once it's available in WebNN. emscripten::val inputs = emscripten::val::array(); inputs.call("push", input0); - output = model_builder.GetBuilder().call("concat", inputs, 0); + output = model_builder.GetBuilder().call("concat", inputs, 0, options); } else { std::string webnn_op_name = op_type == "Max" ? "max" : "min"; emscripten::val input1 = model_builder.GetOperand(input_defs[1]->Name()); - output = model_builder.GetBuilder().call(webnn_op_name.c_str(), input0, input1); + output = model_builder.GetBuilder().call(webnn_op_name.c_str(), input0, input1, options); for (size_t input_index = 2; input_index < input_count; ++input_index) { emscripten::val next_input = model_builder.GetOperand(input_defs[input_index]->Name()); - output = model_builder.GetBuilder().call(webnn_op_name.c_str(), output, next_input); + emscripten::val next_options = emscripten::val::object(); + next_options.set("label", node.Name() + "_" + input_defs[input_index]->Name()); + output = model_builder.GetBuilder().call(webnn_op_name.c_str(), output, next_input, next_options); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index a2aa0df5586e3..4d068baf35e72 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -42,6 +42,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder const auto rank = input_shape.size(); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); std::vector scale_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); @@ -116,7 +117,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder new_shape.erase(insertion_point, insertion_point + excess_rank); *insertion_point = sum; } - input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + emscripten::val reshape_input_options = emscripten::val::object(); + reshape_input_options.set("label", node.Name() + "_reshape_input"); + input = model_builder.GetBuilder().call("reshape", + input, + emscripten::val::array(new_shape), + reshape_input_options); } if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { @@ -126,8 +132,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { std::vector output_shape = GetVecUint32FromVecInt64(input_shape); - output = model_builder.GetBuilder().call( - "reshape", output, emscripten::val::array(output_shape)); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node.Name() + "reshape_output"); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(output_shape), + reshape_output_options); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index bc90821ba4ed8..071155a2fb372 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -73,6 +73,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto pad_mode = helper.Get("mode", std::string("constant")); @@ -145,9 +146,12 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, starts.push_back(start_padding[i] >= 0 ? SafeInt(0) : SafeInt(-start_padding[i])); sizes.push_back(SafeInt(input_shape[i] + start_padding[i] + end_padding[i])); } + emscripten::val slice_options = emscripten::val::object(); + slice_options.set("label", node.Name() + "_slice_output"); output = model_builder.GetBuilder().call("slice", output, emscripten::val::array(starts), - emscripten::val::array(sizes)); + emscripten::val::array(sizes), + slice_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 8b3eecf35fcc8..0af62dacedbd5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -59,6 +59,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 461050849385a..3e6d4d9820e9a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -57,6 +57,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); const auto keep_dims = helper.Get("keepdims", 1); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); options.set("keepDimensions", keep_dims == 1); std::vector axes_data; diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index b5005269b96a7..a7911683f0355 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -58,8 +58,13 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::transform(target_shape.cbegin(), target_shape.cend(), std::back_inserter(new_shape), [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("reshape", - input, emscripten::val::array(new_shape)); + input, + emscripten::val::array(new_shape), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index c4ca980fec715..2218c858951d3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -106,6 +106,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto mode = helper.Get("mode", "nearest"); if (mode == "linear") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 1552023d3f876..0eb7dafdffe4d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -55,8 +55,15 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val sizes = emscripten::val::array(); sizes.call("push", slice_length); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + // Since WebNN doesn't support Shape op, we use constant + slice ops as workaround. - emscripten::val output = model_builder.GetBuilder().call("slice", shape_constant, starts, sizes); + emscripten::val output = model_builder.GetBuilder().call("slice", + shape_constant, + starts, + sizes, + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index fb452aec1c929..bef13841c646c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -97,9 +97,12 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, sizes.begin(), [](int64_t i, int64_t j) { return SafeInt(i - j); }); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("slice", inputs, emscripten::val::array(starts), - emscripten::val::array(sizes)); + emscripten::val::array(sizes), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 95c1dbd518061..798cfabae65db 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -42,7 +42,9 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, int32_t axis = helper.Get("axis", default_axis); axis = static_cast(HandleNegativeAxis(axis, input_size)); - emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index ea3b8ef384ddc..4c59b694d690a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -49,6 +49,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const size_t rank = input_shape.size(); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); int32_t axis = helper.Get("axis", 0); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 8e6feb62fa8c4..5eff96873b8c4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -54,7 +54,6 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); const auto input_rank = input_shape.size(); - emscripten::val options = emscripten::val::object(); std::vector axes_data; auto rank = input_rank; @@ -111,7 +110,12 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil "SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); } - output = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("reshape", + input, + emscripten::val::array(new_shape), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 841e2d18244d5..2ed8330bf25be 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -32,9 +32,11 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val input2 = model_builder.GetOperand(node.InputDefs()[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); if (op_type == "Where") { - output = model_builder.GetBuilder().call("where", input0, input1, input2); + output = model_builder.GetBuilder().call("where", input0, input1, input2, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TernaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 3921b1da188c3..03c88ad9db88a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -42,6 +42,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); std::vector permutation = GetVecUint32FromVecInt64(perm); options.set("permutation", emscripten::val::array(permutation)); emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index e4b7021d49b30..0c818533918a4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -46,6 +46,7 @@ Status TriangularOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val output = emscripten::val::object(); NodeAttrHelper helper(node); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); const bool upper = helper.Get("upper", 1); options.set("upper", upper); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index e0016de8e69b7..061404c8a9ce0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -30,35 +30,37 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Abs") { - output = model_builder.GetBuilder().call("abs", input); + output = model_builder.GetBuilder().call("abs", input, options); } else if (op_type == "Ceil") { - output = model_builder.GetBuilder().call("ceil", input); + output = model_builder.GetBuilder().call("ceil", input, options); } else if (op_type == "Cos") { - output = model_builder.GetBuilder().call("cos", input); + output = model_builder.GetBuilder().call("cos", input, options); } else if (op_type == "Erf") { - output = model_builder.GetBuilder().call("erf", input); + output = model_builder.GetBuilder().call("erf", input, options); } else if (op_type == "Exp") { - output = model_builder.GetBuilder().call("exp", input); + output = model_builder.GetBuilder().call("exp", input, options); } else if (op_type == "Floor") { - output = model_builder.GetBuilder().call("floor", input); + output = model_builder.GetBuilder().call("floor", input, options); } else if (op_type == "Identity") { - output = model_builder.GetBuilder().call("identity", input); + output = model_builder.GetBuilder().call("identity", input, options); } else if (op_type == "Log") { - output = model_builder.GetBuilder().call("log", input); + output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { - output = model_builder.GetBuilder().call("neg", input); + output = model_builder.GetBuilder().call("neg", input, options); } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input); + output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { - output = model_builder.GetBuilder().call("reciprocal", input); + output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { - output = model_builder.GetBuilder().call("sin", input); + output = model_builder.GetBuilder().call("sin", input, options); } else if (op_type == "Sqrt") { - output = model_builder.GetBuilder().call("sqrt", input); + output = model_builder.GetBuilder().call("sqrt", input, options); } else if (op_type == "Tan") { - output = model_builder.GetBuilder().call("tan", input); + output = model_builder.GetBuilder().call("tan", input, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); From d8888136e3cdf29fa63d3b0a08a58683a7c9f0a0 Mon Sep 17 00:00:00 2001 From: mingyueliuh <131847423+mingyueliuh@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:45:52 -0400 Subject: [PATCH 08/40] Add support tensor element type for register custom op shape infer function (#21387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Functionality extension for the SetOutputShape method in custom op shape inference. ### Motivation and Context - **SetOutputShape** Interface enhancement Actually, the shape infer function need set the tensor type and shape ,Add a parameter **type** to allow users to specify the tensor type, and set **ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT** as default value to ensure compatibility. Co-authored-by: mingyue --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- include/onnxruntime/core/session/onnxruntime_cxx_inline.h | 3 ++- onnxruntime/core/session/custom_ops.cc | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5d974e1ff5185..29a229f427163 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2216,7 +2216,7 @@ struct ShapeInferContext { size_t GetInputCount() const { return input_shapes_.size(); } - Status SetOutputShape(size_t indice, const Shape& shape); + Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); int64_t GetAttrInt(const char* attr_name); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index aaef111b9f15b..9b9dd81a749c0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1998,9 +1998,10 @@ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, } } -inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { OrtTensorTypeAndShapeInfo* info = {}; ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); using InfoPtr = std::unique_ptr>; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 4c782f647371e..33d2a0244b453 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -105,6 +105,7 @@ struct OrtShapeInferContext { } } ONNX_NAMESPACE::updateOutputShape(ctx_, index, shape_proto); + ONNX_NAMESPACE::updateOutputElemType(ctx_, index, info->type); return onnxruntime::Status::OK(); } From 05cef469e81e3695667f122beecf97600094d09b Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 30 Jul 2024 00:59:46 +0800 Subject: [PATCH 09/40] Move on-device training packages publish step (#21539) ### Description Since the onedevice training cpu packaging has been a separated pipeline, it's nuget package publishing step must be moved as well. ### Motivation and Context Fixes the exception in Nuget Publishing Packaging Pipeline caused by #21485 --- .../c-api-training-packaging-pipelines.yml | 27 +++++++++++++++++-- .../github/azure-pipelines/publish-nuget.yml | 7 +---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml index aecece05a0e58..22ee7de8a5de0 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml @@ -32,13 +32,25 @@ parameters: displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. type: number default: 0 - + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + stages: - template: stages/set_packaging_variables_stage.yml parameters: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + - template: templates/ondevice-training-cpu-packaging-pipeline.yml parameters: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} @@ -48,4 +60,15 @@ stages: OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime.Training' AdditionalBuildFlags: '--enable_training_apis' AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos' - BuildVariant: 'default' \ No newline at end of file + BuildVariant: 'default' + +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Training Package' + ArtifactName: 'drop-signed-nuget-Training-CPU' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index 206a9464de6ef..b78d586288ba3 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -32,11 +32,6 @@ stages: artifact: 'drop-signed-nuget-dml' - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-Training-CPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - # Publish CUDA 11 Nuget/Java pkgs to ADO feed - template: stages/nuget-cuda-publishing-stage.yml parameters: @@ -44,4 +39,4 @@ stages: - template: stages/java-cuda-publishing-stage.yml parameters: - artifact_feed: $(ArtifactFeed) \ No newline at end of file + artifact_feed: $(ArtifactFeed) From bc3713206dc1d6c7e5062389ef7db42ac2051a30 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 29 Jul 2024 10:00:21 -0700 Subject: [PATCH 10/40] Update QNN pipeline pool (#21482) ### Description Update QNN pipeline pool ### Motivation and Context Let all our pipelines are using the latest NDK version --- ...droid-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 6649206c0d79c..c80092fc82ed5 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -35,7 +35,7 @@ parameters: jobs: - job: Build_QNN_EP - pool: onnxruntime-qnn-ubuntu-2004-cpu + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 30 workspace: clean: all @@ -46,6 +46,10 @@ jobs: inputs: versionSpec: $(pythonVersion) + - script: | + env | grep ANDROID + displayName: View Android ENVs + - script: sudo apt-get update -y && sudo apt-get install -y coreutils ninja-build displayName: Install coreutils and ninja @@ -56,13 +60,6 @@ jobs: parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} - - script: | - export ANDROID_SDK_ROOT=/usr/local/lib/android/sdk - export ANDROID_HOME=/usr/local/lib/android/sdk - export ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle - export ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle - displayName: set Android ENVs - - script: | set -e -x rm -rf /tmp/scripts From 79537d0523a7c215ef1685bf46efbd423242c4c1 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 29 Jul 2024 10:00:52 -0700 Subject: [PATCH 11/40] Remove tools/ci_build/github/android/run_nnapi_code_coverage.sh (#21371) ### Description Remove tools/ci_build/github/android/run_nnapi_code_coverage.sh ### Motivation and Context This file is no longer needed --- .../github/android/run_nnapi_code_coverage.sh | 36 ------------------- 1 file changed, 36 deletions(-) delete mode 100755 tools/ci_build/github/android/run_nnapi_code_coverage.sh diff --git a/tools/ci_build/github/android/run_nnapi_code_coverage.sh b/tools/ci_build/github/android/run_nnapi_code_coverage.sh deleted file mode 100755 index 472e824eaa47a..0000000000000 --- a/tools/ci_build/github/android/run_nnapi_code_coverage.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -# This script will run ORT build for Android with code coverage option - -set -e -set -x - -if [ $# -ne 1 ]; then - echo "One command line argument, the ROOT root directory, is expected" -fi - -ORT_ROOT=$1 -# Build and run onnxruntime using NNAPI execution provider targeting android emulator -python3 ${ORT_ROOT}/tools/ci_build/build.py \ - --android \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME \ - --android_ndk_path $ANDROID_NDK_HOME \ - --android_abi=x86_64 \ - --android_api=29 \ - --skip_submodule_sync \ - --parallel \ - --use_nnapi \ - --cmake_generator=Ninja \ - --build_java \ - --path_to_protoc_exe $ORT_ROOT/protobuf_install/bin/protoc \ - --code_coverage - -# Install gcovr -python3 -m pip install gcovr - -# Retrieve runtime code coverage files from the emulator and analyze -python3 ${ORT_ROOT}/tools/ci_build/coverage.py \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME - From 0d7cf301a1e0ea784edcdf2242e973643f0bb9c9 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 30 Jul 2024 02:05:34 +0800 Subject: [PATCH 12/40] [js/webgpu] Add activation Tanh (#21540) Bug:https://github.com/microsoft/onnxruntime/issues/21467 ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 4 +++ js/web/test/data/ops/fused-conv.jsonc | 33 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 6e66abacf3471..cfa0b42ef9eeb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -30,6 +30,10 @@ export const getActivationSnippet = baseType}(uniforms.beta)));`; case 'LeakyRelu': return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case 'Tanh': + return `let e2x = exp(-2.0 * abs(value)); + value = sign(value) * (1.0 - e2x) / (1.0 + e2x); + `; case '': return ''; // TODO: adding other activations that can be fused. diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 6a10e3b96a26a..d88c91ebc9de7 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -430,5 +430,38 @@ ] } ] + }, + { + "name": "fused conv with tanh", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Tanh", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [0.11, 0.12, 0.13, 0.14], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.15572261810302734, 0.20409323275089264, 0.29770541191101074, 0.3425688147544861], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] From b03c9496aa081fa6c07c5b266800694c830afd60 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:39:38 -0700 Subject: [PATCH 13/40] [js/web] allow load WebAssembly binary from buffer (#21534) ### Description This PR adds a new option `ort.env.wasm.wasmBinary`, which allows user to set to a buffer containing preload .wasm file content. This PR should resolve the problem from latest discussion in #20876. --- cmake/onnxruntime_webassembly.cmake | 2 +- js/common/lib/env.ts | 6 +++++ js/web/lib/wasm/wasm-factory.ts | 8 ++++++- .../e2e/browser-test-wasm-binary-override.js | 22 +++++++++++++++++++ js/web/test/e2e/run-data.js | 3 +++ 5 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 js/web/test/e2e/browser-test-wasm-binary-override.js diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 7a49e90c00bce..0686b66876d9f 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -225,7 +225,7 @@ else() "SHELL:-s EXPORT_ALL=0" "SHELL:-s VERBOSE=0" "SHELL:-s FILESYSTEM=0" - "SHELL:-s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm,mainScriptUrlOrBlob]" + "SHELL:-s INCOMING_MODULE_JS_API=[locateFile,instantiateWasm,wasmBinary]" "SHELL:-s WASM_BIGINT=1" ${WASM_API_EXCEPTION_CATCHING} --no-entry diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index dbb5f8118363f..1a87569a115a6 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -74,6 +74,12 @@ export declare namespace Env { */ wasmPaths?: WasmPrefixOrFilePaths; + /** + * Set a custom buffer which contains the WebAssembly binary. If this property is set, the `wasmPaths` property will + * be ignored. + */ + wasmBinary?: ArrayBufferLike|Uint8Array; + /** * Set or get a boolean value indicating whether to proxy the execution of main thread to a worker thread. * diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index fb068ab42d04c..0f5f10716a00b 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -108,6 +108,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag; const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm; const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag; + const wasmBinaryOverride = flags.wasmBinary; const [objectUrl, ortWasmFactory] = (await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1)); @@ -135,7 +136,12 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise numThreads, }; - if (wasmPathOverride || wasmPrefixOverride) { + if (wasmBinaryOverride) { + /** + * Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching. + */ + config.wasmBinary = wasmBinaryOverride; + } else if (wasmPathOverride || wasmPrefixOverride) { /** * A callback function to locate the WebAssembly file. The function should return the full path of the file. * diff --git a/js/web/test/e2e/browser-test-wasm-binary-override.js b/js/web/test/e2e/browser-test-wasm-binary-override.js new file mode 100644 index 0000000000000..35d427fa3b722 --- /dev/null +++ b/js/web/test/e2e/browser-test-wasm-binary-override.js @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const documentUrl = document.currentScript.src; + +it('Browser E2E testing - WebAssembly backend', async function() { + // preload .wasm file binary + const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.wasm', documentUrl).href; + const response = await fetch(wasmUrl); + + // make sure the .wasm file is loaded successfully + assert(response.ok); + assert(response.headers.get('Content-Type') === 'application/wasm'); + + // override wasm binary + const binary = await response.arrayBuffer(); + ort.env.wasm.wasmBinary = binary; + + await testFunction(ort, {executionProviders: ['wasm']}); +}); diff --git a/js/web/test/e2e/run-data.js b/js/web/test/e2e/run-data.js index 507192f29be9c..856f29eac6ddf 100644 --- a/js/web/test/e2e/run-data.js +++ b/js/web/test/e2e/run-data.js @@ -36,6 +36,9 @@ const BROWSER_TEST_CASES = [ [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2', 'proxy=1']], // 2 threads, proxy [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1', 'proxy=1']], // 1 thread, proxy + // wasm binary override: + [true, false, './browser-test-wasm-binary-override.js', 'ort.min.js'], + // path override: // wasm, path override filenames for both mjs and wasm, same origin [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], From c39f1c4fd80668fd7619719ebe7a374f4ae11a5e Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Mon, 29 Jul 2024 14:12:36 -0700 Subject: [PATCH 14/40] ORT- OVEP 1.19 PR-follow up (#21546) ### Description Follow up PR for bug fixes on 1.19 ### Motivation and Context - Handles 1.19 docker file fixes. - Sets the default file naming of epctx onnx model with _ctx.onnx as suffix. - Create epctx model directories if it doesn't exist. --------- Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> --- dockerfiles/Dockerfile.openvino | 10 ++++------ .../providers/openvino/backend_manager.cc | 9 ++++++++- .../openvino/openvino_execution_provider.cc | 5 ----- .../openvino/openvino_provider_factory.cc | 20 ++++++++++++++++++- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 75898770acf28..39e75a68a369f 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -3,11 +3,11 @@ # SPDX-License-Identifier: MIT #-------------------------------------------------------------------------- -ARG OPENVINO_VERSION=2024.0.0 +ARG OPENVINO_VERSION=2024.2.0 # Build stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} AS builder +FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} AS builder ENV WORKDIR_PATH=/home/openvino WORKDIR $WORKDIR_PATH @@ -34,20 +34,18 @@ RUN cat /etc/apt/sources.list | sed 's/^# deb-src/deb-src/g' > ./temp; mv temp / RUN apt update; apt install dpkg-dev RUN mkdir /sources WORKDIR /sources -RUN apt-get source cron iso-codes lsb-release powermgmt-base python-apt-common python3-apt python3-dbus python3-gi unattended-upgrades libapt-pkg6.0 libhogweed5 libnettle7 +RUN apt-get source cron iso-codes lsb-release powermgmt-base python-apt-common python3-apt python3-dbus python3-gi libapt-pkg6.0 libhogweed6 libnettle8 WORKDIR / RUN tar cvf GPL_sources.tar.gz /sources # Deploy stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} +FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} ENV DEBIAN_FRONTEND noninteractive USER root COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/dist/*.whl ./ COPY --from=builder /GPL_sources.tar.gz ./ RUN python3 -m pip install ./*.whl && rm ./*.whl -RUN apt update; apt install -y unattended-upgrades && \ - unattended-upgrade ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 8f3658df0d09d..18a6257910a56 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -128,6 +128,13 @@ BackendManager::BackendManager(const GlobalContext& global_context, #endif } } + if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { + auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph, + logger); + if ((!status.IsOK())) { + ORT_THROW(status); + } + } } // Call EPContext model exporter here if the provider option for exporting @@ -158,7 +165,7 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie if (dot == std::string::npos) return graph_name; return graph_name.substr(0, dot); }(); - graph_name = graph_name + "-ov_" + GetGlobalContext().device_type + "_blob.onnx"; + graph_name = graph_name + "_ctx.onnx"; } // If embed_mode, then pass on the serialized blob // If not embed_mode, dump the blob here and only pass on the path to the blob diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 5627cb2c122fb..29c45916795d3 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -147,11 +147,6 @@ common::Status OpenVINOExecutionProvider::Compile( *GetLogger(), ep_ctx_handle_); - if (global_context_->export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { - ORT_RETURN_IF_ERROR(backend_manager->ExportCompiledBlobAsEPCtxNode(graph_body_viewer, - *GetLogger())); - } - compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 716a7cd936405..3738f2a534154 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -192,6 +192,10 @@ struct OpenVINO_Provider : Provider { } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { + if (!std::all_of(provider_options_map.at("num_of_threads").begin(), + provider_options_map.at("num_of_threads").end(), ::isdigit)) { + ORT_THROW("[ERROR] [OpenVINO-EP] Number of threads should be a number. \n"); + } num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; @@ -298,7 +302,21 @@ struct OpenVINO_Provider : Provider { // The path to dump epctx model is valid only when epctx is enabled. // Overrides the cache_dir option to dump model cache files from OV. if (export_ep_ctx_blob) { - cache_dir = provider_options_map.at("so_epctx_path").c_str(); + auto ep_context_file_path_ = provider_options_map.at("so_epctx_path"); + auto file_path = std::filesystem::path(ep_context_file_path_); + // ep_context_file_path_ file extension must be .onnx + if (!ep_context_file_path_.empty() && + file_path.extension().generic_string() == ".onnx") { + // ep_context_file_path_ must be provided as a directory, create it if doesn't exist + auto parent_path = file_path.parent_path(); + if (!std::filesystem::is_directory(parent_path) && + !std::filesystem::create_directory(parent_path)) { + ORT_THROW("[ERROR] [OpenVINO] Failed to create directory : " + file_path.parent_path().generic_string() + " \n"); + } + cache_dir = ep_context_file_path_.c_str(); + } else { + ORT_THROW("[ERROR] [OpenVINO] Invalid ep_ctx_file_path" + ep_context_file_path_ + " \n"); + } } } From 7543dd040b2d32109a2718d7276d3aca1edadaae Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Jul 2024 10:50:13 +1200 Subject: [PATCH 15/40] Propagate NaNs in the CPU min and max operators (#21492) ### Description Propagates NaN values in the min and max operators so that min or max with a NaN in either input always produces NaN. ### Motivation and Context Fixes #21455 --- .../providers/cpu/math/element_wise_ops.cc | 18 +- onnxruntime/test/providers/checkers.cc | 2 +- .../cpu/math/element_wise_ops_test.cc | 188 ++++++++++++++++-- 3 files changed, 187 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 1d524a90302e7..5ea6000da1cba 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -705,7 +705,7 @@ Status Min_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - min = min.array().min(EigenMap(data_n).array()); + min = min.array().template min(EigenMap(data_n).array()); } return Status::OK(); @@ -721,15 +721,16 @@ struct Min_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().min(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template min(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template min(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template min( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); @@ -827,7 +828,7 @@ Status Max_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - max = max.array().max(EigenMap(data_n).array()); + max = max.array().template max(EigenMap(data_n).array()); } return Status::OK(); @@ -843,15 +844,16 @@ struct Max_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().max(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template max(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template max(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template max( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 5f332ddcddb8d..182fa4729a88f 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -427,7 +427,7 @@ struct TensorCheck { for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { - EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb3575f2cde88..bd3d21d4929f3 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1553,6 +1553,47 @@ TEST(MathOpTest, Min_12_Float_Nan) { } } +TEST(MathOpTest, Min_12_Float_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.25f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Float_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1586,12 +1627,53 @@ TEST(MathOpTest, Min_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -1.0, -1.0, -2.0, 0.5, 0.0, 1.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.25}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1666,7 +1748,7 @@ TEST(MathOpTest, Min_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16) { +TEST(MathOpTest, Min_12_MLFloat16) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({1.f, 1.f, 1.f})); @@ -1679,7 +1761,7 @@ TEST(MathOpTest, Min_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar0) { OpTester test("Min", 12); test.AddInput("data_0", {}, MakeMLFloat16({-10.f})); @@ -1692,7 +1774,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({2.f, 3.f, 4.f})); @@ -1809,12 +1891,53 @@ TEST(MathOpTest, Max_12_Float_Nan) { std::numeric_limits::quiet_NaN(), -0.5f, 0.0f, -1.0f, 1.0f, 1.0f, 2.0f}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25f, 0.5f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1854,12 +1977,53 @@ TEST(MathOpTest, Max_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -0.5, 0.0, -1.0, 1.0, 1.0, 2.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25, 0.5}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1934,7 +2098,7 @@ TEST(MathOpTest, Max_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16) { +TEST(MathOpTest, Max_12_MLFloat16) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -1.f, -1.f})); @@ -1947,7 +2111,7 @@ TEST(MathOpTest, Max_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar0) { OpTester test("Max", 12); test.AddInput("data_0", {}, MakeMLFloat16({-1.f})); @@ -1960,7 +2124,7 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -2.f, -3.f})); From d98581495f996084af65ae1e6600378bed949460 Mon Sep 17 00:00:00 2001 From: Sophie Schoenmeyer <107952697+sophies927@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:06:03 -0700 Subject: [PATCH 16/40] Update labeling bot (#21548) Current labeling bot over-applies many of the labels (e.g., ep:CUDA and platform:windows) and is missing some of the APIs + EPs Working on migrating this workflow to GitHub policies but would like to use this fix in the meantime to avoid causing any issues w/ ORT 1.19 ### Description ### Motivation and Context --- .github/labeler.yml | 31 ++++++++++++++---------- .github/title-only-labeler.yml | 4 +++ .github/workflows/title-only-labeler.yml | 20 +++++++++++++++ 3 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 .github/title-only-labeler.yml create mode 100644 .github/workflows/title-only-labeler.yml diff --git a/.github/labeler.yml b/.github/labeler.yml index 526d8a643e713..c14e2a213bc60 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,20 +1,25 @@ -api:javascript: '/\bjavascript\b/i' +api:CSharp: '/(\bc\s*sharp\b|\bc#)/i' api:java: '/\bjava\b/i' +api:javascript: '/\bjavascript\b/i' ep:ACL: '/\bacl\b/i' ep:ArmNN: '/\barmnn\b/i' -ep:CUDA: '/\bcuda\b/i' -ep:DML: '/(\bdirectml\b|\bdml\b)/i' -ep:MIGraphX: '/\bmigraphx\b/i' -ep:oneDNN: '/\bonednn\b/i' +ep:CANN: '/\bcann\b/i' +ep:CoreML: '/\bcore\s*ml\b/i' +ep:DML: '/(\bdirect\s*ml\b|\bdml\b)/i' +ep:MIGraphX: '/\bmi\s*graph\s*x\b/i' +ep:oneDNN: '/\bone\s*dnn\b/i' ep:OpenVINO: '/\bopen\s*vino\b/i' -ep:RockchipNPU: '/\brockchip\b/i' +ep:QNN: '/\bqnn\b/i' +ep:RockchipNPU: '/\brockchip(?:npu)?\b/i' ep:ROCm: '/\brocm\b/i' -ep:TensorRT: '/(\btensor\s*rt\b|\btrt\b)/i' +ep:SNPE: '/\bsnpe\b/i' ep:tvm: '/\btvm\b/i' ep:VitisAI: '/\bvitis(?:ai)?\b/i' -platform:jetson: '/\bjetson\b/i' -platform:mobile: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bcore-?ml\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' -platform:web: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b)/i' -platform:windows: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' -model:transformer: '/(\bbert\b|\bgpt-?2\b|\bhugging-?face\b|\blong-?former\b|\bt5\b)/i' -quantization: '/(is this a quantized model\?\n\nYes|\bquantization\b)/i' +ep:WebGPU: '/\bwebgpu\b/i' +ep:WebNN: '/\bwebnn\b/i' +ep:Xnnpack: '/\bxnn\s*pack\b/i' +.NET: '/(\bdot\s*net\b|\bnuget\b|\.net\b)/i' +platform:jetson: '/(\bjetson\b|\bjetpack\b)/i' +platform:mobile: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' +platform:web: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b|\bonnxruntime-react-native\b|\bnpm\b|\btransformers\.js\b)/i' +model:transformer: '/\btransformers(?!\.js)\b/i' diff --git a/.github/title-only-labeler.yml b/.github/title-only-labeler.yml new file mode 100644 index 0000000000000..4980f7251bcb4 --- /dev/null +++ b/.github/title-only-labeler.yml @@ -0,0 +1,4 @@ +ep:CUDA: '/\bcuda\b/i' +ep:TensorRT: '/(\btensor\s*rt\b|\btrt\b)/i' +platform:windows: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' +quantization: '/(quant|\bqdq\b)/i' diff --git a/.github/workflows/title-only-labeler.yml b/.github/workflows/title-only-labeler.yml new file mode 100644 index 0000000000000..e0af2dd06b1b7 --- /dev/null +++ b/.github/workflows/title-only-labeler.yml @@ -0,0 +1,20 @@ +name: "Title Only Issue Labeler" +on: + issues: + types: [opened, edited] + +permissions: + issues: write + +jobs: + triage: + runs-on: ubuntu-latest + steps: + - uses: github/issue-labeler@v3.4 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + configuration-path: .github/title-only-labeler.yml + not-before: 2020-01-15T02:54:32Z + enable-versioned-regex: 0 + include-title: 1 + include-body: 0 From 8417c325ec160dc8ee62edaf6d1daf91ad979d56 Mon Sep 17 00:00:00 2001 From: mcollinswisc Date: Mon, 29 Jul 2024 16:06:51 -0700 Subject: [PATCH 17/40] Keep QDQ nodes w/ nonpositive scale around MaxPool (#21182) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This change adds a check for whether the scale in the QuantizeLinear (or DequantizeLinear) is a positive scalar, and a new selector to disallow removing the QDQ around MaxPool if it is not. ### Motivation and Context Currently, the DropQDQNodesRules optimization removes QuantizeLinear and DequantizeLinear nodes from DequantizeLinear ∘ MaxPool ∘ QuantizeLinear. However, if the x_scale/y_scale values are non-positive, the (de-)quantization changes the ordering of the elements in the input value, so this optimization is changing the results. https://github.com/microsoft/onnxruntime/issues/21176 --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../optimizer/qdq_transformer/qdq_util.cc | 35 ++++++++++++++ .../core/optimizer/qdq_transformer/qdq_util.h | 4 ++ .../qdq_selector_action_transformer.cc | 27 +++++++++-- .../selectors_actions/qdq_selectors.cc | 7 +++ .../selectors_actions/qdq_selectors.h | 10 ++-- .../test/optimizer/qdq_transformer_test.cc | 46 +++++++++++++++++++ 6 files changed, 120 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index a4d1ea3c7cf56..7ef4ced1835f0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -166,6 +166,41 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( return true; } +bool IsQOrDQScalePositiveConstantScalar( + const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, + const std::filesystem::path& model_path) { + auto q_or_dq_input_defs = q_or_dq_node.InputDefs(); + + ORT_ENFORCE(q_or_dq_input_defs.size() >= 2); + + if (!optimizer_utils::IsScalar(*q_or_dq_input_defs[InputIndex::SCALE_ID])) { + return false; + } + + const ONNX_NAMESPACE::TensorProto* q_or_dq_scale_tensor_proto = + get_const_initializer(q_or_dq_input_defs[InputIndex::SCALE_ID]->Name()); + if (nullptr == q_or_dq_scale_tensor_proto) { + return false; + } + + Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path); + + switch (q_or_dq_scale.data_type()) { + case ONNX_NAMESPACE::TensorProto::FLOAT: + return q_or_dq_scale.data()[0] > 0; + + case ONNX_NAMESPACE::TensorProto::FLOAT16: + return q_or_dq_scale.data()[0] > 0; + + case ONNX_NAMESPACE::TensorProto::BFLOAT16: + return q_or_dq_scale.data()[0] > 0; + + default: + assert(false); + return false; + } +} + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) bool MatchQNode(const Node& node) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 5d11b8bfd5558..008f9972a143b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -65,6 +65,10 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( const GetConstantInitializerFn& get_const_initializer, bool& zero_point_exists); +// Checks that the y_scale/x_scale input to the QuantizeLinear/DequantizeLinear node is a positive scalar. +bool IsQOrDQScalePositiveConstantScalar(const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, + const std::filesystem::path& model_path); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // Check Q node op type, version, and domain. bool MatchQNode(const Node& node); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 17e66a3953b97..d81701fdf443b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -35,6 +35,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q. const std::string drop_action_name{"drop"}; const std::string drop_action_no_int16_name{"drop_no_int16_support"}; + const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"}; NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; @@ -46,19 +47,32 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr drop_action_no_int16 = std::make_unique( std::vector(moves)); // Copy before std::move(moves) + std::unique_ptr drop_action_no_int16_and_positive_scale = std::make_unique( + std::vector(moves)); // Copy before std::move(moves) std::unique_ptr drop_action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) - // Use a separate selector + action that disallows 16-bit types for MaxPool and Resize. + // Use separate selectors & actions for MaxPool and Resize. + // + // They disallow 16-bit types for MaxPool and Resize: // int16 MaxPool is not supported by the ONNX specification. // int16 Resize is not supported by the ORT implementation (although allowed by ONNX). - std::unique_ptr selector_disallow_16bit = std::make_unique(false); + // + // And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative + // scale will change the ordering of the elements between quantized & de-quantized values. + std::unique_ptr selector_no_16bit = std::make_unique(false); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, - {{"MaxPool", {12}}, - {"Resize", {}}}, - std::move(selector_disallow_16bit), + {{"Resize", {}}}, + std::move(selector_no_16bit), std::move(drop_action_no_int16)); + std::unique_ptr selector_no_16bit_and_positive_scale = + std::make_unique(false, true, false); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name, + {{"MaxPool", {12}}}, + std::move(selector_no_16bit_and_positive_scale), + std::move(drop_action_no_int16_and_positive_scale)); + std::unique_ptr selector = std::make_unique(true); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name, {{"Gather", {}}, @@ -70,6 +84,9 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::move(drop_action)); #else qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16)); + qdq_selector_action_registry.RegisterAction( + drop_action_no_int16_and_positive_scale_name, + std::move(drop_action_no_int16_and_positive_scale)); qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action)); #endif } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index e271ae8df3356..203aba2c3dd91 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -150,6 +150,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return graph_viewer.GetConstantInitializer(initializer_name, true); }; + if (!allow_nonpositive_scale_) { + // IsQDQPairSupported will check that the scale is the same between q_node and dq_node. + if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 491a15b62cb03..7e009da39403b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -48,8 +48,9 @@ class NodeGroupSelector { // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) - : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} + explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true, + bool allow_nonpositive_scale = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -58,6 +59,7 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; bool allow_4bit_; + bool allow_nonpositive_scale_; }; // Single DQ -> node. @@ -300,8 +302,8 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) - : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} + explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit, allow_nonpositive_scale)) {} }; class DropDQNodesSelector : public BaseSelector { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 367b4a65e3b7b..a043d6553bdfd 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -980,6 +980,52 @@ TEST(QDQTransformerTests, ReshapeDropQDQ) { RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, false, 21); // Use int16 ONNX QDQ ops } +// Runs a test case that checks if Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q if the quantization scale is +// negative. +template +static void RunMaxPoolNegativeScaleDropQDQTestCase() { + auto build_test_case = [](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + const std::vector input_shape = {1, 17, 17, 3}; + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + + constexpr float scale = -0.003f; + QuantType zero_point = 1 + (qmax + qmin) / 2; + + auto* input_arg_dq = builder.MakeIntermediate(); + auto* maxpool_output = builder.MakeIntermediate(); + + builder.AddDequantizeLinearNode(input_arg, scale, zero_point, input_arg_dq); + + Node& maxpool_node = builder.AddNode("MaxPool", {input_arg_dq}, {maxpool_output}); + maxpool_node.AddAttribute("auto_pad", "VALID"); + maxpool_node.AddAttribute("kernel_shape", std::vector({2, 2})); + + builder.AddQuantizeLinearNode(maxpool_output, scale, zero_point, output_arg); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MaxPool"], 1); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 1); + }; + + constexpr int opset = 21; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset); +} + +// Checks that Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q for negative scale. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, MaxpoolDontDropQDQForNegativeScale) { + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); +} + // Runs a test case that checks if Q/DQ nodes are dropped from DQ -> (Un)Squeeze -> Q. template static void RunSqueezeUnsqueezeDropQDQTestCase(const std::string& squeeze_type, From 5d78b9a17bb6d126f8ae7fa7eef05cabe4a08dae Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 29 Jul 2024 17:27:38 -0700 Subject: [PATCH 18/40] [TensorRT EP] Update TRT OSS Parser to 10.2 (#21552) ### Description Update TRT OSS Parser to [latest 10.2-GA branch](https://github.com/onnx/onnx-tensorrt/commit/f161f95883b4ebd8cb789de5efc67b73c0a6e694) ### Motivation and Context --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 4 ++-- .../github/azure-pipelines/templates/download-deps.yml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 66b305a6d36de..7de3f346f6386 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -216,7 +216,7 @@ "component": { "type": "git", "git": { - "commitHash": "06adf4461ac84035bee658c6cf5df39f7ab6071d", + "commitHash": "f161f95883b4ebd8cb789de5efc67b73c0a6e694", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" diff --git a/cmake/deps.txt b/cmake/deps.txt index 9d206b6bb3aeb..d0edf963451d5 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -37,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/v0.3.zip;5ec64e3071edc7347ebd8a81679cf06e2bb9b851 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.1.zip;2eb9198bb352757d5ff13977cbe0634898e0837c -#use the latest commit of 10.0-GA -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/06adf4461ac84035bee658c6cf5df39f7ab6071d.zip;46dceef659d75d276e7914a8057c2282269d5e7b +#use the latest commit of 10.2-GA +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/f161f95883b4ebd8cb789de5efc67b73c0a6e694.zip;2148d0c79a171abf2b9451f3bfec164e85caf2ef protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index bf11730c2ce28..01965343c4592 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.167 + version: 1.0.173 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.167 + version: 1.0.173 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 07d3be5b0e037927c3defd8a7e389e59ec748ad8 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 29 Jul 2024 21:04:47 -0700 Subject: [PATCH 19/40] CoreML: Add ML Program Split Op (#21456) ### Description Add support for Split Op ### Motivation and Context Address operator gaps in high priority model. --------- Co-authored-by: Scott McKay Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../coreml/builders/impl/split_op_builder.cc | 138 ++++++++++++------ .../apple/coreml_supported_mlprogram_ops.md | 1 + 2 files changed, 94 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 0497357c45c54..dbd0f48576f8b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -5,6 +5,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -24,6 +25,8 @@ class SplitOpBuilder : public BaseOpBuilder { // Split opset 13- uses "split" as attribute. Currently it's not supported. int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } + + bool SupportsMLProgram() const override { return true; } }; void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -43,55 +46,98 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); NodeAttrHelper helper(node); - const auto axis = helper.Get("axis", 0); + int64_t axis = helper.Get("axis", 0); - // attribute introduced since opset 18 - uint64_t num_outputs; - - std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_splitnd = layer->mutable_splitnd(); - coreml_splitnd->set_axis(axis); - - if (input_defs.size() > 1) { - // if "split" is explicitly provided as an input - const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - Initializer unpacked_tensor(split_tensor); - auto split_span = unpacked_tensor.DataAsSpan(); - auto split_sizes = split_span.size(); - num_outputs = narrow(split_sizes); - for (size_t i = 0; i < split_sizes; i++) { - coreml_splitnd->add_splitsizes(split_span[i]); - } - } else if (node.SinceVersion() < 18) { - num_outputs = narrow(node.OutputDefs().size()); - coreml_splitnd->set_numsplits(num_outputs); - } else { - // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt64("num_outputs").value()); + auto calculate_remainder_and_chunk_size = [&](int32_t num_outputs) { // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); + uint64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; uint64_t remainder = split_dim_size % chunk_size; - if (remainder) { - // uneven - auto split_sizes = InlinedVector(num_outputs, chunk_size); - split_sizes.back() = remainder; - for (size_t i = 0; i < split_sizes.size(); i++) { - coreml_splitnd->add_splitsizes(split_sizes[i]); - } + return std::make_tuple(remainder, chunk_size); + }; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::unique_ptr split_op = model_builder.CreateOperation(node, "split"); + AddOperationInput(*split_op, "axis", model_builder.AddScalarConstant(split_op->type(), "axis", axis)); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + auto split_span = unpacked_tensor.DataAsSpan(); + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", split_span)); + } else if (node.SinceVersion() < 18) { + int64_t num_outputs = narrow(node.OutputDefs().size()); + AddOperationInput(*split_op, "num_splits", + model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs)); } else { - // even + // note: for opset 18+ 'num_outputs' is a required attribute + int64_t num_outputs = helper.GetInt64("num_outputs").value(); + auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); + if (remainder) { + // uneven + std::vector split_sizes(num_outputs, chunk_size); + split_sizes.back() = remainder; + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", split_sizes)); + } else { + // even + AddOperationInput(*split_op, "num_splits", + model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs)); + } + } + + AddOperationInput(*split_op, "x", input_defs[0]->Name()); + for (const auto& output_def : node.OutputDefs()) { + AddOperationOutput(*split_op, *output_def); + } + model_builder.AddOperation(std::move(split_op)); + + } else +#endif + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + auto* coreml_splitnd = layer->mutable_splitnd(); + coreml_splitnd->set_axis(axis); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + auto split_span = unpacked_tensor.DataAsSpan(); + for (const auto& split_size : split_span) { + coreml_splitnd->add_splitsizes(split_size); + } + } else if (node.SinceVersion() < 18) { + uint64_t num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); + } else { + // note: for opset 18+ 'num_outputs' is a required attribute + uint64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); + auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); + if (remainder) { + // uneven + auto split_sizes = InlinedVector(num_outputs, chunk_size); + split_sizes.back() = remainder; + for (size_t i = 0; i < split_sizes.size(); i++) { + coreml_splitnd->add_splitsizes(split_sizes[i]); + } + } else { + // even + coreml_splitnd->set_numsplits(num_outputs); + } } - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. - // Otherwise, uses attribute value 'num_outputs'. - for (uint64_t i = 0; i < num_outputs; i++) { - *layer->mutable_output()->Add() = node.OutputDefs()[i]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. + // Otherwise, uses attribute value 'num_outputs'. + for (const auto& output_def : node.OutputDefs()) { + *layer->mutable_output()->Add() = output_def->Name(); + } + model_builder.AddLayer(std::move(layer)); } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -99,7 +145,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); NodeAttrHelper helper(node); const auto axis = helper.Get("axis", 0); @@ -110,16 +155,19 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())]; if (input_defs.size() > 1 && input_defs[1]->Exists()) { - if (!CheckIsConstantInitializer(*input_defs[1], input_params.graph_viewer, logger, "'split'")) { + const auto* splits_tensor = input_params.graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!splits_tensor) { + LOGS(logger, VERBOSE) << "CoreML 'splits' input must be a constant initializer."; return false; } + const auto split_shape = *input_defs[1]->Shape(); if (split_shape.dim_size() < 2) { - LOGS(logger, VERBOSE) << "CoreML SplitND requires to produce at least 2 outputs."; + LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; return false; } - const auto& splits_tensor = *initializers.at(input_defs[1]->Name()); - Initializer unpacked_tensor(splits_tensor); + + Initializer unpacked_tensor(*splits_tensor); auto splits_span = unpacked_tensor.DataAsSpan(); int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0}); if (sum_of_splits != split_dims_at_axis) { diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index d2a961f17bd6a..b546c266c131b 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -24,6 +24,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| |ai.onnx.Slice|starts/ends/axes/steps must be constant initializers.| +|ai.onnx:Split|| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| |ai:onnx:Tanh|| From 82036b04978b7930185996a70d2146c2895469ea Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 29 Jul 2024 21:59:16 -0700 Subject: [PATCH 20/40] Remove references to the outdated CUDA EP factory method (#21549) The function "OrtSessionOptionsAppendExecutionProvider_CUDA" is deprecated. --- .../global_thread_pools/test_inference.cc | 4 +++- onnxruntime/test/shared_lib/test_inference.cc | 20 ++++++++++++++----- .../test/shared_lib/test_model_loading.cc | 5 +++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index f553682975f11..c6d958536f488 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -74,7 +74,9 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) if (provider_type == 1) { #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); std::cout << "Running simple inference with cuda provider" << std::endl; #else return Ort::Session(nullptr); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 52491a179c2ce..7a33bf8a527cd 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1959,7 +1959,9 @@ TEST(CApiTest, get_allocator_cpu) { #ifdef USE_CUDA TEST(CApiTest, get_allocator_cuda) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -2076,7 +2078,9 @@ TEST(CApiTest, io_binding_cuda) { #ifdef USE_TENSORRT Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); #else - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); @@ -3438,7 +3442,9 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); #else // arena is enabled but the sole initializer will still be allocated from non-arena memory Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -3890,7 +3896,9 @@ TEST(CApiTest, GitHubIssue10179) { try { const auto* model_path = MODEL_URI; Ort::SessionOptions session_options{}; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; } catch (const std::exception& e) { std::cerr << "exception: " << e.what() << "\n"; @@ -3920,7 +3928,9 @@ TEST(CApiTest, GitHubIssue10179) { TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { const auto* model_path = SEQUENCE_MODEL_URI_2; Ort::SessionOptions session_options{}; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index b7f6f7f4b9a77..5694398b9cb10 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -60,8 +60,9 @@ TEST(CApiTest, model_from_array) { create_session(so); #ifdef USE_CUDA - // test with CUDA provider when using onnxruntime as dll - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(so, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + so.AppendExecutionProvider_CUDA_V2(*options); create_session(so); #endif } From 530a2d7b41b0584f67ddfef6679a79e9dbeee556 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Tue, 30 Jul 2024 03:49:14 -0700 Subject: [PATCH 21/40] Enable FP16 Clip and Handle Bias in FP16 Depthwise Conv (#21493) - Improved accuracy for face-detection, image-classification, and object-detection in the GeekBench ML benchmark on ARM64. - Fixed issue https://github.com/microsoft/onnxruntime/issues/18992 --- docs/OperatorKernels.md | 4 +- onnxruntime/core/mlas/inc/mlas.h | 2 + onnxruntime/core/mlas/lib/dwconv.cpp | 32 +-- onnxruntime/core/mlas/lib/fp16_common.h | 17 ++ .../core/providers/cpu/fp16/fp16_conv.cc | 4 +- onnxruntime/core/providers/cpu/math/clip.cc | 2 +- .../test/providers/cpu/math/clip_test.cc | 18 ++ .../test/providers/cpu/nn/conv_fp16_test.cc | 237 +++++++++++++++++- .../test/providers/cpu/nn/conv_op_test.cc | 235 +++++++++++++++++ 9 files changed, 531 insertions(+), 20 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 211c53d0fecc8..f265c9f985070 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -58,8 +58,8 @@ Do not modify directly.* |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)| -|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||11|**T** = tensor(float)| |||[6, 10]|**T** = tensor(float)| |Col2Im|*in* input:**T**
*in* image_shape:**tensor(int64)**
*in* block_shape:**tensor(int64)**
*out* output:**T**|18+|**T** = tensor(float)| diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 675f7c7a13e8c..e46105324a7fb 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1751,6 +1751,7 @@ MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* Pac * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input * @param Filter Supplies the address for filter tensor + * @param Bias Supplies the address for 1D bias tensor B, has size of M * @param Output Supplies the address for the result tensor * @param Channels # of input channels * @param OutputCount # of output pixels @@ -1762,6 +1763,7 @@ MLASCALL MlasConvDepthwise( const MLAS_FP16* const* Input, const MLAS_FP16* Filter, + const MLAS_FP16* Bias, MLAS_FP16* Output, size_t Channels, size_t OutputCount, diff --git a/onnxruntime/core/mlas/lib/dwconv.cpp b/onnxruntime/core/mlas/lib/dwconv.cpp index 15511d2d8ceac..d48d9cbb17502 100644 --- a/onnxruntime/core/mlas/lib/dwconv.cpp +++ b/onnxruntime/core/mlas/lib/dwconv.cpp @@ -14,7 +14,6 @@ Module Name: --*/ - #include "fp16_common.h" #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -24,19 +23,20 @@ void MlasConvDepthwiseKernel( const _mlas_fp16_* const* Input, const _mlas_fp16_* Filter, + const _mlas_fp16_* Bias, _mlas_fp16_* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { while (OutputCount > 0) { size_t ChannelOffset = 0; size_t c = Channels; while (c >= 8) { - MLAS_FLOAT16X8 Accumulator = MlasZeroFloat16x8(); + MLAS_FLOAT16X8 Accumulator = Bias == nullptr ? MlasZeroFloat16x8() : MlasLoadFloat16x8(&Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -54,7 +54,7 @@ MlasConvDepthwiseKernel( } if (c >= 4) { - MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + MLAS_FLOAT16X4 Accumulator = Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadFloat16x4(&Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -72,7 +72,8 @@ MlasConvDepthwiseKernel( } if (c > 0) { - MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + MLAS_FLOAT16X4 Accumulator = + Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadPartialFloat16x4(&Bias[ChannelOffset], c); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -86,8 +87,7 @@ MlasConvDepthwiseKernel( Output += c; } if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, - Channels); + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); } Input += KernelSize; OutputCount -= 1; @@ -101,16 +101,17 @@ void MlasConvDepthwiseKernel( const _mlas_fp16_* const* Input, const _mlas_fp16_* Filter, + const _mlas_fp16_* Bias, _mlas_fp16_* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { while (OutputCount > 0) { for (size_t ChannelOffset = 0; ChannelOffset < Channels; ChannelOffset++) { - float Accumulator = 0.0f; + float Accumulator = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -120,35 +121,36 @@ MlasConvDepthwiseKernel( *Output++ = MLAS_Float2Half(Accumulator); } if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, - Channels); + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); } Input += KernelSize; OutputCount -= 1; } } -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED - +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED void MLASCALL MlasConvDepthwise( const MLAS_FP16* const* Input, const MLAS_FP16* Filter, + const MLAS_FP16* Bias, MLAS_FP16* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { MlasConvDepthwiseKernel( reinterpret_cast(Input), reinterpret_cast(Filter), + reinterpret_cast(Bias), reinterpret_cast<_mlas_fp16_*>(Output), Channels, OutputCount, KernelSize, - PostProc); + PostProc + ); } diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 1fcab870af64f..30b66cdb2ea78 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -64,6 +64,23 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) +{ + MLAS_FLOAT16X4 Vector = MlasZeroFloat16x4(); + if ((len & 1) != 0) { + Vector = vreinterpret_f16_u16(vld1_lane_u16(Buffer + (len - 1), vreinterpret_u16_f16(Vector), 0)); + } + if ((len & 2) != 0) { + Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 0)); + Vector = vreinterpret_f16_f32( + vld1_lane_f32(reinterpret_cast(Buffer), vreinterpret_f32_f16(Vector), 0) + ); + } + return Vector; +} + MLAS_FORCEINLINE void MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index e6867f10819ae..37db095e92570 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -139,8 +139,9 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr bool share_prepacked_weights = (prepacked_weights != nullptr); + const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1); // Don't pack the filter buffer if the MlasConvDepthwise path is used. - if (!(group_input_channels == 1 && group_output_channels == 1)) { + if (!is_depthwise_conv) { packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false); if (packed_W_size_ != 0) { size_t packed_W_data_size = SafeInt(group_count) * packed_W_size_; @@ -472,6 +473,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { MlasConvDepthwise( worker_indirection_buffer, reordered_W, + Bdata, worker_output, static_cast(M), static_cast(output_count), diff --git a/onnxruntime/core/providers/cpu/math/clip.cc b/onnxruntime/core/providers/cpu/math/clip.cc index ddb64a5a0e461..200469bc47835 100644 --- a/onnxruntime/core/providers/cpu/math/clip.cc +++ b/onnxruntime/core/providers/cpu/math/clip.cc @@ -23,7 +23,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( float); ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0, - float, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t); + float, MLFloat16, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t); } // namespace op_kernel_type_control using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index 6f81bbbe31d54..9948a6cc8a681 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -119,6 +119,24 @@ TEST(MathOpTest, Clip_Default_uint64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(MathOpTest, Clip_MLFloat16) { + OpTester test("Clip", 12); + + std::vector dims{3, 3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), + MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + test.AddInput("min", {}, {MLFloat16(0.0f)}); + test.AddInput("max", {}, {MLFloat16(6.0f)}); + test.AddOutput("Y", dims, + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + + test.Run(); +} + TEST(MathOpTest, Clip_int32) { OpTester test("Clip", 12); diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index cb5fc8095982c..95b274966fbbb 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -714,6 +714,241 @@ TEST(ConvFp16Test, Conv2D_group) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Depthwise2D_Bias_Group1_Issue18992) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(1.0f)}; + vector X_shape = {1, 1, 1, 1}; + vector W = {MLFloat16(0.5f)}; + vector W_shape = {1, 1, 1, 1}; + vector B = {MLFloat16(0.5f)}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 1, 1}; + auto expected_vals = {MLFloat16(1.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Depthwise2D_Bias_Group2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), + MLFloat16(12.0f), MLFloat16(13.0f), MLFloat16(14.0f), + MLFloat16(15.0f), MLFloat16(16.0f), MLFloat16(17.0f)}; + vector X_shape = {1, 2, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(2.0f)}; + vector W_shape = {2, 1, 1, 1}; + vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f), + + MLFloat16(17.0f), MLFloat16(19.0f), MLFloat16(21.0f), + MLFloat16(23.0f), MLFloat16(25.0f), MLFloat16(27.0f), + MLFloat16(29.0f), MLFloat16(31.0f), MLFloat16(33.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Depthwise2D_Bias_Group15) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 15, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + // C = 0 + MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(2.0f), MLFloat16(3.0f), + + // C = 1 + MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), + + // C = 2 + MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), + + // C = 3 + MLFloat16(12.0f), MLFloat16(13.0f), + MLFloat16(14.0f), MLFloat16(15.0f), + + // C = 4 + MLFloat16(16.0f), MLFloat16(17.0f), + MLFloat16(18.0f), MLFloat16(19.0f), + + // C = 5 + MLFloat16(20.0f), MLFloat16(21.0f), + MLFloat16(22.0f), MLFloat16(23.0f), + + // C = 6 + MLFloat16(24.0f), MLFloat16(25.0f), + MLFloat16(26.0f), MLFloat16(27.0f), + + // C = 7 + MLFloat16(28.0f), MLFloat16(29.0f), + MLFloat16(30.0f), MLFloat16(31.0f), + + // C = 8 + MLFloat16(32.0f), MLFloat16(33.0f), + MLFloat16(34.0f), MLFloat16(35.0f), + + // C = 9 + MLFloat16(36.0f), MLFloat16(37.0f), + MLFloat16(38.0f), MLFloat16(39.0f), + + // C = 10 + MLFloat16(40.0f), MLFloat16(41.0f), + MLFloat16(42.0f), MLFloat16(43.0f), + + // C = 11 + MLFloat16(44.0f), MLFloat16(45.0f), + MLFloat16(46.0f), MLFloat16(47.0f), + + // C = 12 + MLFloat16(48.0f), MLFloat16(49.0f), + MLFloat16(50.0f), MLFloat16(51.0f), + + // C = 13 + MLFloat16(52.0f), MLFloat16(53.0f), + MLFloat16(54.0f), MLFloat16(55.0f), + + // C = 14 + MLFloat16(56.0f), MLFloat16(57.0f), + MLFloat16(58.0f), MLFloat16(59.0f)}; + vector X_shape = {1, 15, 2, 2}; + vector W = { + // M = 0 + MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(2.0f), MLFloat16(3.0f), + + // M = 1 + MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), + + // M = 2 + MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), + + // M = 3 + MLFloat16(12.0f), MLFloat16(13.0f), + MLFloat16(14.0f), MLFloat16(15.0f), + + // M = 4 + MLFloat16(16.0f), MLFloat16(17.0f), + MLFloat16(18.0f), MLFloat16(19.0f), + + // M = 5 + MLFloat16(20.0f), MLFloat16(21.0f), + MLFloat16(22.0f), MLFloat16(23.0f), + + // M = 6 + MLFloat16(24.0f), MLFloat16(25.0f), + MLFloat16(26.0f), MLFloat16(27.0f), + + // M = 7 + MLFloat16(28.0f), MLFloat16(29.0f), + MLFloat16(30.0f), MLFloat16(31.0f), + + // M = 8 + MLFloat16(32.0f), MLFloat16(33.0f), + MLFloat16(34.0f), MLFloat16(35.0f), + + // M = 9 + MLFloat16(36.0f), MLFloat16(37.0f), + MLFloat16(38.0f), MLFloat16(39.0f), + + // M = 10 + MLFloat16(40.0f), MLFloat16(41.0f), + MLFloat16(42.0f), MLFloat16(43.0f), + + // M = 11 + MLFloat16(44.0f), MLFloat16(45.0f), + MLFloat16(46.0f), MLFloat16(47.0f), + + // M = 12 + MLFloat16(48.0f), MLFloat16(49.0f), + MLFloat16(50.0f), MLFloat16(51.0f), + + // M = 13 + MLFloat16(52.0f), MLFloat16(53.0f), + MLFloat16(54.0f), MLFloat16(55.0f), + + // M = 14 + MLFloat16(56.0f), MLFloat16(57.0f), + MLFloat16(58.0f), MLFloat16(59.0f)}; + vector W_shape = {15, 1, 2, 2}; + vector B = { + MLFloat16(101.0f), + MLFloat16(102.0f), + MLFloat16(103.0f), + MLFloat16(104.0f), + MLFloat16(105.0f), + MLFloat16(106.0f), + MLFloat16(107.0f), + MLFloat16(108.0f), + MLFloat16(109.0f), + MLFloat16(110.0f), + MLFloat16(111.0f), + MLFloat16(112.0f), + MLFloat16(113.0f), + MLFloat16(114.0f), + MLFloat16(115.0f)}; + vector B_shape = {15}; + vector Y_shape = {1, 15, 1, 1}; + auto expected_vals = { + MLFloat16(115.0f), // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 101.0 + MLFloat16(228.0f), + MLFloat16(469.0f), + MLFloat16(838.0f), + MLFloat16(1335.0f), + MLFloat16(1960.0f), + MLFloat16(2713.0f), // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 107.0 + MLFloat16(3594.0f), + MLFloat16(4603.0f), + MLFloat16(5740.0f), + MLFloat16(7005.0f), + MLFloat16(8398.0f), + MLFloat16(9919.0f), // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 113.0 + MLFloat16(11568.0f), // 52.0*52.0 + 53.0*53.0 + 54.0*54.0 + 55.0*55.0 + 114.0 + MLFloat16(13345.0f) // 56.0*56.0 + 57.0*57.0 + 58.0*58.0 + 59.0*59.0 + 115.0 + }; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + TEST(ConvFp16Test, ConvDimWithZero) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -1074,4 +1309,4 @@ TEST(ConvFp16Test, SharedPrepackedWeights) { } // namespace test } // namespace onnxruntime -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED \ No newline at end of file +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 0efa78af2795c..2d885ee9d479f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -647,6 +647,241 @@ TEST(ConvTest, Conv2D_group) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Depthwise2D_Bias_Group1_Issue18992) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f}; + vector X_shape = {1, 1, 1, 1}; + vector W = {0.5f}; + vector W_shape = {1, 1, 1, 1}; + vector B = {0.5f}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 1, 1}; + auto expected_vals = {1.0f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Depthwise2D_Bias_Group2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + + 9.0f, 10.0f, 11.0f, + 12.0f, 13.0f, 14.0f, + 15.0f, 16.0f, 17.0f}; + vector X_shape = {1, 2, 3, 3}; + vector W = {1.0f, 2.0f}; + vector W_shape = {2, 1, 1, 1}; + vector B = {1.0f, -1.0f}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, + + 17.0f, 19.0f, 21.0f, + 23.0f, 25.0f, 27.0f, + 29.0f, 31.0f, 33.0f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Depthwise2D_Bias_Group15) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 15, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + // C = 0 + 0.0f, 1.0f, + 2.0f, 3.0f, + + // C = 1 + 4.0f, 5.0f, + 6.0f, 7.0f, + + // C = 2 + 8.0f, 9.0f, + 10.0f, 11.0f, + + // C = 3 + 12.0f, 13.0f, + 14.0f, 15.0f, + + // C = 4 + 16.0f, 17.0f, + 18.0f, 19.0f, + + // C = 5 + 20.0f, 21.0f, + 22.0f, 23.0f, + + // C = 6 + 24.0f, 25.0f, + 26.0f, 27.0f, + + // C = 7 + 28.0f, 29.0f, + 30.0f, 31.0f, + + // C = 8 + 32.0f, 33.0f, + 34.0f, 35.0f, + + // C = 9 + 36.0f, 37.0f, + 38.0f, 39.0f, + + // C = 10 + 40.0f, 41.0f, + 42.0f, 43.0f, + + // C = 11 + 44.0f, 45.0f, + 46.0f, 47.0f, + + // C = 12 + 48.0f, 49.0f, + 50.0f, 51.0f, + + // C = 13 + 52.0f, 53.0f, + 54.0f, 55.0f, + + // C = 14 + 56.0f, 57.0f, + 58.0f, 59.0f}; + vector X_shape = {1, 15, 2, 2}; + vector W = { + // M = 0 + 0.0f, 1.0f, + 2.0f, 3.0f, + + // M = 1 + 4.0f, 5.0f, + 6.0f, 7.0f, + + // M = 2 + 8.0f, 9.0f, + 10.0f, 11.0f, + + // M = 3 + 12.0f, 13.0f, + 14.0f, 15.0f, + + // M = 4 + 16.0f, 17.0f, + 18.0f, 19.0f, + + // M = 5 + 20.0f, 21.0f, + 22.0f, 23.0f, + + // M = 6 + 24.0f, 25.0f, + 26.0f, 27.0f, + + // M = 7 + 28.0f, 29.0f, + 30.0f, 31.0f, + + // M = 8 + 32.0f, 33.0f, + 34.0f, 35.0f, + + // M = 9 + 36.0f, 37.0f, + 38.0f, 39.0f, + + // M = 10 + 40.0f, 41.0f, + 42.0f, 43.0f, + + // M = 11 + 44.0f, 45.0f, + 46.0f, 47.0f, + + // M = 12 + 48.0f, 49.0f, + 50.0f, 51.0f, + + // M = 13 + 52.0f, 53.0f, + 54.0f, 55.0f, + + // M = 14 + 56.0f, 57.0f, + 58.0f, 59.0f}; + vector W_shape = {15, 1, 2, 2}; + vector B = { + 101.0f, + 102.0f, + 103.0f, + 104.0f, + 105.0f, + 106.0f, + 107.0f, + 108.0f, + 109.0f, + 110.0f, + 111.0f, + 112.0f, + 113.0f, + 114.0f, + 115.0f}; + vector B_shape = {15}; + vector Y_shape = {1, 15, 1, 1}; + auto expected_vals = { + 115.0f, // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 101.0 + 228.0f, + 469.0f, + 838.0f, + 1335.0f, + 1960.0f, + 2713.0f, // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 107.0 + 3594.0f, + 4603.0f, + 5740.0f, + 7005.0f, + 8398.0f, + 9919.0f, // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 113.0 + 11568.0f, // 52.0*52.0 + 53.0*53.0 + 54.0*54.0 + 55.0*55.0 + 114.0 + 13345.0f // 56.0*56.0 + 57.0*57.0 + 58.0*58.0 + 59.0*59.0 + 115.0 + }; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + TEST(ConvTest, ConvDimWithZero) { ConvOpAndTestAttributes attrs = { "", // auto_pad From 1637f22d39b6d57d2774d7d41e6a8ae1815180c5 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Tue, 30 Jul 2024 09:35:45 -0700 Subject: [PATCH 22/40] Extend Pad Fusion for AveragePool (#21556) ### Description This extends the existing pad_fusion for AveragePool operator i.e. fuse Pad if it is followed by AveragePool operator. ### Motivation and Context --- onnxruntime/core/optimizer/pad_fusion.cc | 3 ++- onnxruntime/core/optimizer/pad_fusion.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index a1c7f8de9e6fe..e266946b0d9e0 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -12,7 +12,7 @@ namespace onnxruntime { * It matches following pattern: * Pad * | - * Conv/MaxPool + * Conv/MaxPool/AveragePool */ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { // if Pad has input axis, don't fuse it. @@ -28,6 +28,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log const Node& child_node = *node.OutputNodesBegin(); if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) && !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { return false; } diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h index a1b6978a83d1e..ca05d219b7e2c 100644 --- a/onnxruntime/core/optimizer/pad_fusion.h +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -8,7 +8,7 @@ namespace onnxruntime { /* * This fusion submerges a Pad operator to it's child - * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * Conv or MaxPool or AveragePool operator, if and only if PadFusion::SatisfyCondition() * is true. */ class PadFusion : public RewriteRule { From e7aa11607f59c7efeb8505af8fe9186ff2e3dd37 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:22:46 -0700 Subject: [PATCH 23/40] Utilize ext data location to reduce qd matmul memory usage (#21451) ### Description When the graph is quantized to qdq format, the DQ + MatMul is transformed to MatMulNBits in the level 2 optimizer when the model is initialized in an inference session. In the transformation step, tensors are transposed and new tensor protos are created. Instead of using protobuf arena allocated memory, the PR sets the tensor proto to use external buffer, and point the external location to memory location which contains the tensor buffer allocated by CPU. Then, in the step that creates OrtValue using the tensor proto, the memory buffers in the tensor proto are directly assigned to the tensors which were originally allocated by Ort Arena. With these two steps, the peak memory usage of QDQ format model is the same as usage of QOperator model. Besides, the model initialization time is significantly reduced. Take [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) for example: || QOperator Model (MatMulNBits) | QDQ Model (DQ + MatMul, original code) | QDQ Model (this PR) | |---|---|---|---| | peak memory consumption | 2.8 GB | ~4.8 GB | 2.8 GB | | initialization time | 3 sec | 9 sec | 5 sec | ### Motivation and Context When the graph is quantized to qdq format, the DQ + MatMul is converted to MatMulNBits in the level 2 optimizer. Originally, the newly created tensor proto use memory allocated by protobuf arena. These memory usage cannot be fully released when the tensor protos are deleted. Then, in the tensor proto to OrtValue step, tensors are created using ORT arena. Later, in the pre-pack step for MatMulNBits, new OrtValues are created. The tensors in the ORT arena are not fully released as well. The two arena memory allocation steps in the DQ + MatMul -> MatMulNBits transformation will result in almost 2x memory consumption in the model initialization. --- .../core/optimizer/graph_transformer_utils.h | 9 +- onnxruntime/core/framework/session_state.cc | 3 +- onnxruntime/core/framework/session_state.h | 11 ++ .../core/framework/session_state_utils.cc | 41 ++++-- .../core/framework/session_state_utils.h | 6 +- .../core/framework/tensorprotoutils.cc | 36 +++++- onnxruntime/core/framework/tensorprotoutils.h | 29 +++-- .../core/optimizer/graph_transformer_utils.cc | 12 +- .../selectors_actions/qdq_actions.cc | 119 +++++++++++------- .../selectors_actions/qdq_actions.h | 6 +- .../qdq_selector_action_transformer.cc | 35 ++++-- .../qdq_selector_action_transformer.h | 7 +- onnxruntime/core/session/inference_session.cc | 14 ++- 13 files changed, 239 insertions(+), 89 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 0bb5c7432f0a7..6cff153c336f0 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -3,12 +3,15 @@ #pragma once +#include #include +#include #include #include #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" +#include "core/framework/tensor.h" #include "core/optimizer/graph_transformer.h" #include "core/platform/threadpool.h" @@ -51,7 +54,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -81,7 +85,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a88f36f63639c..ddb0c3356e544 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1486,7 +1486,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string #include #include +#include #include #include "core/common/flatbuffers.h" @@ -303,6 +304,10 @@ class SessionState { const InlinedHashSet* GetToBeExecutedRange(gsl::span fetch_mlvalue_idxs) const; #endif + std::unordered_map>* GetMutableBufferedTensors() { + return &name_to_buffered_tensor_; + } + Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, bool remove_initializers = true, @@ -562,6 +567,12 @@ class SessionState { // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; #endif + + // Holds the tensors which provide memory buffer for TensorProtos + // Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is + // allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release + // a instance's memory which may result large memory consumption, which is a tradeoff for speed. + std::unordered_map> name_to_buffered_tensor_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 059de8e3c8c4a..b13b0cd27496d 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include @@ -61,17 +63,23 @@ struct ExtDataValueDeleter { // given a tensor proto with external data return an OrtValue with a tensor for // that data; the pointers for the tensor data and the tensor itself are owned -// by the OrtValue's deleter +// by the OrtValue's deleter. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor, OrtCallback& ext_data_deleter) { + Tensor& tensor, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); void* ext_data_buf = nullptr; SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, - ext_data_buf, ext_data_len, ext_data_deleter)); + ext_data_buf, ext_data_len, ext_data_deleter, + buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -83,16 +91,24 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, return common::Status::OK(); } +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, - bool use_device_allocator_for_initializers = false) { + bool use_device_allocator_for_initializers = false, + Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } + ORT_RETURN_IF(buffered_tensor && !utils::HasExternalData(tensor_proto), + "With buffered tensor, tensor proto must use external location and point to buffered tensor"); + // Get shape and type of the tensor, and allocate the empty tensor TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); @@ -123,7 +139,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, ext_data_deleter)); + ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, + ext_data_deleter, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; @@ -154,7 +171,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st std::optional scoped_ort_callback_invoker; if (utils::HasExternalData(tensor_proto)) { ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter)); + ext_data_deleter, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); } else { ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); @@ -187,7 +204,8 @@ common::Status SaveInitializedTensors( const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func) { + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -307,9 +325,16 @@ common::Status SaveInitializedTensors( bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; + Tensor* p_tensor = nullptr; + if (auto iter = buffered_tensors.find(name); + iter != buffered_tensors.end()) { + p_tensor = iter->second.release(); + buffered_tensors.erase(iter); + } + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, - use_device_allocator_for_initializers); + use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index af44c35fbb7f5..499222b6ec613 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -3,6 +3,9 @@ #pragma once #include +#include +#include +#include #include "core/common/const_pointer_container.h" #include "core/framework/allocator.h" @@ -44,7 +47,8 @@ common::Status SaveInitializedTensors( const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func); + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, SessionState& session_state, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 4ecd61962d797..cbd53298ab2ad 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -987,7 +987,8 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, - SafeInt& ext_data_len, OrtCallback& ext_data_deleter) { + SafeInt& ext_data_len, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -1003,7 +1004,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo // the value in location is the memory address of the data ext_data_buf = reinterpret_cast(file_offset); ext_data_len = raw_data_safe_len; - ext_data_deleter = OrtCallback{nullptr, nullptr}; + if (buffered_tensor) { + ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast(p); }, + reinterpret_cast(buffered_tensor)}; + } else { + ext_data_deleter = OrtCallback{nullptr, nullptr}; + } } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, @@ -1241,7 +1247,9 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto return CApiElementTypeFromProtoType(tensor_proto.data_type()); } -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name) { +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer) { // Set name, dimensions, type, and data of the TensorProto. ONNX_NAMESPACE::TensorProto tensor_proto; @@ -1259,6 +1267,28 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: for (; f < end; ++f) { *mutable_string_data->Add() = *f; } + } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { + // The logic aligns with + // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302 + const auto* raw_data = tensor.DataRaw(); + ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); + static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. + // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the + // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. + auto offset = narrow(reinterpret_cast(raw_data)); + + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(tensor.SizeInBytes())); } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index e5197adcb94ec..2af1f080be7ee 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -114,14 +114,22 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); -/** Creates a TensorProto from a Tensor. - @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. - @param[in] tensor_proto_name the name of the TensorProto. - @return the TensorProto. - - Note: Method currently requires that data is in little-endian format. +/** + * @brief Creates a TensorProto from a Tensor. + * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. + * @param[in] tensor_proto_name the name of the TensorProto. + * @param[in] use_tensor_buffer the tensor proto is set to use external location, with + * 'location' set to onnxruntime::utils::kTensorProtoMemoryAddressTag + * 'offset' set to tensor's memory location, and 'length' set to tensor's + * memory size. The caller is responsible to maintain the lifetime of + * the allocated memory buffer. Use with caution. + * @return the TensorProto. + * + * Note: Method currently requires that data is in little-endian format. */ -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name); +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer = false); ONNXTensorElementDataType CApiElementTypeFromProtoType(int type); ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto); @@ -141,10 +149,15 @@ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_A // Given a tensor proto with external data obtain a pointer to the data and its length. // The ext_data_deleter argument is updated with a callback that owns/releases the data. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, - OrtCallback& ext_data_deleter); + OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr); // Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer // If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ab1dbaea7b7fd..54bd44ec2dba4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -189,7 +189,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -309,7 +310,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -419,7 +421,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -444,7 +447,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 74fecb0427e14..8f99b7409d4fe 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" @@ -275,8 +278,10 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } -DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( + int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -286,7 +291,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + p_buffered_tensors_{p_buffered_tensors} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -311,6 +317,7 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { + ORT_RETURN_IF_NOT(p_buffered_tensors_, "Buffered tensors map cannot be null"); const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -338,24 +345,35 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional zp_src; - Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(weight_arg->Name() + "_T"), - std::vector{N, quant_num, blob_bytes}); - Initializer scale_dst(static_cast(scale_src.data_type()), - graph.GenerateNodeArgName(scale_arg->Name() + "_T"), - std::vector{N * quant_num}); - std::optional zp_dst; + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + std::optional zp_src_ptr; + auto cpu_allocator = std::make_shared(); + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + auto weight_dst_ptr = std::make_unique(uint8_type, + TensorShape{N, quant_num, blob_bytes}, + cpu_allocator); + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); + auto scale_dst_ptr = std::make_unique(scale_type, + TensorShape{scale_size}, + cpu_allocator); + std::string zp_dst_name; + std::unique_ptr zp_dst_ptr; + auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); if (zp_tensor_proto) { - zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{zp_size}, + cpu_allocator); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{zp_size}, + cpu_allocator); + memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes()); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -363,10 +381,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -376,10 +394,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -391,10 +409,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -405,10 +423,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -417,28 +435,43 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - ONNX_NAMESPACE::TensorProto weight_T_tp; - ONNX_NAMESPACE::TensorProto scale_T_tp; + auto weight_T_tp = utils::TensorToTensorProto(*weight_dst_ptr, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(*scale_dst_ptr, scale_dst_name, true); std::optional zp_T_tp; - // TODO(fajin): external_data to memory location to avoid arena allocation - // https://github.com/microsoft/onnxruntime/pull/12465 - weight_dst.ToProto(weight_T_tp); - scale_dst.ToProto(scale_T_tp); - if (zp_dst) { - zp_T_tp.emplace(); - zp_dst->ToProto(zp_T_tp.value()); + if (zp_dst_ptr) { + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst_ptr, zp_dst_name, true)); } auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + if (weight_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + // If tensor is too small, tensor proto directly copies data from tensor. The tensor allocated + // here can be directly destructed. + // Only keep the tensor in p_buffered_tensors_ when the tensor proto is using external data location + // and pointing the location to tensor's buffer. + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(weight_dst_name, std::move(weight_dst_ptr)).second, + "Failed to add buffered tensor ", + weight_dst_name); + } + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + if (scale_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(scale_dst_name, std::move(scale_dst_ptr)).second, + "Failed to add buffered tensor ", + scale_dst_name); + } if (zp_T_tp) { input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); replacement_node.MutableInputArgsCount().push_back(1); + if (zp_T_tp->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(zp_dst_name, std::move(zp_dst_ptr)).second, + "Failed to add buffered tensor ", + zp_dst_name); + } } return Status::OK(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 47821619db65a..d25077ca4b491 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -5,10 +5,12 @@ #include #include +#include #include #include "core/optimizer/selectors_actions/actions.h" #include "core/platform/threadpool.h" +#include "core/framework/tensor.h" namespace onnxruntime { @@ -84,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -103,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + std::unordered_map>* p_buffered_tensors_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d81701fdf443b..379d271fbdca7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include +#include +#include + +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" @@ -247,7 +250,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. @@ -255,7 +259,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr selector = std::make_unique(); @@ -312,9 +317,11 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { +SelectorActionRegistry CreateSelectorActionRegistry( + bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -328,20 +335,24 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, - const SatApplyContextVariant& apply_context, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +QDQSelectorActionTransformer::QDQSelectorActionTransformer( + bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, + intra_op_thread_pool, p_buffered_tensors), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index ba636f76d1900..627ddd35b9919 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" @@ -25,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5ad2f08467792..5eed7c5c6f2b5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1615,7 +1615,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1623,7 +1624,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable, intra_op_thread_pool); + optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2012,7 +2013,8 @@ common::Status InferenceSession::Initialize() { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, - cpu_ep, GetIntraOpThreadPoolToUse())); + cpu_ep, GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3175,7 +3177,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3185,7 +3188,8 @@ common::Status InferenceSession::AddPredefinedTransformers( : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } }(); From 1d4b161145ee9604b98d904a6c3da5610fbf6d01 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 31 Jul 2024 08:46:08 +0800 Subject: [PATCH 24/40] [WebNN EP] Support ConvTranspose for TFLite backend (#21291) ### Description Chromium supports ConvTranspose for TFLite in https://chromium-review.googlesource.com/c/chromium/src/+/5635194 With constraint that only default dilations and groups are supported. --------- Co-authored-by: Dwayne Robinson --- js/web/docs/webnn-operators.md | 2 +- .../core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/conv_op_builder.cc | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 75652899b5e5e..8711d4d20e370 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -22,7 +22,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | -| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✗ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). | +| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 496f886e5a076..63fd97abb9a9a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -167,7 +167,7 @@ static const InlinedHashMap op_map = { {"Concat", {"concat", true}}, {"Conv", {"conv2d", true}}, {"ConvInteger", {"conv2dInteger", false}}, - {"ConvTranspose", {"convTranspose2d", false}}, + {"ConvTranspose", {"convTranspose2d", true}}, {"Cos", {"cos", true}}, {"Div", {"div", true}}, {"DequantizeLinear", {"dequantizeLinear", false}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 4f3f7459a7b5b..22049d2519712 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -427,6 +427,22 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } + // WebNN CPU backend (TFLite) only supports default dilations and group. + // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 + if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { + NodeAttrHelper helper(node); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + const auto group = helper.Get("group", 1); + if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; + return false; + } + if (group != 1) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; + return false; + } + } + return true; } From b341c44c20d45c88430feb62b684033ee7b12a8f Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 31 Jul 2024 08:59:55 -0700 Subject: [PATCH 25/40] Fix ETW trace logging crash in multithreading situations (#21566) ### Description ETW trace logger is fakely registered as initialized_ is marked as true before the registration is done, causing crashing issue for Lenovo camera application. A prior attempt to address was made here: https://github.com/microsoft/onnxruntime/pull/21226 It was reverted here: https://github.com/microsoft/onnxruntime/pull/21360 ### Motivation and Context The problem is that during initialization of TraceLoggingRegisterEx, it will reinvoke the callback and attempt reinitialization, which is not allowed. TraceLoggingRegisterEx however can be initialized concurrently when initialization happens on multiple threads. For these reasons it needs to be protected by a lock, but the lock cannot naively block because the callback's reinvocation will cause a deadlock. To solve this problem another tracking variable is added : "initializing" which protects against reinitialization during the first initialization. --------- Co-authored-by: Sheil Kumar --- .../core/platform/windows/logging/etw_sink.cc | 27 +++++++++++++++---- .../core/platform/windows/logging/etw_sink.h | 7 ++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index b0f9eaf4f62d2..ef42c88a67ba6 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -137,28 +137,45 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( EtwRegistrationManager::~EtwRegistrationManager() { std::lock_guard lock(callbacks_mutex_); callbacks_.clear(); - ::TraceLoggingUnregister(etw_provider_handle); + if (initialization_status_ == InitializationStatus::Initialized || + initialization_status_ == InitializationStatus::Initializing) { + std::lock_guard init_lock(init_mutex_); + assert(initialization_status_ != InitializationStatus::Initializing); + if (initialization_status_ == InitializationStatus::Initialized) { + ::TraceLoggingUnregister(etw_provider_handle); + initialization_status_ = InitializationStatus::NotInitialized; + } + } } EtwRegistrationManager::EtwRegistrationManager() { } -void EtwRegistrationManager::LazyInitialize() { - if (!initialized_) { +void EtwRegistrationManager::LazyInitialize() try { + if (initialization_status_ == InitializationStatus::NotInitialized) { std::lock_guard lock(init_mutex_); - if (!initialized_) { // Double-check locking pattern - initialized_ = true; + if (initialization_status_ == InitializationStatus::NotInitialized) { // Double-check locking pattern + initialization_status_ = InitializationStatus::Initializing; etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); if (FAILED(etw_status_)) { ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_)); } + initialization_status_ = InitializationStatus::Initialized; } } +} catch (...) { + initialization_status_ = InitializationStatus::Failed; + throw; } void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { + if (initialization_status_ != InitializationStatus::Initialized) { + // Drop messages until manager is fully initialized. + return; + } + std::lock_guard lock(callbacks_mutex_); for (const auto& callback : callbacks_) { (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 3af45b813a625..d6c9ea27b2955 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -47,6 +47,11 @@ class EtwSink : public ISink { }; class EtwRegistrationManager { + enum class InitializationStatus { NotInitialized, + Initializing, + Initialized, + Failed }; + public: using EtwInternalCallback = std::function Date: Wed, 31 Jul 2024 09:01:05 -0700 Subject: [PATCH 26/40] [CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed
not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed
not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention
bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past
only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present
(not share buffer) ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/18854 --- .../contrib_ops/cpu/bert/attention_base.cc | 1 - .../contrib_ops/cpu/bert/attention_common.h | 13 +- .../cpu/bert/multihead_attention.cc | 15 +- .../cpu/bert/multihead_attention_helper.h | 574 +++++++----- .../cuda/bert/add_bias_transpose.cu | 114 +++ .../cuda/bert/add_bias_transpose.h | 54 +- .../contrib_ops/cuda/bert/attention.cc | 4 +- .../contrib_ops/cuda/bert/attention_impl.cu | 202 ++-- .../contrib_ops/cuda/bert/attention_impl.h | 55 +- .../cuda/bert/attention_kernel_options.h | 1 + .../cuda/bert/attention_kv_cache.cu | 73 +- .../cuda/bert/attention_prepare_qkv.cu | 864 ++++++++++++------ .../cuda/bert/attention_transpose.cu | 6 + .../decoder_masked_multihead_attention.cc | 13 +- ...decoder_masked_multihead_attention_impl.cu | 3 + .../cuda/bert/multihead_attention.cc | 221 ++--- .../cuda/bert/multihead_attention.h | 6 + .../cuda/bert/packed_attention_impl.cu | 22 +- .../bert/packed_multihead_attention_impl.cu | 58 +- .../quantization/attention_quantization.cc | 4 +- .../cuda/utils/dump_cuda_tensor.cc | 9 + .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 2 +- .../contrib_ops/rocm/bert/attention_impl.cu | 10 +- .../contrib_ops/rocm/bert/attention_impl.h | 6 - .../rocm/bert/multihead_attention.cu | 8 +- .../tools/transformers/fusion_attention.py | 3 + .../test/python/transformers/benchmark_mha.py | 99 +- .../test/python/transformers/test_mha.py | 346 ++++--- 28 files changed, 1729 insertions(+), 1057 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 515a967aa2386..f7d8fedc734e4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -258,7 +258,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->scale = scale_; output_parameters->mask_type = mask_type; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = false; output_parameters->qkv_format = Q_K_V_BNSH; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 55292b35e1e38..88127387d08ea 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -6,6 +6,12 @@ namespace onnxruntime { namespace contrib { +enum AttentionType { + kAttention, + kMultiHeadAttention, + kDecoderMaskedMultiHeadAttention, +}; + enum AttentionMaskType { MASK_NONE, // No mask MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length @@ -24,10 +30,12 @@ enum AttentionQkvFormat { UNKNOWN, // enum value not set, or depends on qkv projection implementation details Q_K_V_BNSH, // for non-packed qkv, permuted Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention - QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BSNH_BNSH_BNSH, // for cross attention, k and v are permuted Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) - Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + QKV_BSN3H, // for TRT fused attention, qkv are packed + QKV_BS3NH, // for DecoderMaskedMultiHeadAttention, qkv are packed QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed }; @@ -61,7 +69,6 @@ struct AttentionParameters { bool past_present_share_buffer; bool do_rotary; bool broadcast_res_pos_bias; - bool pass_past_in_kv; float mask_filter_value; float scale; bool use_tf32; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 9677c30f22d8a..0d77376779230 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -85,7 +85,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { scale_, is_unidirectional_, past_present_share_buffer, - false)); + kMultiHeadAttention)); const int batch_size = parameters.batch_size; const int q_sequence_length = parameters.sequence_length; @@ -121,20 +121,13 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - // For each of Q/K/V, there are multiple scenarios: - // 1) Combined QKV bias is null - // a) Q/K/V is (B, S, D) - // b) Q/K/V is (B, S, N, H) - // 2) No packed QKV in Q - // a) Q/K/V has seq_len = 1 - // b) Q/K/V has seq_len > 1 - OrtValue Q; ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q)); - if (parameters.pass_past_in_kv) { // key and value in BNSH format - assert(bias == nullptr); + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + // For cross attention with k and v in BNSH format, we assume that bias for key and value are zeros. + // So we don't need to add bias for key and value here. assert(past_key == nullptr); assert(past_value == nullptr); return ApplyAttention(Q.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index bd7ab09659170..cfb8d36843777 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -11,6 +11,232 @@ namespace onnxruntime { namespace contrib { namespace multihead_attention_helper { +template +Status Check_QKV(const T* packed_qkv, AttentionQkvFormat& qkv_format) { + const auto& query_dims = packed_qkv->Shape().GetDims(); + if (query_dims.size() == 3) { + // Packed qkv used by DecoderMaskedMultiHeadAttention. Query shape is (B, S, 3D), no key and value. + qkv_format = AttentionQkvFormat::QKV_BS3NH; + } else { + assert(query_dims.size() == 5); + if (static_cast(query_dims[3]) != 3) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'query' shape (batch_size, sequence_length, num_heads, 3, head_size) for packed qkv"); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + + return Status::OK(); +} + +template +Status Check_Q_KV(const T* query, const T* packed_kv, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = packed_kv->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key be 5 for packed kv"); + } + + if (key_dims[0] != query_dims[0] || + static_cast(key_dims[2]) != num_heads || + static_cast(key_dims[3]) != 2 || + static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + kv_sequence_length = static_cast(key_dims[1]); + return Status::OK(); +} + +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length, int& v_hidden_size) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + const auto& value_dims = value->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != value_dims.size() || (key_dims.size() != 3 && value_dims.size() != 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key and value be same, and either 3 or 4"); + } + + if (key_dims[0] != query_dims[0] || value_dims[0] != query_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query', 'key' and 'value' shall have same dim 0 (batch_size)"); + } + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + + if (key_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same dim 1 (kv_sequence_length)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + kv_sequence_length = static_cast(key_dims[1]); + v_hidden_size = static_cast(value_dims[2]); + } else { // key_dims.size() == 4 + if (value->Shape() != key->Shape() || + static_cast(key_dims[1]) != num_heads || + static_cast(key_dims[3]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same shape (batch_size, num_heads, kv_sequence_length, head_size)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + kv_sequence_length = static_cast(key_dims[2]); + v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); + } + + return Status::OK(); +} + +template +Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, + int batch_size, int num_heads, int head_size, bool past_present_share_buffer, + int& past_sequence_length, int& max_sequence_length) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 1 should be same as number of heads, got ", + past_key_dims[1]); + } + if (past_value_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 1 should be same as number of heads, got ", + past_value_dims[1]); + } + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); + } + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + past_sequence_length = static_cast(past_key_dims[2]); + if (past_present_share_buffer) { + max_sequence_length = static_cast(past_key_dims[2]); + if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); + } + past_sequence_length = *((*past_seq_len).template Data()); + } + return Status::OK(); +} + +template +Status CheckRelativePositionBias( + const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length, + bool& broadcast_res_pos_bias) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[0] == 1) { + broadcast_res_pos_bias = true; + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + return Status::OK(); +} + +template +AttentionMaskType GetMaskType(const T* key_padding_mask, int batch_size, int sequence_length, int total_sequence_length) { + AttentionMaskType mask_type = AttentionMaskType::MASK_UNKNOWN; + const auto& mask_dims = key_padding_mask->Shape().GetDims(); + if (mask_dims.size() == 1) { + if (mask_dims[0] == static_cast(batch_size)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; + } + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; + } + return mask_type; +} + template Status CheckInputs(const T* query, const T* key, @@ -27,176 +253,128 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing) { - // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None - // relative_position_bias : (B, 1, S, L) - // past_key : (B, N, S*, H) - // past_value : (B, N, S*, H) - // When no packing for q/k/v: + AttentionType operator_type) { + // --------------------------------------------------------------- + // Notations: + // B: batch_size + // N: num_heads + // H: head_size (V might have different head size than Q and K) + // D: hidden_size = N * H + // S: q_sequence_length + // P: past_sequence_length + // L: kv_sequence_length + // T: total_sequence_length = P + L + // M: max_sequence_length + // --------------------------------------------------------------- + // MultiHeadAttention inputs: + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: // query (Q) : (B, S, D) - // key (K) : (B, L, D) or (B, N, S*, H) - // value (V) : (B, L, D_v) or (B, N, S*, H) - // bias (Q/K/V) : (D + D + D_v) - // When packed kv is used: + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) - // key (K) : (B, L, N, 2, H) - // value (V) : None - // bias (Q/K/V) : None - // When packed qkv is used: - // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): + // query (Q) : (B, S, D) + // key (K/V) : (B, L, N, 2, H) + // value : None + // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v): + // query (Q/K/V) : (B, S, N, 3, H) + // key : None + // value : None + // + // Other inputs: + // bias (Q/K/V) : None or (D + D + D_v) + // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) + // relative_position_bias : (B, N, S, T) or (1, N, S, T) + // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // --------------------------------------------------------------- + // DecoderMaskedMultiHeadAttention inputs (S == 1, D == D_v): + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T): + // query (Q) : (B, S, D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // QKV_BS3NH - packed qkv (S == L): + // query (Q) : (B, S, 3 * D) // key (K) : None // value (V) : None - // bias (Q/K/V) : None or (D + D + D_v) - - AttentionQkvFormat qkv_format; + // + // Other inputs: + // bias (Q/K/V) : None or (3 * D) + // key_padding_mask (K/V) : None or (B, T) + // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. + // + // The following inputs are not used in cross attention (so they are None for cross attention): + // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_value : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_sequence_length : scalar (1) when past_present_share_buffer is True. + // CUDA version has extra inputs (beam_width, cache_indirection) that are not checked in the class. + // For ROCm, see contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh for more details. + // --------------------------------------------------------------- + AttentionQkvFormat qkv_format = UNKNOWN; const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3 && query_dims.size() != 5) { + + int query_rank = static_cast(query_dims.size()); + if (query_rank != 3 && query_rank != 5) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ", - query_dims.size()); + query_rank); } int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); - int hidden_size = (query_dims.size() == 3) + bool dmmha_packing = operator_type == kDecoderMaskedMultiHeadAttention && key == nullptr && value == nullptr; + int hidden_size = (query_rank == 3) ? (dmmha_packing ? (static_cast(query_dims[2]) / 3) : static_cast(query_dims[2])) : (num_heads * static_cast(query_dims[4])); int head_size = static_cast(hidden_size) / num_heads; int kv_sequence_length = sequence_length; + int v_hidden_size = hidden_size; + if (key != nullptr) { + if (value == nullptr) { + ORT_RETURN_IF_ERROR(Check_Q_KV(query, key, num_heads, head_size, qkv_format, kv_sequence_length)); + } else { + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, head_size, + qkv_format, kv_sequence_length, v_hidden_size)); + } + } else if (value == nullptr) { // no key and value + ORT_RETURN_IF_ERROR(Check_QKV(query, qkv_format)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value' shall absent when 'key' is absent"); + } + int past_sequence_length = 0; int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - if (past_key_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 1 should be same as number of heads, got ", - past_key_dims[1]); - } - if (past_value_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 1 should be same as number of heads, got ", - past_value_dims[1]); - } - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", - past_key_dims[2], " vs ", past_value_dims[2]); - } - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - past_sequence_length = static_cast(past_key_dims[2]); - max_sequence_length = static_cast(past_key_dims[2]); - if (past_present_share_buffer) { - if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); - } - past_sequence_length = *((*past_seq_len).template Data()); - } + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, past_seq_len, + batch_size, num_heads, head_size, past_present_share_buffer, + past_sequence_length, max_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ", - query_dims.size()); - } - - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3 && key_dims.size() != 4 && key_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3, 4, or 5 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { + if (operator_type == kMultiHeadAttention) { + if (qkv_format == AttentionQkvFormat::QKV_BS3NH) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); + "Packed qkv of 3D BS3NH format is not support by MultiHeadAttention"); } - if (key_dims.size() == 3) { - if (key_dims[2] != query_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else if (key_dims.size() == 5) { - if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); - } - - qkv_format = Q_KV_BSNH_BSN2H; - kv_sequence_length = static_cast(key_dims[1]); - } else { // key_dims.size() == 4 (cross-attention with past_key) - if (static_cast(key_dims[1]) != num_heads || static_cast(key_dims[3]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, num_heads, kv_sequence_length, head_size)"); - } - - if (value == nullptr || value->Shape().GetDims().size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' shall be 4D when 'key' is 4D"); - } - - if (bias != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when 'key' is 4D"); - } - - qkv_format = UNKNOWN; - kv_sequence_length = static_cast(key_dims[2]); - } - } else { // packed QKV - if (query_dims.size() != 3 && query_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions when key is empty, got ", - query_dims.size()); - } - if (query_dims.size() == 5 && (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3)) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv"); + if (qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H && bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when packed kv is used"); } - - qkv_format = QKV_BSN3H; } if (bias != nullptr) { @@ -206,116 +384,31 @@ Status CheckInputs(const T* query, bias_dims.size()); } - if (value == nullptr) { - // Currently, bias is not allowed for packed KV. This constraint can be removed later. - // Here we assume that fusion tool will not include bias for packed KV. - if (query_dims.size() == 5 && query_dims[3] == 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); - } + int expected_bias_length = 2 * hidden_size + v_hidden_size; + if (bias_dims[0] != static_cast(expected_bias_length)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' length is expected to be 2 * hidden_size + hidden_size_v, got ", + bias_dims.size()); } } int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { - mask_type = AttentionMaskType::MASK_UNKNOWN; - const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1) { - if (mask_dims[0] == static_cast(batch_size)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(kv_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(sequence_length) && - mask_dims[2] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_3D_ATTENTION; - } - + mask_type = GetMaskType(key_padding_mask, batch_size, sequence_length, total_sequence_length); if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); - } - } - - // NOTE: In Cross-Attention, we pass the past key and value to 'key' and 'value' instead of 'past_key' and 'past_value'. - bool pass_past_in_kv = false; - int v_hidden_size = hidden_size; - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3 && value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 or 4 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (value_dims.size() == 3) { - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } - v_hidden_size = static_cast(value_dims[2]); - } else { // value_dims.size() == 4 - if (static_cast(kv_sequence_length) != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 2 (kv_sequence_length)"); - } - - if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be empty when 'value' is 4D"); - } - - v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); - pass_past_in_kv = true; + "Input 'key_padding_mask' shape is not expected."); } } bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - if (relative_position_bias_dims[3] != total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); - } + ORT_RETURN_IF_ERROR(CheckRelativePositionBias( + relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias)); } - // TODO: ORT_RETURN_IF(qkv_format == UNKNOWN, "Unrecognized QKV format"); + assert(qkv_format != UNKNOWN); + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; @@ -323,7 +416,7 @@ Status CheckInputs(const T* query, output_parameters->past_sequence_length = past_sequence_length; output_parameters->kv_sequence_length = kv_sequence_length; output_parameters->total_sequence_length = total_sequence_length; - output_parameters->max_sequence_length = max_sequence_length; + output_parameters->max_sequence_length = past_present_share_buffer ? max_sequence_length : total_sequence_length; output_parameters->input_hidden_size = 0; output_parameters->hidden_size = hidden_size; output_parameters->v_hidden_size = v_hidden_size; @@ -336,7 +429,6 @@ Status CheckInputs(const T* query, output_parameters->mask_type = mask_type; output_parameters->scale = scale; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = pass_past_in_kv; output_parameters->qkv_format = qkv_format; } @@ -359,7 +451,7 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing, + AttentionType operator_type, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); @@ -367,7 +459,7 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, - past_present_share_buffer, dmmha_packing); + past_present_share_buffer, operator_type); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 9e6752b451868..62d6a723bf32c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -520,6 +520,39 @@ __global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) } } +template +__global__ void AddBiasTransposeUnpack(int M, const T* input, const T* biases, T* output) { + // Format 5 to unpack TRT packed input format to BNSH for unfused attention. + // Input: BxSxNxMxH + // Output: MxBxNxSxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = (s + n * sequence_length) * head_size + (b + m * batch_size) * NHS; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -692,6 +725,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else if (format == 5) { // format == 5 + AddBiasTransposeUnpack<<>>(total_matrix_count, input, biases, output); } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } @@ -716,6 +751,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); + } else if (format == 5) { // format == 5 + ORT_THROW("AddBiasTranspose (format 5) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -904,6 +941,7 @@ void InvokeAddBias( AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } } + // K { const dim3 grid(kv_sequence_length, batch_size, num_matrices); @@ -1011,6 +1049,82 @@ void LaunchAddBias( } } +template +void InvokeAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q) { + assert(num_heads <= max_threads_per_block); + constexpr int num_matrices = 1; + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const float* biases, const float* query, float* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const float4* query2 = reinterpret_cast(query); + const float4* biases2 = reinterpret_cast(biases); + float4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const float2* query2 = reinterpret_cast(query); + const float2* biases2 = reinterpret_cast(biases); + float2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const half* biases, const half* query, half* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const Half4* query2 = reinterpret_cast(query); + const Half4* biases2 = reinterpret_cast(biases); + Half4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const half2* query2 = reinterpret_cast(query); + const half2* biases2 = reinterpret_cast(biases); + half2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index efc31db43bcdb..bd4e123a272bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -3,14 +3,15 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { namespace cuda { -// Fused kernel of Add (bias) and Transpose. +// Fused kernel of Add bias (optional, can be None) and Transpose. // Shape of inputs and outputs: -// biases: (num_matrices, num_heads * head_size) +// biases: (num_matrices, num_heads * head_size) or None // format 0: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (num_matrices, batch_size, sequence_length, num_heads, head_size) // output: (num_matrices, batch_size, num_heads, sequence_length, head_size) @@ -24,9 +25,12 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) -// format 4: (requires qk_head_size = v_head_size) +// format 4: (requires qk_head_size == v_head_size) // input: (batch_size, sequence_length, num_heads, num_matrices, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 5: (requires qk_head_size == v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, num_heads, sequence_length, head_size) template void LaunchAddBiasTranspose( @@ -35,7 +39,7 @@ void LaunchAddBiasTranspose( const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr, int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0); -// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format. +// Add bias (optional, can be None) and Transpose for separated inputs of Q, K and V, and output Trt format. // For self attention: // output: (batch_size, sequence_length, num_heads, 3, head_size) // It assumes sequence_length == kv_sequence_length and head_size == v_head_size. @@ -50,7 +54,7 @@ void LaunchAddBiasTransposeTrt( const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length = -1); -// Add (bias) for separated inputs of Q, K and V. +// Add bias (required) for separated inputs of Q, K and V. // Q: (batch_size, sequence_length, num_heads, head_size) // K: (batch_size, kv_sequence_length, num_heads, head_size) // V: (batch_size, kv_sequence_length, num_heads, v_head_size) @@ -61,6 +65,46 @@ void LaunchAddBias( const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v); +// Add bias (required) for Q: (batch_size, sequence_length, num_heads, head_size) +template +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q); + +// Add bias (optional, can be None) transpose kernel defined in packed_multihead_attention_impl.cu. +// Support the following format transforms (for float and half only). +// source_format => target_format: +// Q_K_V_TNH => Q_K_V_BNSH (requires token_offset) +// Q_K_V_TNH => Q_K_V_TNH +// Q_K_V_TNH => QKV_TN3H +// QKV_TN3H => Q_K_V_BNSH (requires token_offset) +// QKV_TN3H => Q_K_V_TNH +// QKV_TN3H => QKV_TN3H +template +void AddBiasTransposePacked( + const T* query, const T* key, const T* value, const T* bias, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +// Add bias (required) transpose kernel defined in packed_attention_impl.cu. +// Support the following format transforms (for float and half only): +// format transform +// Q_K_V_BNSH: Tx3xNxH => 3xBxNxSxH (requires token_offset) +// Q_K_V_BSNH: Tx3xNxH => 3xTxNxH +// QKV_BSN3H: Tx3xNxH => TxNx3xH +template +void AddBiasTransposePacked( + const T* input, const T* biases, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3b7f980ba1881..5c0989bced70c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -260,7 +260,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; @@ -281,6 +282,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 997493acd9cb7..f9eabe27d97e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,31 +58,25 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { - if (this->sequence_length != seq_length) { - ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); - LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, seq_length, stream); - this->sequence_length = seq_length; +const int32_t* CumulatedSequenceLengthCache::TryGet(int batch_size, int32_t seq_len, cudaStream_t stream) { + if (this->sequence_length == 0 && seq_len > 0) { + // Initialize only once with sequence length in the first request. + std::call_once(init_once_flag_, [&]() { + ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, + this->max_batch_size, seq_len, stream); + // Syncronize to ensure thread-safe since other thread will not wait for the above kernel finish. + // Otherwise, the data might be consumed by other threads before it is ready and causes data race issue. + cudaStreamSynchronize(stream); + this->sequence_length = seq_len; + }); } -} -int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache, - const int* mask_index, - int batch_size, - int sequence_length, - cudaStream_t stream, - void* scratch_buffer) { - if (mask_index == nullptr && cache != nullptr) { - if (batch_size <= cache->max_batch_size) { - cache->Initialize(sequence_length, stream); - return reinterpret_cast(cache->buffer.get()); - } + if (this->sequence_length == seq_len && batch_size <= this->max_batch_size) { + return reinterpret_cast(buffer.get()); } - int* sequence_offset = reinterpret_cast(scratch_buffer); - LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream); - return sequence_offset; + return nullptr; } size_t GetAttentionScratchSize( @@ -114,10 +108,12 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention) { + bool use_memory_efficient_attention, + bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. - const size_t qkv_bytes = element_size * batch_size * num_heads * - ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_size = element_size * batch_size * num_heads * + ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_bytes = no_qkv_workspace ? 0 : qkv_size; #if USE_FLASH_ATTENTION if (use_flash_attention) { @@ -162,39 +158,44 @@ Status FusedTrtCrossAttention( // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr); - + assert(data.scratch != nullptr); + assert(data.q != nullptr); + assert(data.k != nullptr); + +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + 2 * GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, - sequence_length, stream, - data.scratch); + int32_t* q_sequence_offset = const_cast(data.cumulated_sequence_length_q_cache); + if (q_sequence_offset == nullptr) { + q_sequence_offset = reinterpret_cast(data.scratch); + LaunchTrtSequenceOffset(q_sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_INIT(); DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, parameters.kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + int32_t* kv_sequence_offset = const_cast(data.cumulated_sequence_length_kv_cache); + if (kv_sequence_offset == nullptr) { + int* scratch = reinterpret_cast(data.scratch) + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = reinterpret_cast(scratch); + LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, parameters.kv_sequence_length, stream); + } + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = data.q; - void const* packed_kv = data.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV + data.q, // Q + data.k, // packed KV q_sequence_offset, // cumulated sequence length of Q kv_sequence_offset, // cumulated sequence length of KV data.output, // output @@ -206,8 +207,6 @@ Status FusedTrtCrossAttention( parameters.kv_sequence_length, // sequence length of KV stream); - DUMP_TENSOR("trt cross output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -225,24 +224,33 @@ Status FusedTrtSelfAttention( cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { + assert(data.scratch != nullptr); +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const bool causal = parameters.is_unidirectional; - int* sequence_offset = reinterpret_cast(data.scratch); - - DUMP_TENSOR_INIT(); + const int32_t* sequence_offset = data.cumulated_sequence_length_q_cache; if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + LaunchTrtSequenceOffset2d(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); + if (sequence_offset == nullptr) { + LaunchTrtSequenceOffset(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); + } } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length); @@ -252,22 +260,12 @@ Status FusedTrtSelfAttention( if (!causal) { assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = data.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + fused_fp16_runner->Run(b, s, data.q, sequence_offset, data.output, stream); } else { assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); } @@ -289,38 +287,19 @@ Status FlashAttention( contrib::AttentionParameters& parameters, AttentionData& data, float scale) { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); assert(nullptr == data.mask_index); assert(nullptr == data.relative_position_bias); assert(parameters.head_size == parameters.v_head_size); - void* query = reinterpret_cast(data.q); - void* key = reinterpret_cast(data.k); - void* value = reinterpret_cast(data.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); - - bool is_bf16 = false; + constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - true)); - - DUMP_TENSOR("flash attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); } @@ -351,25 +330,8 @@ Status EfficientAttention( float scale) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = data.q; - const void* key = data.k; - const void* value = data.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -394,21 +356,19 @@ Status EfficientAttention( ? nullptr : const_cast(reinterpret_cast( data.mask_index + 2 * parameters.batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; + p.query = data.q; + p.key = data.k; + p.value = data.v; p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.is_kv_bsnh = true; + p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; p.stream = stream; p.has_custom_right_padding = false; run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -449,10 +409,6 @@ Status UnfusedAttention( cublasSetStream(cublas, stream); - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); - const int present_sequence_length = parameters.past_present_share_buffer ? parameters.max_sequence_length : total_sequence_length; @@ -467,8 +423,7 @@ Status UnfusedAttention( &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32)); - DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_INIT(); DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); constexpr size_t element_size = sizeof(T); @@ -523,7 +478,6 @@ Status UnfusedAttention( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, device_prop.maxThreadsPerBlock, false, temp_output, data.output); - DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } @@ -554,7 +508,7 @@ Status QkvToContext( if (!parameters.past_present_share_buffer) { ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, + sequence_length, total_sequence_length, stream, max_threads_per_block, data)); } else { // past_present_share_buffer diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 56836bdda197c..fad353dcfeb07 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include "core/framework/allocator.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -15,13 +17,18 @@ namespace cuda { constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; +// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that. struct CumulatedSequenceLengthCache { onnxruntime::IAllocatorUniquePtr buffer; int32_t max_batch_size; int32_t sequence_length; - CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {} - void Initialize(int32_t sequence_length, cudaStream_t stream); + CumulatedSequenceLengthCache() : max_batch_size(kCumulatedSequenceLengthCacheMaxBatchSize), sequence_length(0) {} + + const int32_t* TryGet(int batch_size, int32_t sequence_length, cudaStream_t stream); + + // Use this flag to guard the initializaton only once in multi-threading. + mutable std::once_flag init_once_flag_; }; size_t @@ -46,7 +53,8 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention); + bool use_memory_efficient_attention, + bool no_qkv_workspace); template struct AttentionData { @@ -65,8 +73,6 @@ struct AttentionData { bool has_qkv_workspace = false; T* workspace = nullptr; - T* temp_k_workspace = nullptr; - T* temp_v_workspace = nullptr; T* output = nullptr; T* present = nullptr; @@ -79,22 +85,50 @@ struct AttentionData { bool use_flash_attention = false; bool use_memory_efficient_attention = false; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; + const int32_t* cumulated_sequence_length_q_cache = nullptr; + const int32_t* cumulated_sequence_length_kv_cache = nullptr; // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; T* scratch = nullptr; - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + AttentionQkvFormat qkv_format = AttentionQkvFormat::UNKNOWN; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + + // For Debugging + size_t workspace_bytes = 0; + bool allow_debug_info = false; + + bool IsUnfused() const { + return !use_flash_attention && !use_memory_efficient_attention && + (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr); + } + + void PrintDebugInfo() const { + std::cout << "flash=" << use_flash_attention + << ", efficient=" << use_memory_efficient_attention + << ", fused_runner=" << (fused_runner != nullptr) + << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) + << ", bias=" << (bias != nullptr) + << ", attn_bias=" << (relative_position_bias != nullptr) + << ", mask_dims=" << mask_index_dims.size() + << ", has_qkv_workspace=" << has_qkv_workspace + << ", workspace=" << workspace_bytes + << ", past=" << (past != nullptr ? 1 : (past_key != nullptr ? 2 : 0)) + << ", present=" << (present != nullptr ? 1 : (present_key != nullptr ? 2 : 0)) + << std::endl; + } }; +// Return true if it does not need qkv workspace, false otherwise. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, @@ -129,6 +163,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, int total_matrix_count = -1); +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block); + Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const half* input, half* output, cudaStream_t stream, const int max_threads_per_block); @@ -158,7 +195,7 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index bd7df5f490c76..aba1e01bfd91b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -50,6 +50,7 @@ class AttentionKernelOptions { bool use_unfused_{true}; bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; // Causal attention is disabled by default in #14732. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 89be0f1115f41..9f0f49348c225 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -249,16 +249,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, - cudaStream_t stream, - int max_threads_per_block, + int sequence_length, int total_sequence_length, + cudaStream_t stream, int max_threads_per_block, AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. - if (nullptr != data.present) { + if (nullptr != data.present) { // Attention op assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); @@ -270,58 +269,52 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int // Update pointers to present_k and present_v. data.k = data.present; data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } else if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + } else { // MultiHeadAttention op + if (nullptr != data.present_key) { + ORT_ENFORCE(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + if (nullptr != data.past_key) { + assert(data.past_key != data.k); + assert(data.past_value != data.v); + + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + // Update pointers to present_k and present_v. data.k = data.present_key; data.v = data.present_value; - } else { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - data.k = data.temp_k_workspace; - data.v = data.temp_v_workspace; + } else { // nullptr == data.past_key && nullptr != data.present_key + if (data.k != data.present_key) { + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present_key, data.k, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } + + if (data.v != data.present_value) { + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_value, data.v, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } } - } else if (pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, data.k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, data.v, data.present_value)); - // Update pointers to present_k and present_v. - data.k = data.present_key; - data.v = data.present_value; } } + return CUDA_CALL(cudaGetLastError()); } // Template Instantiation template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 040d6124e7456..05c592ec61059 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -12,12 +12,101 @@ namespace onnxruntime { namespace contrib { namespace cuda { +#if DEBUG_TENSOR_LEVEL > 1 +// Dump the workspace for Q, K, V after processing QKV data. +template +void DumpQkv(AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("q(BSN3H)", data.q, batch_size, sequence_length, num_heads * 3, qk_head_size); + } +} + +// Dump the inputs before processing QKV data. +template +void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BSNH)", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Key(BSNH)", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Value(BSNH)", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("Query(BSN3H)", data.query, batch_size, sequence_length, num_heads * 3, qk_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BSN2H)", data.value, batch_size, sequence_length, num_heads * 2, qk_head_size); + } + + if (data.bias != nullptr) { + DUMP_TENSOR_D("Q_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("K_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } + + if (data.relative_position_bias != nullptr) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + parameters.broadcast_res_pos_bias ? 1 : batch_size, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr) { + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length); + } + if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1); + } + } +} + +// Dump the kernel outputs +template +void DumpOutputs(AttentionData& data) { + DUMP_TENSOR_INIT(); + DUMP_TENSOR("output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); +} +#endif + template Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -40,7 +129,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, int matrix_to_trans = (past_present_share_buffer ? 1 : 3); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } else { // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) @@ -48,13 +137,13 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, // For fused causal kernel, use format 1 since we need have K and V to update present state, // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); + data.qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); // For fused causal, we will update gemm_buffer with bias directly. T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; @@ -71,367 +160,526 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, return Status::OK(); } -// For MultiHeadAttention with past state +// Return true if the workspace is not needed for Q, K, V inputs, false otherwise. +// This shall be in sync with the following function PrepareQkv_MHA_Cross. template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { +bool NoQkvWorkspace_MHA_Cross(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr); +} + +// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format) +template +Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + // past_key or past_value is not supported for cross attention + // present_key and present_value can be supported in theory, although we do not allow the senario for now. + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_Cross(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Add bias for Q + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + } else { + data.q = const_cast(data.query); } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + if (data.bias == nullptr) { + // Transpose query from BSNH to BNSH ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); + max_threads_per_block, false, data.query, data.q)); + } else { + // Add bias to query, and transpose it: Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + // So we do not need to add bias for key and value. Just use the key and value directly. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr; +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + assert(data.mask_index == nullptr); + assert(parameters.hidden_size == parameters.v_hidden_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); + data.v = nullptr; + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; } #if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, data.q, data.k, data.v); + } else { + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); +#endif + else if (data.fused_runner != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + + // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length); + data.k = nullptr; + data.v = nullptr; + + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k, + true, -1); - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) { + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV. + return data.past_key == nullptr && data.present_key != nullptr; + } + return false; +} + +// For MultiHeadAttention with kv cache (past or present), but no bias +template +Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.bias == nullptr); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; + } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Use oiginal Query (BSNH) since there is no bias. + data.q = const_cast(data.query); + + // Key (BxLxNxH) => K (BxNxLxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + // Value (BxLxNxH) => V (BxNxLxH) ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, data.q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +template +constexpr bool NoQkvWorkspace_MHA_WithPast_Bias(AttentionData& /*data*/) { + return false; +} + +// For MultiHeadAttention with both kv cache (past or present) and bias +template +Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.bias != nullptr); + assert(!(data.past_key != nullptr && data.present_key == nullptr)); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_Bias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH) + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, true, -1); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else #endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed + { // unfused kernel + assert(data.IsUnfused()); + constexpr int format = 0; // Query (BxSxNxH) => Q (BxNxSxH) LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, + data.query, data.bias, data.q, true, -1); - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, + true, -1); - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +template +bool NoQkvWorkspace_MHA_PackedQKV(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return nullptr != data.fused_runner && data.bias == nullptr; +} + // For MultiHeadAttention without past state, with packed QKV inputs template Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedQKV(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + // unpack qkv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, + data.query, data.bias, data.q, true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (nullptr != data.fused_runner) { + assert(nullptr == data.relative_position_bias); + if (data.bias == nullptr) { + // When there is no bias, we can directly use the original packed QKV input. + // Need revisit this when we add support for causal. + data.q = const_cast(data.query); + data.k = nullptr; + data.v = nullptr; + } else { // data.bias != nullptr + AddBiasTransposePacked( + data.query, data.key, data.value, data.bias, data.q, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + AttentionQkvFormat::QKV_TN3H, AttentionQkvFormat::QKV_TN3H, + nullptr, batch_size * sequence_length, + stream); } - qkv_format = AttentionQkvFormat::QKV_BSN3H; + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // unpack qkv to BNSH + constexpr int format = 5; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, v_head_size, qkv_add_bias, 3); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +// This shall be in sync with the following function PrepareQkv_MHA_PackedQKV. +template +bool NoQkvWorkspace_MHA_PackedKV(AttentionData& data) { + return data.fused_cross_attention_kernel != nullptr; +} + // For MultiHeadAttention without past state, with packed KV inputs template Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); + assert(data.bias == nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_runner == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedKV(data)); + const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + // Note that there is no bias so we need not output query to q. + data.q = const_cast(data.query); + // Unpack kv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, batch_size, kv_sequence_length, num_heads, qk_head_size, data.key, kv_bias, data.k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } + true, v_head_size, qkv_add_bias); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (data.fused_cross_attention_kernel != nullptr) { + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = nullptr; + } else { // unfused kernel + assert(data.IsUnfused()); + // Transpose q from BSNH to BNSH. Note that there is no bias. + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(batch_size, parameters.sequence_length, num_heads, qk_head_size, + data.query, data.q, stream, max_threads_per_block)); - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + // Unpack kv to BNSH. + constexpr int format = 5; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, data.k, + true, v_head_size, qkv_add_bias, 2); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } -// For MultiHeadAttention without past state, with Q, K and V inputs +// Prepare Q, K and V for MultiHeadAttention operator. template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif - - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_Cross(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::QKV_BSN3H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block)); + } + } else { // no past state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block)); + } + break; + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + return Status::OK(); +} - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; +// Check whether there is no needed to have workspace for Q, K and V for MultiHeadAttention operator. +// Please make it in sync with PrepareQkv_MultiHeadAttention. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + return NoQkvWorkspace_MHA_Cross(data); + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + return NoQkvWorkspace_MHA_PackedKV(data); + case AttentionQkvFormat::QKV_BSN3H: + return NoQkvWorkspace_MHA_PackedQKV(data); + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + return NoQkvWorkspace_MHA_WithPast_NoBias(data); + } else { + return NoQkvWorkspace_MHA_WithPast_Bias(data); + } + } else { // no past state + return NoQkvWorkspace_MHA_NoPast(data); + } + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } - return Status::OK(); } template @@ -439,7 +687,6 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block) { - data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -452,28 +699,37 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.k = data.workspace + elements_q; data.v = data.k + elements_k; data.scratch = data.v + elements_v; + } else { + data.q = nullptr; + data.k = nullptr; + data.v = nullptr; + data.scratch = data.workspace; } +#if DEBUG_TENSOR_LEVEL > 1 + DumpInputs(parameters, data); +#endif + if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - data.qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block)); + } else { // MultiHeadAttention operator + ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention(parameters, data, stream, max_threads_per_block)); } + assert(data.qkv_format != AttentionQkvFormat::UNKNOWN); + +#if DEBUG_TENSOR_LEVEL > 1 + DumpQkv(data); +#endif + CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } // Template Instantiation +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index bd38a21aadfcb..9f3e396b7f949 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -304,6 +304,12 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c max_threads_per_block, false, input, output); } +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 66c0aceaed1e7..037a4fdf3d9a0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -75,7 +75,6 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); bool is_unidirectional = false; - bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, @@ -91,7 +90,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* scale_, is_unidirectional, past_present_share_buffer_, - is_dmmha_packing, // dmmha_packing + kDecoderMaskedMultiHeadAttention, device_prop.maxThreadsPerBlock)); if (bias) { @@ -157,7 +156,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.is_cross_attention = true; parameters.total_sequence_length = parameters.kv_sequence_length; parameters.max_sequence_length = parameters.kv_sequence_length; - // parameters.k and paraneters.v are nullptr + // parameters.k and parameters.v are nullptr parameters.k_cache = const_cast(key->Data()); parameters.v_cache = const_cast(value->Data()); parameters.k_bias = nullptr; @@ -188,12 +187,14 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* } parameters.is_cross_attention = false; - parameters.is_packed_qkv = is_dmmha_packing; - parameters.k = is_dmmha_packing + bool is_packed_qkv = (key == nullptr && value == nullptr); + parameters.is_packed_qkv = is_packed_qkv; + + parameters.k = is_packed_qkv ? const_cast(query->Data() + parameters.hidden_size) : const_cast(key->Data()); - parameters.v = is_dmmha_packing + parameters.v = is_packed_qkv ? const_cast(query->Data() + 2 * static_cast(parameters.hidden_size)) : const_cast(value->Data()); parameters.k_cache = present_key_data; diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 9efb6f08e8e99..2f8d277cb7342 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -183,6 +183,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; } + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (!params.is_cross_attention) { Qk_vec_k k; @@ -580,6 +581,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; + + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (params.v_bias && !params.is_cross_attention) { zero(v_bias); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 663bd020ddac7..c36abc8e1d624 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -44,7 +45,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); + ORT_ENFORCE(!is_unidirectional_, + "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -95,7 +97,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { scale_, is_unidirectional_, false, // past_present_share_buffer - false, // dmmha_packing + kMultiHeadAttention, device_prop.maxThreadsPerBlock)); int sequence_length = parameters.sequence_length; @@ -111,25 +113,43 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_key = context->Output(1, present_shape); Tensor* present_value = context->Output(2, present_shape); - MHARunner* fused_runner = nullptr; + int num_past = static_cast(past_key != nullptr) + static_cast(past_value != nullptr); + int num_present = static_cast(present_key != nullptr) + static_cast(present_value != nullptr); + if (num_past == 0 && num_present == 0) { + // It is valid case without past state. + } else if ((num_past == 2 && num_present == 2) || (num_past == 0 && num_present == 2)) { + if (parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed QKV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed KV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for cross attention"); + } + } else { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be all provided, " + "or all empty, or only present_key and present_value are provided"); + } + MHARunner* fused_runner = nullptr; const FusedMultiHeadCrossAttentionKernel* fused_cross_attention_kernel = nullptr; // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; - bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - - const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); - -#if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION - // Exclude this case since PrepareQkv will convert the format to BNSH. - bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; -#endif - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - !past_no_bias && nullptr == relative_position_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -138,7 +158,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. - if (use_flash_attention && key == nullptr && value == nullptr && + if (use_flash_attention && parameters.qkv_format == AttentionQkvFormat::QKV_BS3NH && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } @@ -162,19 +182,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - bool use_fused_cross_attention = !use_flash_attention && - !disable_fused_cross_attention_ && - nullptr == key_padding_mask && - nullptr == relative_position_bias && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - key != nullptr && - (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV - parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + bool use_fused_cross_attention = + !use_flash_attention && + !disable_fused_cross_attention_ && + nullptr == key_padding_mask && + nullptr == relative_position_bias && + nullptr == past_key && nullptr == present_key && + (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && + parameters.hidden_size == parameters.v_hidden_size && + has_fused_cross_attention_kernel(sm, parameters.head_size, + parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { - fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + std::call_once(fused_cross_init_once_flag_, [&]() { + fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -184,17 +206,18 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !use_flash_attention && - !disable_fused_self_attention_ && - fused_cross_attention_kernel == nullptr && - nullptr == relative_position_bias && - (value != nullptr || key == nullptr) && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - (nullptr == key_padding_mask || is_mask_1d_seq_len) && - parameters.hidden_size == parameters.v_hidden_size && - parameters.sequence_length == parameters.kv_sequence_length && - FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + bool use_fused_runner = + !use_flash_attention && + !disable_fused_self_attention_ && + fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && + (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && + nullptr == past_key && nullptr == present_key && + (nullptr == key_padding_mask || AttentionMaskType::MASK_1D_KEY_SEQ_LEN) && + parameters.hidden_size == parameters.v_hidden_size && + parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { @@ -214,10 +237,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); - bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 + bool is_long_sequence = std::is_same::value || // sequence length threshold is 0 for FP16 parameters.sequence_length >= length_threshold || parameters.kv_sequence_length >= length_threshold; + // Check whether the relative position bias alignment is good for memory efficient attention. bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; bool use_memory_efficient_attention = @@ -226,82 +250,25 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - !past_no_bias && (relative_position_bias == nullptr || is_good_for_rpb) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); + has_memory_efficient_attention(sm, std::is_same::value, + parameters.head_size, parameters.v_head_size); #else constexpr bool use_memory_efficient_attention = false; #endif - if (kernel_options_->AllowDebugInfo()) { - AttentionKernelDebugInfo debug_info; - debug_info.use_flash_attention = use_flash_attention; - debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; - debug_info.use_efficient_attention = use_memory_efficient_attention; - if (fused_fp16_runner_ != nullptr) { - debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); - } - - debug_info.Print("MultiHeadAttention", - this->Node().Name(), - std::is_same::value, - std::is_same::value); - } - - // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. - // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. - bool no_qkv_workspace = nullptr == value && - (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && - nullptr == key_padding_mask && - nullptr == bias; - - size_t workspace_bytes; - constexpr size_t element_size = sizeof(T); - if (no_qkv_workspace) { - workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0; - } else { - workspace_bytes = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.v_head_size, - parameters.sequence_length, - parameters.kv_sequence_length, - parameters.total_sequence_length, - fused_runner, - use_flash_attention, - use_fused_cross_attention, - use_memory_efficient_attention); - } - - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - const size_t past_k_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.head_size; - const size_t past_v_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.v_head_size; - const bool use_temp_k_v_workspace = parameters.pass_past_in_kv || use_memory_efficient_attention || use_flash_attention; - auto temp_k_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; - auto temp_v_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; - typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); - data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); - data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); + data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) - : (nullptr == past_key) ? nullptr - : reinterpret_cast(past_key->Data()); - data.past_value = pass_key_value_as_past ? reinterpret_cast(value->Data()) - : (nullptr == past_value) ? nullptr - : reinterpret_cast(past_value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; - data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); @@ -309,8 +276,41 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); - data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + + // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). + // The cache will be initialized only once, and become readonly after that. + if ((data.fused_cross_attention_kernel != nullptr || data.fused_runner != nullptr) && data.mask_index == nullptr) { + cudaStream_t stream = Stream(context); + data.cumulated_sequence_length_q_cache = this->cumulated_sequence_length_q_cache_.TryGet( + parameters.batch_size, parameters.sequence_length, stream); + + if (data.fused_cross_attention_kernel != nullptr) { + data.cumulated_sequence_length_kv_cache = this->cumulated_sequence_length_kv_cache_.TryGet( + parameters.batch_size, parameters.kv_sequence_length, stream); + } + } + + const bool no_qkv_workspace = NoQkvWorkspace(parameters, data); + size_t workspace_bytes = GetAttentionWorkspaceSize(sizeof(T), + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.v_head_size, + parameters.sequence_length, + parameters.kv_sequence_length, + parameters.total_sequence_length, + fused_runner, + use_flash_attention, + use_fused_cross_attention, + use_memory_efficient_attention, + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + data.allow_debug_info = kernel_options_->AllowDebugInfo(); if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } @@ -318,8 +318,23 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - cublasHandle_t cublas = GetCublasHandle(context); + if (data.allow_debug_info) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + + data.PrintDebugInfo(); + } + cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 26e38dbad9fd7..68fd0c9943fca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" @@ -32,11 +33,16 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + + // These mutable members are readonly after they are initialized so that they can be shared among multiple threads. + // Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource. mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; + mutable std::once_flag fused_cross_init_once_flag_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ac2cb5165a94c..2521cd49b5482 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -297,7 +297,7 @@ struct T2 { }; template -void LaunchAddBiasTranspose( +void AddBiasTransposePacked( const T* input, const T* biases, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -452,7 +452,7 @@ Status FusedScaledDotProductAttention( void* fused_runner = data.fused_runner; ORT_RETURN_IF_NOT(nullptr != fused_runner, "fused_runner cannot be NULL"); - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::QKV_BSN3H, data.token_offset, @@ -477,7 +477,7 @@ Status FusedScaledDotProductAttentionCutlass( const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BSNH, data.token_offset, @@ -564,7 +564,7 @@ Status UnfusedScaledDotProductAttention( T* k = q + elements_q; T* v = k + elements_k; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BNSH, data.token_offset, @@ -657,6 +657,20 @@ Status QkvToContext( return UnfusedScaledDotProductAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const float* input, const float* biases, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const half* input, const half* biases, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index b4ca0194b08bc..e5a4c54f48903 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -502,7 +502,7 @@ struct T2 { }; template -void LaunchTranspose( +void AddBiasTransposePacked( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -566,11 +566,11 @@ Status FusedAttentionTrt( // When packed QKV is used, we can directly pass it to fused runner. Otherwise, we need transpose to BSN3H format. const T* qkv = data.query; if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, + data.token_offset, parameters.token_count, stream); qkv = data.workspace; } @@ -601,11 +601,11 @@ Status FlashAttention( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) @@ -675,11 +675,11 @@ Status FusedAttentionCutlass( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } MemoryEfficientAttentionParams p; @@ -746,11 +746,11 @@ Status UnfusedAttention( const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); // Q, K and V pointers when fused attention is not used - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, + data.token_offset, parameters.token_count, stream); T* qkv = data.workspace; T* q = qkv; @@ -848,6 +848,22 @@ Status QkvToContext( return UnfusedAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const half* query, const half* key, const half* value, const half* bias, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const float* query, const float* key, const float* value, const float* bias, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 168c69c69f003..b62e566d43f89 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -190,7 +190,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + true); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -208,6 +209,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index e10c2ec63fd51..6d52ff7282799 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -13,6 +13,9 @@ namespace cuda { #if DUMP_TENSOR_LEVEL > 0 +// Environment variable to enable/disable GPU Tensor dumping +constexpr const char* kEnableGpuTensorDumper = "ORT_ENABLE_GPU_DUMP"; + // Total number of elements which trigger snippet rather than full dump (default 200). Value 0 disables snippet. constexpr const char* kTensorSnippetThreshold = "ORT_TENSOR_SNIPPET_THRESHOLD"; @@ -202,6 +205,10 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { + is_enabled_ = ParseEnvironmentVariableWithDefault(kEnableGpuTensorDumper, 1) != 0; +} + void CudaTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } @@ -329,6 +336,8 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { +} void CudaTensorConsoleDumper::Print(const std::string&) const { } diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 6ad0ad9a67b75..4f41161cd4a31 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -13,7 +13,7 @@ namespace cuda { class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { public: - CudaTensorConsoleDumper() = default; + CudaTensorConsoleDumper(); virtual ~CudaTensorConsoleDumper() {} void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index b0ed3ff82226a..b94971ffd44d5 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -119,7 +119,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; return Status::OK(); } @@ -128,7 +128,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; return Status::OK(); } @@ -136,7 +136,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; return Status::OK(); } @@ -146,7 +146,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); } @@ -154,7 +154,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 349df045becf2..d593bc0012826 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -132,12 +132,6 @@ class CompatRocblasMathModeSetter { } }; -enum AttentionType { - kAttention, - kMultiHeadAttention, - kDecoderMaskedMultiHeadAttention, -}; - enum AttentionMode { // Q,K,V,PastK,PastV,PresentK,PresentV QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 09e7d61b71db9..5997daaca6e8a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -122,9 +122,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, num_heads_, - mask_filter_value_, scale_, false, /*is_unidirectional_*/ - past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ + past_present_share_buffer_, + attn_type_, + device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index dc2b38f3928ac..a9ff623fb6967 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -691,6 +691,9 @@ def create_multihead_attention_node( return None # Add bias to inputs for MHA + # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume + # bias has been added to key and value when they are in BNSH format, so only bias for query is used. + # Need add checks if we found such assumption is not true. if not self.disable_multi_head_attention_bias: bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name) mha_inputs.append(bias_name) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 715a92431e6bf..ec350874af32c 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -88,9 +88,11 @@ def __init__( enable_cuda_graph: bool = False, dtype=torch.float, use_kv_cache: bool = False, + has_past_input: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, verbose: bool = False, + has_bias: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -103,15 +105,25 @@ def __init__( self.causal = causal self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + # Support the case that there is no past but need present output (for prompt case). + self.has_past_input = has_past_input + if has_past_input: + assert use_kv_cache + else: # no past input + assert past_sequence_length == 0 + + self.has_present_output = use_kv_cache + self.use_kv_cache = use_kv_cache if not use_kv_cache: assert past_sequence_length == 0 else: assert self.kv_sequence_length == self.sequence_length - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - assert not use_kv_cache + # Only BSNH input format supports past state. + if input_format != InputFormats.Q_K_V_BSNH_BSNH_BSNH: + assert not self.has_past_input + assert not self.has_present_output # Derived values self.total_sequence_length = self.kv_sequence_length + past_sequence_length @@ -130,6 +142,7 @@ def __init__( self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H self.verbose = verbose + self.has_bias = has_bias def __repr__(self): return ( @@ -140,7 +153,8 @@ def __repr__(self): f"causal={self.causal}), softmax_scale={self.softmax_scale}, use_kv_cache={self.use_kv_cache}, " f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " - f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}" + f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " + f"has_bias={self.has_bias}" ) def shape_dict(self, input_format=None): @@ -176,16 +190,23 @@ def shape_dict(self, input_format=None): "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), } - if self.use_kv_cache: - assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" + if self.has_past_input: shapes = { **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), + } + + if self.has_present_output: + shapes = { + **shapes, "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes def symbolic_shape_dict(self, input_format=None): @@ -221,19 +242,26 @@ def symbolic_shape_dict(self, input_format=None): "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), } - if self.use_kv_cache: - assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" + if self.has_past_input: shapes = { **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), + } + + if self.has_present_output: + shapes = { + **shapes, "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes - def random_inputs(self, seed: int = 123): + def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): device = self.device dtype = self.dtype @@ -246,6 +274,14 @@ def random_inputs(self, seed: int = 123): q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + + bias_q = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_k = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_v = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + if no_bias_k_v: + bias_k = torch.zeros_like(bias_k) + bias_v = torch.zeros_like(bias_v) + k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) @@ -277,7 +313,7 @@ def random_inputs(self, seed: int = 123): "value": v_bnsh.contiguous(), } - if self.use_kv_cache: + if self.has_past_input: feeds = { **feeds, "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), @@ -286,6 +322,9 @@ def random_inputs(self, seed: int = 123): ), } + if self.has_bias: + feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous() + return feeds def get_input_output_names(self): @@ -299,15 +338,29 @@ def get_input_output_names(self): else: inputs, outputs = ["query", "key", "value"], ["output"] - if self.use_kv_cache: - return [*inputs, "past_key", "past_value"], [*outputs, "present_key", "present_value"] - else: - return inputs, outputs + if self.has_bias: + inputs = [*inputs, "bias"] + + if self.has_past_input: + inputs = [*inputs, "past_key", "past_value"] + + if self.has_present_output: + outputs = [*outputs, "present_key", "present_value"] + + return inputs, outputs def fill_optional_mha_inputs(input_names): inputs = ["query", "key", "value", "bias", "key_padding_mask", "relative_position_bias", "past_key", "past_value"] - return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:] + + # Remove optional inputs that are not in input_names with empty string + inputs_with_optional = [input if input in input_names else "" for input in inputs] + + # Remove empty string at the end of the list. + while inputs_with_optional[-1] == "": + inputs_with_optional.pop(-1) + + return inputs_with_optional def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False): @@ -317,7 +370,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use nodes = [ helper.make_node( "MultiHeadAttention", - fill_optional_mha_inputs(input_names) if config.use_kv_cache else input_names, + fill_optional_mha_inputs(input_names), output_names, "MultiHeadAttention_0", num_heads=config.num_heads, @@ -331,11 +384,13 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use inputs = [ helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) for input_name in input_names + if input_name ] outputs = [ helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name])) for output_name in output_names + if output_name ] graph = helper.make_graph( @@ -355,6 +410,7 @@ def create_ort_session( session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_symbolic_shape: bool = True, + use_tf32: bool = True, ) -> CudaSession: if config.verbose: print(f"create session for {vars(config)}") @@ -364,6 +420,7 @@ def create_ort_session( device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) provider_options["sdpa_kernel"] = int(attention_kernel) + provider_options["use_tf32"] = int(use_tf32) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] @@ -373,9 +430,11 @@ def create_ort_session( def create_session( - config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_tf32: bool = True ) -> CudaSession: - ort_session = create_ort_session(config, session_options, attention_kernel, use_symbolic_shape=False) + ort_session = create_ort_session( + config, session_options, attention_kernel, use_symbolic_shape=False, use_tf32=use_tf32 + ) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -385,8 +444,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__(self, config: MultiHeadAttentionConfig, session_options=None): - self.ort_session = create_session(config, session_options) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_tf32: bool = True): + self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() def infer(self): diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 0fcbd889847e9..a35d02b0b9d52 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -21,6 +21,47 @@ import onnxruntime +def get_provider_support_info(provider: str, use_kv_cache: bool): + if provider == "CUDAExecutionProvider": + if not use_kv_cache: + formats = [ + InputFormats.Q_K_V_BSNH_BSNH_BSNH, + InputFormats.Q_KV_BSNH_BSN2H, + InputFormats.QKV_BSN3H, + InputFormats.Q_K_V_BSNH_BNSH_BNSH, + ] + else: + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + else: + assert provider == "CPUExecutionProvider" + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + if not use_kv_cache: + formats.append(InputFormats.Q_K_V_BSNH_BNSH_BNSH) + device = torch.device("cpu") + dtype = torch.float + return device, dtype, formats + + +def get_bias_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] + + raise RuntimeError(f"Unknown format: {format}") + + def attention_reference( head_size: int, query: torch.Tensor, @@ -84,8 +125,8 @@ def attention_reference( def mha_with_past_reference( config: MultiHeadAttentionConfig, - past_k: torch.Tensor, - past_v: torch.Tensor, + past_k: Optional[torch.Tensor], + past_v: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -94,41 +135,23 @@ def mha_with_past_reference( ): assert config.kv_sequence_length == config.sequence_length assert config.use_kv_cache - assert past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) # both BNSH format - assert past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) # both BNSH format - - present_k = torch.cat((past_k, k), dim=2) - present_v = torch.cat((past_v, v), dim=2) + if past_k is not None: + assert ( + past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) + ), f"expect BNSH format: {past_k.shape=} {k.shape=}" + + if past_v is not None: + assert ( + past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) + ), f"expect BNSH format: {past_v.shape=} {v.shape=}" + + present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k + present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, mask=mask) return out, present_k, present_v -def get_provider_support_info(provider: str, use_kv_cache: bool): - if provider == "CUDAExecutionProvider": - if not use_kv_cache: - formats = [ - InputFormats.Q_K_V_BSNH_BSNH_BSNH, - InputFormats.Q_KV_BSNH_BSN2H, - InputFormats.QKV_BSN3H, - InputFormats.Q_K_V_BSNH_BNSH_BNSH, - ] - else: - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - dtype = torch.float16 - else: - assert provider == "CPUExecutionProvider" - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - if not use_kv_cache: - formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH) - device = torch.device("cpu") - dtype = torch.float - return device, dtype, formats - - def get_compute_capability(): if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): major, minor = torch.cuda.get_device_capability() @@ -143,35 +166,38 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 2048] + sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -179,25 +205,27 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for format in formats: + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def kv_cache_test_cases(provider: str, comprehensive: bool): @@ -206,37 +234,42 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 15, 16, 255, 256, 2048] + sequence_lengths = [1, 15, 16, 255, 256, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - sequence_length = 1 device, dtype, formats = get_provider_support_info(provider, True) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for past_sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -244,31 +277,31 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): past_sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config - - -def mha_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_test_cases(provider, comprehensive), kv_cache_test_cases(provider, comprehensive) - ) + for format in formats: + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): @@ -343,6 +376,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device=device, dtype=dtype, use_kv_cache=True, + has_past_input=True, share_past_present_buffer=False, input_format=format, ) @@ -350,13 +384,6 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): yield configs -def multi_thread_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_multi_thread_test_cases(provider, comprehensive), - kv_cache_multi_thread_test_cases(provider, comprehensive), - ) - - def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) @@ -374,28 +401,31 @@ def parity_check_mha( if config.causal and config.provider == "CUDAExecutionProvider": return - ort_mha = OrtMultiHeadAttention(config) + ort_mha = OrtMultiHeadAttention(config, use_tf32=False) ort_outputs = ort_mha.infer() out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + no_bias_k_v = config.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH - ref_inputs = config.random_inputs() - q = ( - ref_inputs["query"] - .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - k = ( - ref_inputs["key"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - v = ( - ref_inputs["value"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) + ref_inputs = config.random_inputs(no_bias_k_v=no_bias_k_v) + q = ref_inputs["query"].reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + k = ref_inputs["key"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + v = ref_inputs["value"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + + if "bias" in ref_inputs: + bias = ref_inputs["bias"] + bias = bias.reshape((3, config.num_heads, config.head_size)) + bias_q = bias[0, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_k = bias[1, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_v = bias[2, :, :].reshape(1, 1, config.num_heads, config.head_size) + q = q + bias_q + k = k + bias_k + v = v + bias_v + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) mask = None if config.causal: @@ -404,8 +434,8 @@ def parity_check_mha( k_cache = None v_cache = None if config.use_kv_cache: - past_k = ref_inputs["past_key"] - past_v = ref_inputs["past_value"] + past_k = ref_inputs.get("past_key", None) + past_v = ref_inputs.get("past_value", None) out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) else: out_ref = attention_reference(config.head_size, q, k, v, mask=mask) @@ -445,7 +475,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - attention_kernel: int = SdpaKernel.DEFAULT, + attention_kernel=SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -454,6 +484,7 @@ def parity_check_mha_multi_threading( # For now, MHA CUDA kernel does not support causal so skip such test cases. if config.causal and config.provider == "CUDAExecutionProvider": return None + # Some kernel does not support certain input format. if attention_kernel not in [ SdpaKernel.DEFAULT, @@ -462,7 +493,7 @@ def parity_check_mha_multi_threading( ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True) + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True, use_tf32=False) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -572,18 +603,32 @@ def check_parity_with_config(i: int): return None -# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. +def mha_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_test_cases(provider, comprehensive), + kv_cache_test_cases(provider, comprehensive), + ) + + +def multi_thread_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_multi_thread_test_cases(provider, comprehensive), + kv_cache_multi_thread_test_cases(provider, comprehensive), + ) + + +# Off by default so that we do not run too many tests in CI pipeline. comprehensive_mode = False class TestMultiHeadAttention(unittest.TestCase): @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cuda(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) @parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cpu(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): @@ -604,19 +649,24 @@ def run_mha_cuda_multi_threading(self, attention_kernel): assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): - self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) + if get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) def test_mha_cuda_multi_threading_efficient(self): - self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + + def test_mha_cuda_multi_threading_math(self): + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.MATH) def test_mha_cuda_multi_threading_trt(self): - sm = get_compute_capability() - if sm in [75, 80, 86, 89]: + if get_compute_capability() in [75, 80, 86, 89]: self.run_mha_cuda_multi_threading( SdpaKernel.TRT_FUSED_ATTENTION | SdpaKernel.TRT_FLASH_ATTENTION - | SdpaKernel.TRT_CROSS_ATTENTION | SdpaKernel.TRT_CAUSAL_ATTENTION + | SdpaKernel.TRT_CROSS_ATTENTION ) From a3883af7bfede84315abaa94fbc4cc2a0d2b02a3 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 1 Aug 2024 05:39:21 +0800 Subject: [PATCH 27/40] [WebNN EP] Fixed bug in ConvTranspose (#21569) The constraint of ConvTranspose was placed in wrong place. --- .../webnn/builders/impl/conv_op_builder.cc | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 22049d2519712..76a8a178678df 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -28,7 +28,7 @@ class ConvOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; @@ -378,6 +378,22 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } + // WebNN CPU backend (TFLite) only supports default dilations and group. + // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 + if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { + NodeAttrHelper helper(node); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + const auto group = helper.Get("group", 1); + if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; + return false; + } + if (group != 1) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; + return false; + } + } + return true; } @@ -427,22 +443,6 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - // WebNN CPU backend (TFLite) only supports default dilations and group. - // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 - if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { - NodeAttrHelper helper(node); - const auto dilations = helper.Get("dilations", std::vector{1, 1}); - const auto group = helper.Get("group", 1); - if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { - LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; - return false; - } - if (group != 1) { - LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; - return false; - } - } - return true; } From 8540ac4f78bf06c9b6a0fe90d71e76c947836817 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:30:33 -0700 Subject: [PATCH 28/40] Fix quant_format argument for 4bit quantizer (#21581) ### Description Original argument accepts Enum QuantFormat.QOperator or QuantFormat.QDQ, but the default value is QOperator. Change the argument to str to accept QOperator or QDQ and convert to QuantFormat after parsing. ### Motivation and Context Bug fix --- .../python/tools/quantization/matmul_4bits_quantizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 40a4a4d26dc1c..cc8bd622df9b1 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -797,8 +797,8 @@ def parse_args(): parser.add_argument( "--quant_format", default="QOperator", - type=QuantFormat, - choices=list(QuantFormat), + type=str, + choices=["QOperator", "QDQ"], help="QuantFormat {QOperator, QDQ}" "QOperator format quantizes the model with quantized operators directly." "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", @@ -814,7 +814,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model - quant_format = args.quant_format + quant_format = QuantFormat[args.quant_format] if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") From 4b8f6dcbb69ee9c74330d7785fe5b7ef656a94f5 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 31 Jul 2024 21:05:11 -0700 Subject: [PATCH 29/40] [QNN EP] Improve INT4 accuracy (#21582) ### Description Masks off top 4-bits of INT4 weights, improving accuracy. ### Motivation and Context This is a workaround as the QNN docs state masking is not required. --- .../qnn/builder/qnn_model_wrapper.cc | 6 + onnxruntime/test/providers/qnn/conv_test.cc | 5 +- .../test/providers/qnn/matmul_test.cpp | 154 ++++++++++++++++-- 3 files changed, 151 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index c8537307ef3ba..9d3f460572d84 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -617,6 +617,12 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + + // NOTE: Masking off top 4 bits to workaround a QNN INT4 accuracy bug. + // Docs explicitly state that masking off top 4 bits should not be required. + for (size_t i = 0; i < dst.size(); i++) { + dst[i] &= 0x0F; // -3 (0b1111_1101) becomes 13 (0b0000_1101) + } } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); const size_t num_elems = shape.Size(); diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 99636976b9c05..35889c9fa2307 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -799,7 +799,7 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { // CPU EP (f32 model): 25.143 21.554 17.964 10.785 7.195 3.605 -3.574 -7.164 -10.753 // CPU EP (qdq model): 24.670 21.103 17.536 10.254 6.689 2.972 -4.161 -7.728 -10.700 // QNN EP (qdq model): 27.186 27.186 27.186 21.541 6.685 -8.022 -10.548 -10.548 -10.548 -TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { +TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_AccuracyIssue) { std::vector input_shape = {1, 2, 4, 4}; std::vector weight_shape = {3, 2, 2, 2}; std::vector bias_shape = {3}; @@ -835,7 +835,8 @@ TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 21); // opset + 21, // opset + QDQTolerance(0.005f)); } // Test per-channel QDQ Conv is rejected with weight axis != 0 diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index dba60b1041696..d8c34d6a6c6ed 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -28,26 +28,25 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de // Returns a function that creates a graph with a QDQ MatMul operator. template -static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, - const TestInputDef& input2_def, +static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def, bool use_contrib_qdq) { - return [input1_def, input2_def, use_contrib_qdq](ModelTestBuilder& builder, + return [input0_def, input1_def, use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { // input1 -> Q -> DQ -> - NodeArg* input1 = MakeTestInput(builder, input1_def); - QuantParams input1_qparams = GetTestInputQuantParams(input1_def); - auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_qparams = GetTestInputQuantParams(input0_def); + auto* input0_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, input0_qparams.zero_point, use_contrib_qdq); - - // input2 -> Q -> DQ -> - NodeArg* input2 = MakeTestInput(builder, input2_def); - QuantParams input2_qparams = GetTestInputQuantParams(input2_def); - auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point, + // input1 -> Q -> DQ -> + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, use_contrib_qdq); // MatMul auto* op_output = builder.MakeIntermediate(); - builder.AddNode("MatMul", {input1_qdq, input2_qdq}, {op_output}); + builder.AddNode("MatMul", {input0_qdq, input1_qdq}, {op_output}); // op_output -> Q -> DQ -> output AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, @@ -55,6 +54,88 @@ static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDe }; } +template +static GetTestQDQModelFn BuildQDQPerChannelMatMulTestCase(const TestInputDef& input_def, + const TestInputDef& weights_def, + int64_t weight_quant_axis, + bool use_contrib_qdq = false) { + return [input_def, weights_def, weight_quant_axis, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector matmul_inputs; + + // input -> Q/DQ -> + auto* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + matmul_inputs.push_back(input_qdq); + + // Quantized(weights) -> DQ -> + ORT_ENFORCE(weights_def.IsInitializer() && weights_def.IsRawData()); + std::vector weight_scales; + std::vector weight_zero_points; + TensorShape weights_shape = weights_def.GetTensorShape(); + int64_t pos_weight_quant_axis = weight_quant_axis; + if (pos_weight_quant_axis < 0) { + pos_weight_quant_axis += static_cast(weights_shape.NumDimensions()); + } + GetTestInputQuantParamsPerChannel(weights_def, weight_scales, weight_zero_points, + static_cast(pos_weight_quant_axis), true); + + std::vector quantized_weights; + size_t num_weight_storage_elems = weights_shape.Size(); + if constexpr (std::is_same_v || std::is_same_v) { + num_weight_storage_elems = Int4x2::CalcNumInt4Pairs(weights_shape.Size()); + } + quantized_weights.resize(num_weight_storage_elems); + QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, + weight_scales, weight_zero_points, pos_weight_quant_axis); + + NodeArg* weights_initializer = builder.MakeInitializer(weights_def.GetShape(), quantized_weights); + NodeArg* weights_dq = builder.MakeIntermediate(); + Node& weights_dq_node = builder.AddDequantizeLinearNode(weights_initializer, weight_scales, + weight_zero_points, weights_dq, + nullptr, use_contrib_qdq); + weights_dq_node.AddAttribute("axis", weight_quant_axis); + matmul_inputs.push_back(weights_dq); + + auto* matmul_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", matmul_inputs, {matmul_output}); + + AddQDQNodePairWithOutputAsGraphOutput(builder, matmul_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ per-channel MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the +// QDQ model is accurate on QNN EP (compared to CPU EP). +template +static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, + const TestInputDef& weights_def, + int64_t weight_quant_axis, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 21, + bool use_contrib_qdq = false, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), + BuildQDQPerChannelMatMulTestCase(input_def, + weights_def, + weight_quant_axis, + use_contrib_qdq), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + // Runs an MatMul model on the QNN CPU backend. Checks the graph node assignment, and that inference // outputs for QNN and CPU match. static void RunMatMulOpOpTest(const TestInputDef& input1_def, @@ -160,6 +241,55 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { true); // Use com.microsoft Q/DQ ops } +// Test QDQ per-channel MatMul with 16-bit act, signed 4-bit weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt4) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + +// Test QDQ per-channel MatMul with 16-bit act, unsigned 4-bit weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + +// Test QDQ per-channel MatMul with int8 act, int4 weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { + std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); + std::vector input1_data = {-2.0f, -1.0f, -0.5f, 0.0f, 1.0f, 2.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false, + QDQTolerance(0.007f)); +} + +// Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt8) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + // Test QDQ MatMul with uint16 activation uint16 weights, both dynamic // Inaccuracy detected for output 'output_0', element 1. // Output quant params: scale=0.0015259021893143654, zero_point=0. From 25722bb9e35fbaf122f778db46329dac1f0dcef2 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 1 Aug 2024 04:23:02 -0700 Subject: [PATCH 30/40] Add CUDA custom op header files to Linux tarball (#21551) ### Description The header files were added in PR #16454. Then, recently I made a PR #21464 that changed how we packed Linux tarballs. The new tarball misses the custom op header files. Therefore I need to make this change. ### Motivation and Context --- cmake/onnxruntime.cmake | 12 ++++++++---- cmake/onnxruntime_providers_cpu.cmake | 1 + cmake/onnxruntime_providers_cuda.cmake | 9 ++++++++- cmake/onnxruntime_providers_rocm.cmake | 7 ++++++- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index bdb4b00b02a35..927b4ac84b037 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -38,10 +38,14 @@ function(get_c_cxx_api_headers HEADERS_VAR) # need to add header files for enabled EPs foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) - file(GLOB _provider_headers CONFIGURE_DEPENDS - "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" - ) - list(APPEND _headers ${_provider_headers}) + # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory + # with onnxruntime_c_api.h . Most other EPs probably also do not work in this way. + if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm)) + file(GLOB _provider_headers CONFIGURE_DEPENDS + "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" + ) + list(APPEND _headers ${_provider_headers}) + endif() endforeach() set(${HEADERS_VAR} ${_headers} PARENT_SCOPE) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index d2afe19f36691..bbcc709b144a0 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -219,6 +219,7 @@ if (onnxruntime_ENABLE_TRAINING) endif() install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu/cpu_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) +install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/resource.h ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/custom_op_context.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) set_target_properties(onnxruntime_providers PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 82c31ce6b6b4d..0829be05a3ab0 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -289,8 +289,15 @@ config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) - + # Cannot use glob because the file cuda_provider_options.h should not be exposed out. + set(ONNXRUNTIME_CUDA_PROVIDER_PUBLIC_HEADERS + "${REPO_ROOT}/include/onnxruntime/core/providers/cuda/cuda_context.h" + "${REPO_ROOT}/include/onnxruntime/core/providers/cuda/cuda_resource.h" + ) + set_target_properties(onnxruntime_providers_cuda PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_CUDA_PROVIDER_PUBLIC_HEADERS}") install(TARGETS onnxruntime_providers_cuda + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 71692ddb9391f..559204bd0df88 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -223,8 +223,13 @@ if (onnxruntime_ENABLE_ATEN) target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN) endif() - + file(GLOB ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS CONFIGURE_DEPENDS + "${REPO_ROOT}/include/onnxruntime/core/providers/rocm/*.h" + ) + set_target_properties(onnxruntime_providers_rocm PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS}") install(TARGETS onnxruntime_providers_rocm + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/rocm ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) From 3b73ef2bf7e1597f8379713494f19858a678b483 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Aug 2024 04:28:43 -0700 Subject: [PATCH 31/40] Bump torch from 1.13.1 to 2.2.0 in /tools/ci_build/github/windows/eager (#21505) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [torch](https://github.com/pytorch/pytorch) from 1.13.1 to 2.2.0.
Release notes

Sourced from torch's releases.

PyTorch 2.2: FlashAttention-v2, AOTInductor

PyTorch 2.2 Release Notes

  • Highlights
  • Backwards Incompatible Changes
  • Deprecations
  • New Features
  • Improvements
  • Bug fixes
  • Performance
  • Documentation

Highlights

We are excited to announce the release of PyTorch® 2.2! PyTorch 2.2 offers ~2x performance improvements to scaled_dot_product_attention via FlashAttention-v2 integration, as well as AOTInductor, a new ahead-of-time compilation and deployment tool built for non-python server-side deployments.

This release also includes improved torch.compile support for Optimizers, a number of new inductor optimizations, and a new logging mechanism called TORCH_LOGS.

Please note that we are deprecating macOS x86 support, and PyTorch 2.2.x will be the last version that supports macOS x64.

Along with 2.2, we are also releasing a series of updates to the PyTorch domain libraries. More details can be found in the library updates blog.

This release is composed of 3,628 commits and 521 contributors since PyTorch 2.1. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.2. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Summary:

  • scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups compared to previous versions.
  • PyTorch 2.2 introduces a new ahead-of-time extension of TorchInductor called AOTInductor, designed to compile and deploy PyTorch programs for non-python server-side.
  • torch.distributed supports a new abstraction for initializing and representing ProcessGroups called device_mesh.
  • PyTorch 2.2 ships a standardized, configurable logging mechanism called TORCH_LOGS.
  • A number of torch.compile improvements are included in PyTorch 2.2, including improved support for compiling Optimizers and improved TorchInductor fusion and layout optimizations.
  • Please note that we are deprecating macOS x86 support, and PyTorch 2.2.x will be the last version that supports macOS x64.
  • torch.ao.quantization now offers a prototype torch.export based flow

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=torch&package-manager=pip&previous-version=1.13.1&new-version=2.2.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/ci_build/github/windows/eager/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/windows/eager/requirements.txt b/tools/ci_build/github/windows/eager/requirements.txt index 08e7baa76471b..b285defd89f55 100644 --- a/tools/ci_build/github/windows/eager/requirements.txt +++ b/tools/ci_build/github/windows/eager/requirements.txt @@ -3,5 +3,5 @@ wheel numpy==1.21.6 ; python_version < '3.9' numpy==2.0.0 ; python_version >= '3.9' typing_extensions -torch==1.13.1 +torch==2.2.0 parameterized From 8c2ee7b32e28837ef4e109ea200fc7f1404ebb91 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 2 Aug 2024 00:15:31 +0800 Subject: [PATCH 32/40] [WebNN EP] Create MLGraphBuilder for every model builder (#21514) Currently WebNN spec only allows MLGraphBuilder.build() to be called once, we need to create new builder for every subgraph in WebNN EP. Spec change: https://github.com/webmachinelearning/webnn/pull/717 --- .../core/providers/webnn/builders/helper.cc | 4 ++-- .../core/providers/webnn/builders/helper.h | 6 +++--- .../providers/webnn/builders/model_builder.cc | 16 ++++++++++---- .../providers/webnn/builders/model_builder.h | 8 +++---- .../webnn/webnn_execution_provider.cc | 21 +++++++------------ .../webnn/webnn_execution_provider.h | 1 - 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 44e6953db438e..d3c1d06818db2 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons } std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder_, + const emscripten::val& wnn_builder, const WebnnDeviceType device_type, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -103,7 +103,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const auto* node(graph_viewer.GetNode(node_idx)); bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) { + if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, logger); } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 63fd97abb9a9a..05b783fd17902 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder_, + const emscripten::val& wnn_builder, const WebnnDeviceType device_type, const logging::Logger& logger); static const InlinedHashMap op_map = { @@ -241,14 +241,14 @@ static const InlinedHashMap op_map = { {"Where", {"where", true}}, }; -inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_, +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder, const WebnnDeviceType device_type) { // Returns false if the op_type is not listed in the op_map. if (op_map.find(op_type) == op_map.end()) { return false; } // Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser. - if (!wnn_builder_[op_map.find(op_type)->second.opName].as()) { + if (!wnn_builder[op_map.find(op_type)->second.opName].as()) { return false; } // The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 6b0e1495f552d..b21f717eedc7a 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -20,14 +20,20 @@ namespace onnxruntime { namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const emscripten::val& builder, - const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type) + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), - wnn_builder_(builder), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) {} + wnn_device_type_(wnn_device_type) { + // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() + // is only allowed to be called once. + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); + if (!wnn_builder_.as()) { + ORT_THROW("Failed to create WebNN builder."); + } +} Status ModelBuilder::Initialize() { PreprocessInitializers(); @@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } + // Explicitly release the WebNN builder to free memory. + wnn_builder_ = emscripten::val::undefined(); model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 6a1688f16d2a6..b1561f009aa25 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -22,8 +22,8 @@ class IOpBuilder; class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const emscripten::val& builder, - const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type); + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -62,8 +62,8 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; - emscripten::val wnn_context_ = emscripten::val::object(); - emscripten::val wnn_builder_ = emscripten::val::object(); + emscripten::val wnn_context_ = emscripten::val::undefined(); + emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; InlinedHashMap wnn_operands_; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 0da0dfc6dfb26..1cd382c1e75e9 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } - wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); - if (!wnn_builder_.as()) { - ORT_THROW("Failed to create WebNN builder."); - } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto& logger = *GetLogger(); - if (!wnn_builder_.as()) { - // The GetCapability function may be called again after Compile due to the logic in the - // PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc). - // We need to re-create the wnn_builder_ here to avoid it's been released in last Compile. - wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + if (!wnn_builder.as()) { + ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { return result; @@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); + // Build map from input name to its index in input definitions. { InlinedHashMap input_map; @@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector Date: Thu, 1 Aug 2024 22:21:16 +0300 Subject: [PATCH 33/40] Add reduce kernels for bigger types (#21490) --- docs/OperatorKernels.md | 40 ++++----- .../providers/cpu/cpu_execution_provider.cc | 64 +++++++++++++- .../providers/cpu/reduction/reduction_ops.cc | 20 +++++ .../cpu/reduction/reduction_ops_test.cc | 85 +++++++++++++++++++ 4 files changed, 187 insertions(+), 22 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f265c9f985070..529c676321bbb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -272,18 +272,18 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -294,20 +294,20 @@ Do not modify directly.* |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32)| -|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7ac68e3a9a69d..7ed776f1358a5 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -179,12 +179,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MaxRoiPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, @@ -202,11 +205,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); @@ -407,12 +412,15 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, @@ -432,11 +440,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); @@ -724,12 +734,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, @@ -751,6 +764,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceMin); @@ -758,6 +772,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int8_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, uint8_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, @@ -881,12 +896,15 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, int8_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceL1); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceL1); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceLogSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceLogSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); @@ -902,6 +920,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMin); @@ -909,6 +928,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); @@ -1343,18 +1363,24 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2165,18 +2205,24 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL1, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL1, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL1, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceL1, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceL1, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL2, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL2, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL2, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL2, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL2, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL2, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceL2, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceL2, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceLogSum, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceLogSum, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceLogSum, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceLogSum, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceLogSum, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceLogSum, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceLogSum, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceLogSum, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10); @@ -202,6 +214,10 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMean, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMean, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMean, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceMean, 18); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMean, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMean, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMean, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMean, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMin, 1, 10); @@ -236,12 +252,16 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMin, 20); REGISTER_UNARY_ELEMENTWISE_KERNEL_BOOL_ONLY(ReduceMin, 20); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceProd, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceProd, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceProd, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceProd, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceProd, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 98a65b8efffd2..ab24337046b99 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -263,6 +263,22 @@ TEST(ReductionOpTest, ReduceL1) { test.Run(); } +TEST(ReductionOpTest, ReduceL1_double) { + OpTester test("ReduceL1"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {1, 2, 1}, {33.0f, 45.0f}); + test.Run(); +} + TEST(ReductionOpTest, ReduceL1_int32) { OpTester test("ReduceL1"); test.AddAttribute("axes", std::vector{0, 2}); @@ -423,6 +439,23 @@ TEST(ReductionOpTest, ReduceL2) { test.Run(); } +TEST(ReductionOpTest, ReduceL2_double) { + OpTester test("ReduceL2"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {2}, {15.71623325f, 20.07485962f}); + test.Run(); +} + #if defined(USE_DNNL) TEST(ReductionOpTest, ReduceL2_bfloat16) { #ifdef USE_DNNL @@ -512,6 +545,25 @@ TEST(ReductionOpTest, ReduceLogSum) { test.Run(); } +TEST(ReductionOpTest, ReduceLogSum_double) { + OpTester test("ReduceLogSum"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {3, 1, 2}, + {1.38629436f, 1.79175949f, + 2.48490667f, 2.6390574f, + 2.99573231f, 3.09104252f}); + test.Run(); +} + #if defined(USE_DNNL) TEST(ReductionOpTest, ReduceLogSum_bfloat16) { #ifdef USE_DNNL @@ -1820,6 +1872,23 @@ TEST(ReductionOpTest, ReduceMean_int32) { test.Run(); } +TEST(ReductionOpTest, ReduceMean_int64) { + OpTester test("ReduceMean"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {10, 20, + 30, 40, + + 50, 60, + 70, 80, + + 90, 100, + 110, 120}); + test.AddOutput("reduced", {1, 2, 1}, {55, 75}); + test.Run(); +} + TEST(ReductionOpTest, ReduceMean_axes_input) { OpTester test("ReduceMean", 18, onnxruntime::kOnnxDomain); test.AddAttribute("keepdims", (int64_t)1); @@ -2989,6 +3058,22 @@ TEST(ReductionOpTest, ReduceProd) { test.Run(); } +TEST(ReductionOpTest, ReduceProd_double) { + OpTester test("ReduceProd"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {1, 2, 1}, {5400.f, 88704.f}); + test.Run(); +} + TEST(ReductionOpTest, ReduceProdAxesInitializerOpset18) { OpTester test("ReduceProd", 18); test.AddInput("data", {3, 2, 2}, From b87e8edb98ed85640ccf369625267786e91db920 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Fri, 2 Aug 2024 10:20:22 -0700 Subject: [PATCH 34/40] Mlas int4 int8 with avx2/512 (#20687) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description model: phi-3-mini-4k-instruct avx2 symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |49.5|70.0|-29.2%|9.6|10.8|-34.2% 32 |76.8|52.4|9.7%|15.2|14.6|4.1% 64 |78.2|71.4|9.5%|16.6|16.3|1.8% 128 |72.9|70.6|3.2%|17.1|16.8|1.7% 256 |83.7|63.6|31.6%|18.1|17.4|4% avx2 asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |50.7|61.5|-17.5%|9.6|9.2|4.3% 32 |77.4|52.4|47.7%|14.6|13.9|5.0% 64 |78.7|63.0|24.9%|16.2|15.9|1.8% 128 |80.0|61.9|29.2%|17.2|16.9|1.7% 256 |81.5|63.3|28.7%|17.9|17.3|3.4% avx2vnni symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |82.9|117.0|-29.0%|15.9|19.3|-17.6% 32 |133.0|100.4|32.4%|26.1|24.5|6.5% 64 |166.9|118.8|40.4%|28.3|27.1|4.4% 128 |165.9|119.6|38.7%|29.3|28.5|2.8% 256 |165.2|119.6|38.1%|30.2|29.0|4.1% avx2vnni asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |80.2|118.9|-32.5%|15.1|16.7|-9.5% 32 |130.7|99.7|31.0%|25.0|23.8|5.0% 64 |168.7|124.9|35.0%|27.3|26.8|1.8% 128 |169.6|123.8|36.9%|29.2|27.9|4.6% 256 |175.0|125.7|39.0%|30.0|29.7|1.0% avx512 symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |135.2|156.5|-13.6|25.5|23.8|7.1 32 |150.0|159.5|-5.9|34.9|29.6|17.9 64 |167.5|157.5|6.3|39.7|34.4|15.4 128 |177.8|158.0|12.5|40.3|35.4|13.8 256 |182.6|157.3|16.0|41.7|37.7|10.6 avx512 asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |136.1|151.4|-10.1%|26.1|19.9|31.1% 32 |150.0|157.8|-4.9%|34.3|29.3|17.0% 64 |165.7|156.6|5.8%|38.7|30.7|26.0% 128 |180.4|156.6|15.1%|40.2|34.7|15.8% 256 |181.3|158.0|14.7%|41.6|36.6|13.6% avx512vnni symmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |143.4|155.4|-7.7%|25.6|23.3|9.8% 32 |159.2|157.0|1.4%|34.1|29.8|14.4% 64 |182.0|159.5|14.1%|38.4|34.8|10.3% 128 |221.2|160.8|37.5%|41.0|36.4|12.6% 256 |250.5|162.4|54.2%|41.6|37.7|10.3% avx512vnni asymmetric blklen|updated prompt tps | baseline prompt tps | prompt tps change%|updated token gen tps | baseline token gen tps | token gen change% -|-|-|-|-|-|- 16 |142.5|152.3|-6.4%|26.3|19.7|33.5% 32 |158.2|155.0|2.0%|34.3|29.2|17.4% 64 |184.1|156.6|17.5%|38.3|30.9|23.9% 128 |215.8|156.1|17.5%|41.3|35.0|17.9% 256 |249.2|155.9|59.8%|41.1|36.3|13.2% 4bit gemm implementation with avx using tile. 1. tile size is 2blk by 4. in case of size less then tile, it reduce to 1blk by 4, 2blk by 1 and lastly 1blk by 1. with internal kernel, weight and activation are loaded based on SIMD register width and blk length: avx2 256bit register, 64 weights and activation are loaded. blklen16: 4 blks are computed by the internal kernel blklen32: 2 blks are computed by the internal kernel blklen64: 1 blk are computed by the internal kernel blklen128: 1 blks are computed 2 times by the internal kernel blklen16: 1 blks are computed 4 times by the internal kernel avx512 512bit register, 128 weights and activation are loaded. blklen16: 8 blks are computed by the internal kernel blklen32: 4 blks are computed by the internal kernel blklen64: 2 blk are computed by the internal kernel blklen128: 1 blks are computed by the internal kernel blklen16: 1 blks are computed 2 times by the internal kernel 2. blksum is precomputed during prepacking. computation is reformed: Sum1(scale_a * scale_b * Sum_blk(a_i * b_i)) + Sum2(blksum_a * blksum_b) Sum_blk is over one blk Sum1 is over all blks for one output Sum2 is over all blks for one output Sum is computed with sgemm with the current implementation. Further improvement is possible.   --------- Signed-off-by: Liqun Fu Signed-off-by: liqunfu Signed-off-by: Liqun Fu --- cmake/onnxruntime_mlas.cmake | 13 +- .../cpu/quantization/matmul_nbits.cc | 41 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 56 +- onnxruntime/core/mlas/lib/mlasi.h | 2 + onnxruntime/core/mlas/lib/platform.cpp | 1 + onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 244 +++- onnxruntime/core/mlas/lib/sqnbitgemm.h | 102 ++ .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 292 +++- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 727 ++++++++++ .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 1049 +++++++++++++++ .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 541 ++++++++ .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 150 ++- .../mlas/lib/sqnbitgemm_kernel_avx512_int8.h | 1171 +++++++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 581 ++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 812 ++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 852 ++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 840 ++++++++++++ .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 180 ++- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 273 +++- .../lib/sqnbitgemm_kernel_avx_common_int8.h | 51 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 759 +++++++++++ ...bitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 312 +++++ .../test/contrib_ops/matmul_4bits_test.cc | 5 +- onnxruntime/test/mlas/bench/bench_q4dq.cpp | 24 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 9 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 47 +- 26 files changed, 8834 insertions(+), 300 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 66f4aea606ef5..c02ac2096db2e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,8 +555,17 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S @@ -575,7 +584,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") set(mlas_platform_srcs_avx512vnni ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 995babc857357..5fdd2b017b8a6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -104,6 +104,8 @@ class MatMulNBits final : public OpKernel { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + const Tensor* tensor_zero_point = nullptr; + has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point); #ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; @@ -139,6 +141,7 @@ class MatMulNBits final : public OpKernel { IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) bool is_asym_{false}; @@ -207,10 +210,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(ORT_NEURAL_SPEED) - +#else // defined(ORT_NEURAL_SPEED) + ORT_UNUSED_PARAMETER(prepacked_weights); + const auto compute_type = static_cast(accuracy_level_); if (input_idx == InputIndex::B) { - const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { return Status::OK(); } @@ -220,12 +223,20 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; + } else if (compute_type == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } +#endif } #endif // defined(ORT_NEURAL_SPEED) @@ -332,9 +343,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { + const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + M, N, K, batch_count, nbits_, block_size_, compute_type); + if (workspace_size > 0) { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); @@ -344,14 +355,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].Bias = bias_data; data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), thread_pool); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 32e9cc98106d5..232bf2261ef4c 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -43,14 +43,16 @@ typedef enum { * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -159,14 +161,29 @@ MlasSQNBitGemmPackQuantBDataSize( /** * @brief Packs the quantized B data in a format that the kernel expects. * - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - * @param[in] QuantBData quantized B data - * @param[out] PackedQuantBData packed quantized B data - * @param[in] ThreadPool optional thread pool to use + * If the function is called without QuantBScale and QuantBZeroPoint, + * it just packs QuantBData into PackedQuantBDataAndOrBlkSum. + * + * If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint + * additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum. + * + * Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData, + * QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData, + * and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint, + * BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum. + * If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint. + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum + * @param[in] QuantBScale quantized B scale + * @param[in] has_zp_input whether QuantBZeroPoint is provided + * @param[in] QuantBZeroPoint quantized B zero point + * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL MlasSQNBitGemmPackQuantBData( @@ -176,6 +193,9 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, - MLAS_THREADPOOL* ThreadPool = nullptr + void* PackedQuantBDataAndOrBlkSum, + const void* QuantBScale, + bool has_zp_input, + const void* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e1..4239e2ecaeb6e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; + extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 859b7c2f560a4..ed437f20f7c2a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -409,6 +409,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 81789386a3200..a45494ef2e04f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,11 +16,10 @@ Module Name: --*/ #include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" #include -#include "sqnbitgemm_q8_block.h" - namespace { @@ -80,9 +79,10 @@ MlasIsSQNBitGemmAvailable( return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && - Dispatch->QuantizeARow_CompInt8 != nullptr; + case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + return + (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || + (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } default: { return false; @@ -197,6 +197,21 @@ MlasSQNBitGemmPackQuantBDataSize( return 0; } +struct PerGemmQuantAWorkspace { + PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) + : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + QuantData = (std::byte*)PerGemmWorkspace; + QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); + BlockSum = QuantScale + M * BlockCountK; + } + std::byte* QuantData; // NxBlockCountKxBlkLen + float* QuantScale; // NxBlockCountK + float* BlockSum; // NxBlockCountK + void* PerGemmWorkspace_; // memory for above data + size_t M_, BlockCountK_, BlkLen_; +}; + void MLASCALL MlasSQNBitGemmPackQuantBData( size_t N, @@ -205,7 +220,10 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, + void* PackedQuantBDataAndOrBlkSumWorkspace, + const void* QuantBScale, + bool has_zp_input, + const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ) { @@ -214,17 +232,37 @@ MlasSQNBitGemmPackQuantBData( return; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBData), - ThreadPool - ); - return; + if (BlkBitWidth == 4) { + if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + has_zp_input, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + //assert(QuantBScale == nullptr); + //assert(QuantBZeroPoint == nullptr); + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; + } } } @@ -293,7 +331,7 @@ SQ4BitGemm_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -373,7 +411,6 @@ SQ4BitGemm_CompFp32( if (bias) { AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } - if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, @@ -383,7 +420,6 @@ SQ4BitGemm_CompFp32( c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; } } @@ -402,16 +438,33 @@ SQ4BitGemm_CompInt8( ) { #ifdef MLAS_TARGET_AMD64_IX86 - if (RangeCountM != 1) { - // perf experiment shows fp32 is faster than int8 in M > 1 cases. - // route to fp32 compute before int8 compute is improved. - SQ4BitGemm_CompFp32( - BlkLen, - K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN - ); - return; - } -#endif + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 4 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#else constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); @@ -423,7 +476,7 @@ SQ4BitGemm_CompInt8( const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -433,6 +486,7 @@ SQ4BitGemm_CompInt8( float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#endif size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { @@ -446,25 +500,57 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc + ); + } + + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; + } + } +#ifdef MLAS_TARGET_AMD64_IX86 + else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum ); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc ); } - - c_blk += RowsHandled * ldc; - a_row += RowsHandled * lda; - - RowsRemaining -= RowsHandled; } +#endif } } @@ -496,23 +582,44 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; + // TODO: try parallel on BatchN * M threads because BatchN is usually 1. + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } } struct Operations { @@ -530,7 +637,6 @@ constexpr auto OperationMap = []() { return ops; }(); - } // namespace void MLASCALL @@ -572,12 +678,23 @@ MlasSQNBitGemmBatch( const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + } } return; } @@ -627,9 +744,6 @@ MlasSQNBitGemmBatch( const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = reinterpret_cast( - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride - ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -640,6 +754,18 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 8321dcc217e9a..2da336ca2f0ec 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -25,12 +25,50 @@ Module Name: #include "mlas_qnbit.h" #include "mlasi.h" +constexpr MLAS_FORCEINLINE size_t +MlasQNBitQuantBBlkSumAlignment() +{ + // 16 floats. this alignment is required by GemmFloatKernel + return 16 * sizeof(float); +} + constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { return BlkLen * BlkBitWidth / 8; } +MLAS_FORCEINLINE void* +MlasAlignAddress(void* addr, const size_t alignment) +{ + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); + addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); + return addr; +} + +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + constexpr size_t BlkBitWidth = 4; + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + } + std::byte* PackedQuantBData; + float* PackedQuantBScale; + float* QuantBBlkSum; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + template constexpr MLAS_FORCEINLINE size_t MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) @@ -74,6 +112,21 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + // // Workspace size calculation function prototypes. // @@ -181,6 +234,45 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { // CompInt8 kernel function prototypes. // + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -235,4 +327,14 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { ); QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; + + typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledGroupSum // scale_k * Sum_blklen(a_i) + ); + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 0922f5ef646be..55d86bb9cc18e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -22,6 +22,12 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen64.h" + +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" MLAS_FORCEINLINE __m256 @@ -338,38 +344,92 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( } } +template +MLAS_FORCEINLINE +void +SQ4BitGemmKernel_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } +} + +template MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, size_t CountN, - size_t CountK, + size_t /*CountK*/, size_t BlockStrideQuantB, const float* Bias ) { - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; + if (QuantBZeroPoint) { if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -379,36 +439,25 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, - CountK, BlockStrideQuantB, Bias ); } } else { - constexpr bool HasZeroPoint = false; if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -418,15 +467,15 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, - CountK, BlockStrideQuantB, Bias ); @@ -434,10 +483,12 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } } +MLAS_FORCEINLINE size_t -SQ4BitGemmKernel_CompInt8_avx2( - size_t BlkLen, +SQ4BitGemmKernel_BlkSum_CompInt8_avx2( + const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -446,30 +497,101 @@ SQ4BitGemmKernel_CompInt8_avx2( size_t CountN, size_t CountK, size_t BlockCountK, + const float* Bias, size_t ldc, - const float* Bias + const float* ABlockSum, + const float* QuantBBlkSum ) { - MLAS_UNREFERENCED_PARAMETER(ldc); + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); - if (CountM == 0) { - return 0; + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; } + return CountM; +} - SQ4BitGemmM1Kernel_CompInt8_avx2( +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, - QuantBZeroPoint, C, + CountM, CountN, CountK, BlockCountK, - Bias + Bias, + ldc ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); - return 1; + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; } template @@ -1053,30 +1175,23 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( } } -MLAS_FORCEINLINE __m128i -convert_2_ps_to_epi8(__m256 v0, __m256 v1) -{ - __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); - __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); - - __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); - __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); - - return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); -} - void MLASCALL QuantizeARow_CompInt8_avx2( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_srli_epi16( + _mm256_cmpeq_epi16(_mm256_castps_si256(signBit), _mm256_castps_si256(signBit)), 15); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -1097,13 +1212,14 @@ QuantizeARow_CompInt8_avx2( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const int klen = std::min(16, (int)(step - kk)); @@ -1122,16 +1238,50 @@ QuantizeARow_CompInt8_avx2( v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); } - __m128i i_8 = convert_2_ps_to_epi8(v0, v1); - _mm_storeu_si128(dst++, i_8); + __m128i i_16_epi8 = convert_2_ps_to_epi8(v0, v1); + _mm_storeu_si128(dst++, i_16_epi8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } +static void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == CompInt8) { + SubBlkLen = 64; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); +} + // // Kernel dispatch structure definition. // @@ -1140,6 +1290,26 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + + return d; +}(); + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -1147,8 +1317,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h new file mode 100644 index 0000000000000..80d67806ea6e8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -0,0 +1,727 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE __m256 +load_and_broadcast_4_scale_2(const float* scale) +{ + // 3 2 1 0 3 2 1 0 (7) + __m256 scale_2_4_ps = _mm256_broadcast_ps((__m128 const*)scale); + + // 2 1 0 0 2 1 0 0 (1) + __m256 scale_2_4_ps_shifted = _mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_castps_si256(scale_2_4_ps), 4) + ); + + // 3 2 1 0 2 1 0 0: (3) cross lane + __m256 scale_2_4_ps_permutted = _mm256_permute2f128_ps( + scale_2_4_ps_shifted, scale_2_4_ps, 0b00110000 + ); + + // in accumulate_r1_4blk_dot and accumulate_r2_4blk_dot + // _mm256_hadd_epi16 inter leaved dot sum, resulting: + // a31b31|a30b30|a11b11|a10b10|a21b21|a20b20|a01b01|a00b00 + // therefore we need weight to be: + // 3 3 1 1 2 2 0 0 (1) + return _mm256_permute_ps(scale_2_4_ps_permutted, 0b11110101); +} + +MLAS_FORCEINLINE +__m256i +load_16_epi8_as_epi16(const std::byte* ablob) +{ + const __m128i av_epi8 = _mm_lddqu_si128(reinterpret_cast(ablob)); + __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); + return av_epi16; +} + +MLAS_FORCEINLINE void +accumulate_r1_4blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a, const float* scale_b, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av0_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av1_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a_4_ps = load_and_broadcast_4_scale_2(scale_a); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a_4_ps, scale_b_4_ps); + acc = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc); +} + +MLAS_FORCEINLINE void +accumulate_r2_4blk_dot( + const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a0, const float* scale_a1, const float* scale_b, + __m256& acc0, __m256& acc1 +) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a0_4_ps = load_and_broadcast_4_scale_2(scale_a0); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a0_4_ps, scale_b_4_ps); + acc0 = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc0); + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_inter_leaved_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_inter_leaved_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16_); + const __m256 sum_inter_leaved_ps_ = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32_); + + __m256 scale_a1_4_ps = load_and_broadcast_4_scale_2(scale_a1); + scale_8_ps = _mm256_mul_ps(scale_a1_4_ps, scale_b_4_ps); + acc1 = _mm256_fmadd_ps(sum_inter_leaved_ps_, scale_8_ps, acc1); +} + +static MLAS_FORCEINLINE __m256i +load_4b_packed_1blk_blklen16(const std::byte* QuantBDataPtr) +{ + // | 0 8 |...| 7 15 | + const __m128i bv_packed_64 = _mm_loadl_epi64(reinterpret_cast(QuantBDataPtr)); + const __m128i low_mask = _mm_set1_epi8(0xF); + const __m128i lower_8_epu8 = _mm_and_si128(bv_packed_64, low_mask); // 0~7 + const __m128i upper_8_epu8 = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bv_packed_64, 4), low_mask), 8); // 8~15 + const __m256i bv_16_epu16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper_8_epu8, lower_8_epu8)); // 0~15 + return bv_16_epu16; +} + +static MLAS_FORCEINLINE void +load_4b_packed_4blk_blklen16(const std::byte* QuantBDataPtr, __m256i& bv0_32_epi8, __m256i& bv1_32_epi8) +{ + // | 0 8 |...| 7 15 | 16 24 |...| 23 31 ||| 32 40 |...| 39 47 | 48 56 |...| 55 63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + // 0~7, 16~22, 32~39, 48~55 + __m256i bv0_32_epi8_ = _mm256_and_si256(bv_packed, low_mask); + // 8~15, 24~31, 40~47, 56~63: (1) + __m256i bv1_32_epi8_ = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8_), 4); + // 0~7, 32~39, 16~22, 48~55 <- cross lane (3) + bv0_32_epi8_ = _mm256_permute4x64_epi64(bv0_32_epi8_, 0b11011000); + // 40~47, 8~15, 56~63, 24~31 <- cross lane (3) + bv1_32_epi8_ = _mm256_permute4x64_epi64(bv1_32_epi8_, 0b01110010); + + // 0~7, 8~15, 16~22, 24~31: (1) + bv0_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b11001100); + + // 40~47, 32~39, 56~63, 48~55: (1) + bv1_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b00110011); + + // 32~39, 40~47, 48~55, 56~63: (1) + bv1_32_epi8 = _mm256_shuffle_epi32(bv1_32_epi8, 0b01001110); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r2_4blk_dot(av00_32_epi8, av01_32_epi8, av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, + scale_a0, scale_a1, scale_b, acc0, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk4_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r1_4blk_dot(av0_32_epi8, av1_32_epi8, bv0_32_epi8, bv1_32_epi8, scale_a, scale_b, acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk1_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale0, + const float& combined_scale1, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av0_32_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale0), prod_8_ps, acc0); + + prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av1_32_epi8); + prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale1), prod_8_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk1_avx2( + const __m256i& av_16_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + __m256& acc +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av_16_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale), prod_8_ps, acc); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_blklen16_r2c1blk4_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const __m256i av_16_epi16 = load_16_epi8_as_epi16(QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_16_epi16, QuantBDataPtr, scale_00, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h new file mode 100644 index 0000000000000..af6f52090adcb --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -0,0 +1,1049 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + bv_32_epi8, av_32_epi8 + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +#if !defined(__GNUC__) || (__GNUC__ > 10) +MLAS_FORCEINLINE void +accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} +#endif + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + // low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + // TODO: this (the second line below) is faster and does not keep low_mask in use. + // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); + } + } else { +#endif + //{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + // generating constant 1s is faster here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + //} + //{ + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); + //} +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); // 00110011 + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { +#endif + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4x2BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + } + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); + } + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); + } + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_blklen32_r2c1blk2_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + + { + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] + ); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] + ); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] + ); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4x2BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h new file mode 100644 index 0000000000000..174ebc580904c --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -0,0 +1,541 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av11_32_epi8); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + + } else { +#endif + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } else { +#endif + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen64_r1c1blk1_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index b868906760701..13bd369a065bb 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -22,6 +22,10 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // // CompFp32 kernel implementation. @@ -150,18 +154,115 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( // CompInt8 kernel implementation. // +MLAS_FORCEINLINE +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_set1_epi16(1); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -185,13 +286,14 @@ MlasQ80BlkQuantRow_avx512( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m512 mul = _mm512_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const size_t klen = std::min(size_t(16), step - kk); @@ -208,23 +310,46 @@ MlasQ80BlkQuantRow_avx512( // Convert int32 to int8 __m128i i0_8 = _mm512_cvtepi32_epi8(i0); _mm_storeu_si128(dst++, i0_8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i0_8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); + } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } -void MLASCALL -QuantizeARow_CompInt8_avx512( +static void +SQ4BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool ) { - MlasQ80BlkQuantRow_avx512(BlkLen, A, CountK, QuantA); + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { @@ -232,6 +357,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -239,8 +365,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h new file mode 100644 index 0000000000000..7d9dc36854621 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -0,0 +1,1171 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +MLAS_FORCEINLINE void +accumulate_2blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float& combined_scale0, const float& combined_scale1, + const __m256i& one_16_epi16, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale_8_ps = _mm256_set_ps( + combined_scale1, combined_scale1, combined_scale0, combined_scale0, + combined_scale1, combined_scale1, combined_scale0, combined_scale0 + ); + acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256d scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_mul( + _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256d scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_mul( + _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + const float& combined_scale10, + const float& combined_scale11, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + //low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + // generating constant 1s is fater here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. + // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); + // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + __m256& acc0) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this be faster and save a use of low_mask? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += Q8Blk32Size * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + // accumulate_blklen32_r2c1_avx2 + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + } + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4BlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h new file mode 100644 index 0000000000000..60a887345d0e0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -0,0 +1,581 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +//static MLAS_FORCEINLINE __m512i +//combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +//{ +// __m512i result = _mm512_castsi256_si512(a); +// result = _mm512_inserti64x4(result, b, 1); +// return result; +//} + +//static MLAS_FORCEINLINE void +//load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +//{ +// // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | v64 v96 | ... | v95 v127 | +// const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); +// const __m512i low_mask = _mm512_set1_epi8(0x0F); +// __m512i bv0_64_epi8_ = _mm512_and_si512(bv_packed, low_mask); // 0~31, 64~95 +// __m512i bv1_64_epi8_ = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 32~63, 96~127 +// +// // Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 +// __m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); +// __m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); +// __m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); +// __m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); +// +// // Compose new __m512i variables +// bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); +// bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +//} + +static MLAS_FORCEINLINE void +dot_accumulate_1blk( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i zeros = _mm512_setzero_si512(); + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +static MLAS_FORCEINLINE void +dot_accumulate_1blkvnni( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(dot0_16_epi32, bv1_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr +=NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_blklen128_r1c1blk1_avx512( + av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h new file mode 100644 index 0000000000000..3cd610796a5e3 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -0,0 +1,812 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + + + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1,2~2,3~3 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 4~4,5~5,6~6,7~7 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0044115522663377 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, + acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, + acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2( + av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, + acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } else { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t +MlasQ4Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h new file mode 100644 index 0000000000000..ca12cc14a7875 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -0,0 +1,852 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3}; + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111 + const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +MLAS_FORCEINLINE void +accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = PerAccuBlk4 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + else { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h new file mode 100644 index 0000000000000..2a65ac4af0c1d --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -0,0 +1,840 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +static MLAS_FORCEINLINE __m256 +h_add_512(__m512 a) +{ + return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); +} + +static MLAS_FORCEINLINE float +hsum_float_16(const __m512 x) +{ + __m256 hi = h_add_512(x); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} + +static MLAS_FORCEINLINE __m512i +combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +{ + __m512i result = _mm512_castsi256_si512(a); + result = _mm512_inserti64x4(result, b, 1); + return result; +} + +static MLAS_FORCEINLINE void +load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 + + //// Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 + //__m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); + //__m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); + //__m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); + //__m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); + + //// Compose new __m512i variables + //bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); + //bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +} + +static MLAS_FORCEINLINE __m512i +load_1blk_4b_packed_blklen64(const std::byte* QuantBDataPtr) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16( + _mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + __m512i bv_64_epi8 = combine_two_m256i_to_m512i(bv0_32_epi8, bv1_32_epi8); + return bv_64_epi8; +} + +static MLAS_FORCEINLINE __m512i +horizontal_add_epi32(__m512i a, __m512i b) +{ + __m512i t1 = _mm512_unpacklo_epi32(a, b); + __m512i t2 = _mm512_unpackhi_epi32(a, b); + __m512i sum = _mm512_add_epi32(t1, t2); + return sum; +} + +static MLAS_FORCEINLINE __m512i +generate_ones_32_epi16() +{ + const __m512i zeros = _mm512_setzero_si512(); + return _mm512_srli_epi16(_mm512_ternarylogic_epi64(zeros, zeros, zeros, 1), 15); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blk( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + //const __m512i& one_32_epi16, + __m512& acc) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i one_32_epi16 = generate_ones_32_epi16(); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // sum for blk: 0 1 0 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blkvnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + // const __m512i& one_32_epi16, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); + + __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blkvnni( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } else { + dot_accumulate_2blk( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blk( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk2_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } else { + dot_accumulate_2blk( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } else { + const __m512i zeros = _mm512_setzero_si512(); + // const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); + // const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); + + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx512( + const __m512i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av_32_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } else { + const __m512i one_32_epi16 = _mm512_set1_epi16(1); + + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM % NRows2 == 0); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + //const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + else + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + else + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6477a2019b21a..6a5c01162c51b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -23,6 +23,10 @@ Module Name: #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -146,6 +150,7 @@ void SQ4BitGemmM1Kernel_CompInt8_avx512vnni( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -157,44 +162,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( ) { if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } + assert(false); } else { constexpr bool HasZeroPoint = false; if (BlkLen == 16) { @@ -212,6 +180,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } else if (BlkLen == 32) { SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -237,52 +206,134 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +MLAS_FORCEINLINE size_t -SQ4BitGemmKernel_CompInt8_avx512vnni( - size_t BlkLen, +SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, + const std::byte* /*QuantBZeroPoint*/, float* C, size_t CountM, size_t CountN, - size_t CountK, + size_t /*CountK*/, size_t BlockCountK, + const float* Bias, size_t ldc, - const float* Bias + const float* ABlockSum, + const float* QuantBBlkSum ) { - MLAS_UNREFERENCED_PARAMETER(ldc); - - if (CountM == 0) { - return 0; + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); } - SQ4BitGemmM1Kernel_CompInt8_avx512vnni( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockCountK, - Bias - ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; - return 1; + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; } void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ); +static void +SQ4BitGemmPackQuantBDataAndBlkSum512vnni( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); +} + // // Kernel dispatch structure definition. // @@ -291,6 +342,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -298,8 +350,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni; - d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 706e08fc467ba..177f5518bb891 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -14,13 +14,24 @@ SQ4BitGemmPackQuantBDataSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - constexpr size_t BlkBitWidth = 4; - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (ComputeType == CompInt8) { + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } else { + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } } static void @@ -100,6 +111,216 @@ SQ4BitGemmPackQuantBData( ); } +static size_t +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + size_t scale_dst_offset = T * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += t * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += k_sub_or_blk * 4 + t; + } + return scale_dst_offset; +} + +static size_t +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; + size_t scale_dst_offset = T * 4 * BlockCountK; + if (te) { + scale_dst_offset += t * BlockCountK + k_blk; + } else { + scale_dst_offset += k_subblk * blks_per_sub * 4; + if (be) { + scale_dst_offset += b * 4 + t; + } else { + scale_dst_offset += t * blks_per_sub + b; + } + } + return scale_dst_offset; +} + +static void +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) +{ + constexpr size_t BlkBitWidth = 4; + const size_t BlkBytePairCount = BlkLen / 4; + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // for avx2 + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // for the remaining blk, it shall be: + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + + // for avx512 + // dst: | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + // for the remaining blk, it shall be: + // dst blklen64: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; + + const size_t src_data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset; + + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } else { + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t k_blk = k_subblk * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } + ); +} + +//#include + +static void +ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 8; + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + size_t src_zp_offset = ZPCountK * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; + const std::byte low_mask{0X0F}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlkLen == 16) { // TODO + + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } + } + ); +} + +static void +PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + // // Workspace size calculation function implementation. // @@ -119,7 +340,8 @@ SQ4BitGemmPerGemmWorkspaceSize( case CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } default: { @@ -288,6 +510,20 @@ load_and_mul_sum_s8_quads_with_zp_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); } +template +void MLAS_FORCEINLINE +get_2_zps(const std::byte* QuantBZeroPointPtr, int8_t& zp0, int8_t& zp1) +{ + if constexpr (HasZeroPoint) { + zp0 = std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}); + zp1 = std::to_integer((*QuantBZeroPointPtr) >> 4); + } else { + zp0 = 8; + zp1 = 8; + (void)QuantBZeroPointPtr; + } +} + template int8_t MLAS_FORCEINLINE get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) @@ -375,7 +611,7 @@ FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, con return acc_y; } -static inline float +static MLAS_FORCEINLINE float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); @@ -417,4 +653,27 @@ FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, con _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); } + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8(__m256 v0, __m256 v1) +{ + __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); + __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); + + __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); + __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); + + return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); +} + +// horizontally add 8 int32_t +static MLAS_FORCEINLINE int +hsum_8_epi32(const __m256i a_8_epi32) +{ + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} } // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 250ffeacd7c2f..895ce6cd091c2 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -7,20 +7,6 @@ #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" -void -SQ4BitGemmM1Kernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -); - template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( @@ -240,6 +226,7 @@ template accumulator> void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -273,6 +260,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( int64_t nblk = (int64_t)(CountN)-4; while (nblk >= 0) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -286,14 +274,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -320,7 +308,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -331,9 +320,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -374,6 +363,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( nblk += NCols; for (int64_t n = 0; n < nblk; n++) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -383,14 +373,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -399,7 +389,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -410,9 +401,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h new file mode 100644 index 0000000000000..45c3963365e6b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -0,0 +1,759 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_zp_avx2( + const __m256i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + const std::byte* QuantBZeroPointPtr, + __m256& acc, + const __m256i& low_mask +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(low_mask, bv_32_epi8); + + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + __m256& acc0, + const __m256i& low_mask +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { +#endif + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // accumulate_blklen32_r1c1blk2_zp_is_8_avx2 is much faster than + // accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2: + // BlkBitWidth:4/BlkLen:32/M:1/N:2560/K:2560/Threads:8/Symmetric:1/HasBias:0/ComputeType:4 + // 36591 vs 40270 ns (the main is 51836 ns). both are not as good as main with genai. + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { +#endif + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps( + _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) + ); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const __m256& scale_a0_8_ps, + const __m256& scale_a1_8_ps, + const std::byte* QuantBDataPtr, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { +#endif + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc[0], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale, acc[1], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountN < NCols4); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); + } else { + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE +void +MlasQ4Int8GemmM1KernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} + +//#define SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout 1 +void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + // port from neon implementation + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout +#else + constexpr bool HasZeroPoint = false; +#endif + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + //const size_t StrideQuantBScale = BlockCountK; + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bzp8 = _mm256_set1_epi8(8); + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + (void)StrideQuantBZeroPoint; +#else + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); +#endif + const size_t NCols = 4; + constexpr size_t StrideQuantBScale2 = 2; + constexpr size_t StrideQuantBScale1 = 1; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 + acc0 = _mm256_setzero_ps(), + acc1 = _mm256_setzero_ps(), + acc2 = _mm256_setzero_ps(), + acc3 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen))); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + + // Col1 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale2, acc1, low_mask, bzp8); +#else + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale2)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); +#endif + + // Col2 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale2, acc2, low_mask, bzp8); +#else + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale2)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); +#endif + // Col3 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale2, acc3, low_mask, bzp8); +#else + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale2)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2 * NCols; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_0 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_0, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_0, acc0); +#endif + + // Col1 + const float& scale_1 = scale_a0 * (QuantBScalePtr + StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + StrideQuantBData, scale_1, acc1, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_1, acc1); +#endif + + // Col2 + const float& scale_2 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_2, acc2, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_2, acc2); +#endif + + // Col3 + const float& scale_3 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_3, acc3, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_3, acc3); +#endif + } + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + + // move to next NCols columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + nblk -= NCols; + } + + nblk += NCols; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk0)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk1)); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_00, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); +#endif + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h new file mode 100644 index 0000000000000..e9c3812bde899 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -0,0 +1,312 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + const bool is_lower_half_byte_zp, + __m256& acc0, + const __m256i& low_mask +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + const __m256i bzp8 = _mm256_set1_epi8(get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr)); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_is_8_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t SubblkLen64 = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen64; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const size_t StrideQuantBData1 = 1 * SubblkDataSizeInBytes; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc[0], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, QuantBZeroPointPtr + StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[1], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[2], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[3], low_mask); + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, acc[1], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, acc[2], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, acc[3], low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen64; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + assert(CountN < NCols4); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc0, low_mask); + } else { + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE void +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index dedc01de9655d..548f24e8ac69e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -263,9 +263,10 @@ void RunTest(const TestOptions& opts, } // namespace TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); for (auto M : {1, 2, 100}) { - for (auto N : {1, 2, 32, 288}) { - for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto N : {/*2560, */ 1, 2, 32, 288}) { + for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { for (auto accuracy_level : {0, 1, 4}) { TestOptions base_opts{}; diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index 9d15c9a6bf994..6d21ed2eef864 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -9,10 +9,10 @@ #include "core/util/thread_utils.h" static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -37,10 +37,10 @@ static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) } static void BM_MlasQuantizeBlockwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) { } static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); bool add8 = state.range(4) != 0; int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 354621eff42b6..73c78b8cc3d47 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -53,6 +53,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, std::vector QuantBData(QuantBDataSizeInBytes); std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + bool has_zp_input = !Symmetric; MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), @@ -71,15 +72,17 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), tp.get()); } MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; - params.QuantBData = PackedQuantBData != nullptr - ? static_cast(PackedQuantBData.get()) - : static_cast(QuantBData.data()); + if (PackedQuantBData != nullptr) + params.QuantBDataWorkspace = static_cast(PackedQuantBData.get()); + else + params.QuantBDataWorkspace = static_cast(QuantBData.data()); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = HasBias ? Bias.data() : nullptr; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index f391027de4d51..0710981fa17c6 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -55,8 +55,8 @@ class MlasSQNBitGemmTest : public MlasTestBase { size_t K, const float* A, size_t lda, - const void* QuantBData, - const void* PackedQuantBData, + const void* /*QuantBData*/, + const void* PackedQuantBDataWorkspace, const float* QuantBScale, const void* QuantBZeroPoint, const float* Bias, @@ -71,7 +71,12 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData; +#ifdef MLAS_TARGET_AMD64_IX86 + if (ComputeType == CompInt8) { + params.QuantBDataWorkspace = PackedQuantBDataWorkspace; + } +#endif + params.PackedQuantBData = static_cast(PackedQuantBDataWorkspace); params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; @@ -213,12 +218,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { for (size_t row = 0; row < nrows; ++row) { for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * ncols + col] << "\t"; + std::cout << data[row * ncols + col] << ", "; } std::cout << "\n"; } }; + auto print_matrix_col = [](size_t nrows, size_t ncols, size_t col, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + std::cout << data[row * ncols + col] << ", "; + } + std::cout << "\n"; + }; + std::cout << "A:\n"; print_matrix(M, K, A); std::cout << "B:\n"; @@ -258,14 +270,25 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } - void* PackedQuantBData = nullptr; + void* PackedQuantBDataWorkspace = nullptr; if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { - PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + bool has_zp_input = QuantBZeroPoint != nullptr; + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, GetMlasThreadPool()); } + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); + if (ComputeType == CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else if (ComputeType == CompInt8) { @@ -275,15 +298,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; } - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); - size_t f = 0; for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { @@ -382,7 +396,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Fri, 2 Aug 2024 11:02:22 -0700 Subject: [PATCH 35/40] [QNN EP] Support Conv + Clip/Relu fusion (#21537) ### Description - Supports quantized Conv + Activation on the HTP backend: - Translates `DQs -> Conv -> Relu/Clip -> Q` into a single QNN Conv operator if the Relu (or Clip) are redundant. ### Motivation and Context Expands support for QDQ models created with tools that do not wrap Relu or Clip with QDQ nodes. This PR introduces the `IQnnNodeGroup` class. In the same way that a `NodeUnit` represents a collection of `Nodes`, a `IQnnNodeGroup` can represent one or more `NodeUnits` that are translated into a QNN operator. QNN EP parses the ONNX graph to create a list of `IQnnNodeGroup` objects, each representing a single `NodeUnit` or a fusion of multiple `NodeUnits`. --- onnxruntime/core/framework/node_unit.cc | 15 + onnxruntime/core/framework/node_unit.h | 4 + .../core/providers/qnn/builder/qnn_fusions.cc | 294 ----------- .../core/providers/qnn/builder/qnn_fusions.h | 38 -- .../core/providers/qnn/builder/qnn_model.cc | 51 +- .../qnn/builder/qnn_model_wrapper.cc | 2 + .../providers/qnn/builder/qnn_node_group.h | 68 +++ .../qnn_node_group/conv_activation_fusion.cc | 480 ++++++++++++++++++ .../qnn_node_group/conv_activation_fusion.h | 63 +++ .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 179 +++++++ .../qnn/builder/qnn_node_group/dq_q_fusion.h | 57 +++ .../qnn_node_group/hardsigmoid_mul_fusion.cc | 144 ++++++ .../qnn_node_group/hardsigmoid_mul_fusion.h | 57 +++ .../builder/qnn_node_group/qnn_node_group.cc | 221 ++++++++ .../qnn/builder/qnn_node_group/utils.cc | 66 +++ .../qnn/builder/qnn_node_group/utils.h | 40 ++ .../providers/qnn/qnn_execution_provider.cc | 121 ++--- .../providers/qnn/qnn_execution_provider.h | 3 - onnxruntime/test/providers/qnn/conv_test.cc | 252 +++++++-- 19 files changed, 1670 insertions(+), 485 deletions(-) delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_fusions.cc delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_fusions.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index e2c06fbdfa621..850cb167a3ece 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "node_unit.h" +#include #include "core/graph/graph_viewer.h" namespace onnxruntime { @@ -272,6 +273,20 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, Node::EdgeSet output_edges) + : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), + target_node_(target_node), + q_nodes_(q_nodes.begin(), q_nodes.end()), + type_(unit_type), + inputs_(inputs.begin(), inputs.end()), + outputs_(outputs.begin(), outputs.end()), + input_edge_count_(input_edge_count), + output_edges_(std::move(output_edges)) { +} + const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index e84e62479162f..50bd423d2f547 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,6 +68,10 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); + NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc deleted file mode 100644 index b04075f11203c..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_fusions.h" - -#include -#include -#include -#include -#include -#include -#include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" - -#define QNN_RETURN_OK_IF_ERROR(expr, logger) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - LOGS((logger), VERBOSE) << _status.ErrorMessage(); \ - return Status::OK(); \ - } \ - } while (0) - -namespace onnxruntime { -namespace qnn { - -/** - * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from - * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param start_node_unit The node unit that could potentially start the DQ -> Q sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ -static Status TryHandleConvertSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - - // Looking for a standalone DQ to start the sequence. - if (start_node_unit.OpType() != QDQ::DQOpName || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - const Node& dq_node = start_node_unit.GetNode(); - - // DQ must have a single Q child. DQ must not produce a graph output. - auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); - if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return Status::OK(); - } - - const Node& q_node = *children[0]; - const auto q_node_unit_it = node_unit_map.find(&q_node); - - ORT_RETURN_IF(q_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - - const NodeUnit* q_node_unit = q_node_unit_it->second; - - // Check if Q node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(q_node_unit) != 0) { - return Status::OK(); - } - - // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { - return graph_viewer.GetConstantInitializer(initializer_name, true); - }; - - // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = q_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN Convert via fusion. dq_node name: [" << dq_node.Name() - << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit->Name() - << "] q_node optype: [" << q_node_unit->OpType() - << "]"; - - // Add a QNN Convert to the model. Get the input from the DQ node, and the output from the Q node. - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(*q_node_unit), - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused Convert node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(q_node_unit); - - return Status::OK(); -} - -/** - * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. - * Should be called in a topologically ordered iteration of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -static Status TryHandleHardSigmoidSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { - // Looking for a standalone HardSigmoid to start the sequence. - if (start_node_unit.OpType() != "HardSigmoid" || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - NodeAttrHelper hs_attr_helper(start_node_unit); - float alpha = hs_attr_helper.Get("alpha", 0.2f); - float beta = hs_attr_helper.Get("beta", 0.5f); - constexpr float req_alpha = 1.0f / 6.0f; - constexpr float req_beta = 0.5f; - constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; - constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; - - // Check for explicit values of alpha and beta. - if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return Status::OK(); - } - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& hs_node = start_node_unit.GetNode(); - - // HardSigmoid must have a single Mul child. HardSigmoid must not produce a graph output. - auto children = graph_utils::FindChildrenByType(hs_node, "Mul"); - if (children.size() != 1 || hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return Status::OK(); - } - - const Node& mul_node = *children[0]; - const auto mul_node_unit_it = node_unit_map.find(&mul_node); - ORT_RETURN_IF(mul_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - const NodeUnit* mul_node_unit = mul_node_unit_it->second; - - // Check if Mul node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(mul_node_unit) != 0) { - return Status::OK(); - } - - // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - // Input to HardSigmoid must also be the other input to the Mul. - auto& hs_input_name = start_node_unit.Inputs()[0].node_arg.Name(); - const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || - mul_node.InputDefs()[1]->Name() == hs_input_name; - - if (!same_root_input) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = mul_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << start_node_unit.Name() - << "] Mul name: [" << mul_node_unit->Name() << "]"; - - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused HardSwish node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(mul_node_unit); - - return Status::OK(); -} - -using FusionFunc = Status (*)(std::vector&, - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_set&, - const logging::Logger&, - bool); - -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool validate) { - // Maps a starting operator type to the fusion function. - static std::unordered_map fusions = { - {"DequantizeLinear", TryHandleConvertSequence}, - {"HardSigmoid", TryHandleHardSigmoidSequence}, - }; - - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto iter = fusions.find(starting_node.OpType()); - if (iter != fusions.end()) { - fused_nodes.clear(); - - FusionFunc fusion_func = iter->second; - ORT_RETURN_IF_ERROR(fusion_func(fused_nodes, qnn_model_wrapper, starting_node, node_unit_map, - handled_node_units, logger, validate)); - } - - return Status::OK(); -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h deleted file mode 100644 index 39e2e71c01d8c..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" - -namespace onnxruntime { -namespace qnn { - -/** - * Tries to fuse a node sequence starting from the given starting node. Should be called in a topologically ordered - * walk of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation); -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 503943dfb636b..83f9184d33611 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -7,7 +7,7 @@ #include "QnnOpDef.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_fusions.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -117,49 +117,20 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } - std::unordered_set handled_node_units; + std::vector> qnn_node_groups; + qnn_node_groups.reserve(node_unit_holder.size()); - // Op builer - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer.GetNode(node_indices[i])); + ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, + node_unit_holder.size(), logger_)); - // Check whether it's part of NodeUnit - const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); - // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) - const std::string& op_type = node_unit.OpType(); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_); - if (node != &node_unit.GetNode()) { - continue; + if (!status.IsOK()) { + LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " + << status.ErrorMessage() << std::endl; + return status; } - - if (handled_node_units.count(&node_unit) != 0) { - continue; // Already handled. - } - - // Try to see if this node unit can be fused. - std::vector fused_nodes; - ORT_RETURN_IF_ERROR(TryFusions(fused_nodes, qnn_model_wrapper, node_unit, node_unit_map, - handled_node_units, logger_, false /*do_op_validation*/)); - - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - } - continue; - } - - LOGS(logger_, VERBOSE) << " node name: [" << node->Name() - << "] node optype: [" << op_type - << "] as part of the NodeUnit type: [" << node_unit.OpType() - << "] name: [" << node_unit.Name() - << "]"; - if (const auto* op_builder = GetOpBuilder(op_type)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); - } - - handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 9d3f460572d84..657224f68f71b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -239,6 +239,8 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, std::string error_msg; bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg); if (!rt) { + // TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more + // specific validation error (instead of "failed to add node"). LOGS(logger_, WARNING) << error_msg; } return rt; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h new file mode 100644 index 0000000000000..f9ef01411310f --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/logging/logging.h" +#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a group of NodeUnits that QNN EP translates into a core QNN operator. Can represent a single NodeUnit +/// or a fusion of multiple NodeUnits (e.g., DQ* -> Conv -> Relu -> Q). +/// +class IQnnNodeGroup { + public: + virtual ~IQnnNodeGroup() = default; + + // Returns an OK status if this IQnnNodeGroup is supported by QNN. + virtual Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Adds this IQnnNodeGroup to the QNN model wrapper. + virtual Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Returns a list of NodeUnits contained by this IQnnNodeGroup. + virtual gsl::span GetNodeUnits() const = 0; + + /// + /// Returns the "target" NodeUnit of the group. This is important for topological ordering of IQnnNodeGroups. + /// The target should be the first NodeUnit where all input paths (of the IQnnNodeGroup) converge. + /// For example, "Conv" should be the target NodeUnit for the following IQnnNodeGroup with 6 NodeUnits. + /// input0 -> DQ -> Conv -> Relu -> Q + /// ^ + /// | + /// input1 -> DQ ----+ + /// + /// Target NodeUnit in IQnnNodeGroup + virtual const NodeUnit* GetTargetNodeUnit() const = 0; + + // Returns a string representation of the IQnnNodeGroup's type. + virtual std::string_view Type() const = 0; +}; + +/// +/// Traverses the ONNX graph to create IQnnNodeGroup objects, each containing one or more NodeUnits. +/// The returned IQnnNodeGroup objects are sorted in topological order. +/// +/// Output vector into which the resulting IQnnNodeGroup objects are stored. +/// Contains reference to the ONNX GraphViewer and used for validaton on QNN +/// Maps a Node* to a NodeUnit* +/// The number of NodeUnits in the ONNX graph. +/// Logger +/// Status with potential error +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + size_t num_node_units, + const logging::Logger& logger); +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc new file mode 100644 index 0000000000000..813bba8a5952b --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -0,0 +1,480 @@ +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Gets the scale, zero-point, and zero-point type for a QuantizeLinear node that uses per-tensor quantization. +static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& scale, + /*out*/ int32_t& zero_point, + /*out*/ int32_t& zp_data_type) { + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); + const auto& q_inputs = q_node_unit.GetNode().InputDefs(); + + // Require an explicit zero-point input for now. + if (q_inputs.size() != 3 || !q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Exists()) { + return false; + } + + std::vector zero_points; + Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Name(), + zero_points, zp_data_type); + + // Should only have one zero-point (per-tensor). + if (!status.IsOK() || zero_points.size() != 1) { + return false; + } + zero_point = -zero_points[0]; // QNN zero-points are negated. + + std::vector scales; + status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ_SCALE_INPUT_IDX]->Name(), scales); + + // Should only have one scale (per-tensor). + if (!status.IsOK() || scales.size() != 1) { + return false; + } + + scale = scales[0]; + return true; +} + +// Computes the floating point range (rmin, rmax) from a QuantizeLinear node's scale/zero-point. +static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& rmin, + /*out*/ float& rmax) { + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + default: + return false; + } + + return true; +} + +// Returns true if the Clip in the sequence (Clip -> Q) can be removed because it is made redundant by the Q. +static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& clip_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QUANTIZE_LINEAR); + float rmin = 0.0f; + float rmax = 0.0f; + + if (!GetQRminRmax(qnn_model_wrapper, q_node_unit, rmin, rmax)) { + return false; + } + + float clip_min = std::numeric_limits::lowest(); + float clip_max = std::numeric_limits::max(); + + if (!onnxruntime::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), + clip_min, clip_max, logger)) { + return false; + } + + // The clip range must entirely overlap the quantization range (quantization can be smaller). + // Clip range: [------------------] + // Quant range: [-------------] + constexpr float epsilon = std::numeric_limits::epsilon(); + if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { + return false; + } + + return true; +} + +// Returns true if the Relu in the sequence (Relu -> Q) can be removed because it is made redundant by the Q. +static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + // Relu is redundant if the zero-point is set to the smallest quantized value. + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: + return zero_point == static_cast(std::numeric_limits::lowest()); + default: + return false; + } +} + +// Returns true if the Clip/Relu in the sequence (Clip/Relu -> Q) can be removed because it is made redundant by the Q. +static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + const std::string& activation_type = activation_node_unit.OpType(); + + if (activation_type == "Relu") { + return CanQRelaceRelu(qnn_model_wrapper, q_node_unit); + } + + if (activation_type == "Clip") { + return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit, logger); + } + + return false; +} + +// Returns the parent DQ nodes for a given node. +static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { + // Get all parent DQ nodes sorted by destination argument index. + std::vector parents(node.InputDefs().size(), nullptr); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) { + if (it->GetNode().OpType().compare(DEQUANTIZE_LINEAR) == 0) { + parents[it->GetDstArgIndex()] = &(it->GetNode()); + } + } + + // Remove all the nodes which are not in the graph_viewer + parents.erase(std::remove_if(parents.begin(), parents.end(), + [&graph_viewer](const Node* _node) { + return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; + }), + parents.end()); + + return parents; +} + +// Gets the parent DQ nodes for the given Conv node. This fuction checks that the DQs are not a part of +// any other NodeUnit and that every Conv input comes from a parent DQ. +static bool GetConvDQs( + const GraphViewer& graph_viewer, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Node& conv_node, + /*out*/ std::array& dq_node_units) { + if (conv_node.OpType() != "Conv" && conv_node.OpType() != "ConvTranspose") { + return false; + } + + // Count number of inputs to Conv node. + const auto& conv_inputs = conv_node.InputDefs(); + const size_t num_conv_inputs = std::count_if(conv_inputs.cbegin(), conv_inputs.cend(), + [](const NodeArg* input) { return input && input->Exists(); }); + + // Get the Conv's parent DQ nodes. + std::vector dq_nodes = FindParentDQNodes(graph_viewer, conv_node); + const size_t num_dqs = dq_nodes.size(); + + // Within a QDQ node group, a target node input is the only consumer of each DQ. + if ((num_conv_inputs != num_dqs) || (num_dqs > dq_node_units.size())) { + return false; + } + + dq_node_units.fill(nullptr); + for (size_t i = 0; i < num_dqs; i++) { + const Node* dq_node = dq_nodes[i]; + + // DQ must not produce a graph output. + if (!dq_node || graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return false; + } + + // Conv should be the only consumer of a parent DQ. + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); + if (!dq_has_single_output_edge_to_target) { + return false; + } + + // DQ node must be part of a "standalone" NodeUnit. + const auto it = node_to_node_unit.find(dq_node); + if (it == node_to_node_unit.end()) { + return false; + } + const NodeUnit* dq_node_unit = it->second; + if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { + return false; + } + if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return false; + } + + dq_node_units[i] = dq_node_unit; + } + + return true; +} + +// Checks that the input and output data types are valid for a QDQ Conv. +static bool CheckQDQConvDataTypes(std::array& dq_node_units, + gsl::not_null q_node_unit) { + assert(q_node_unit->OpType() == QUANTIZE_LINEAR); + // input and output types need to be same + int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return false; + } + + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (dt_weight != dt_input) { + return false; + } + } + + if (dq_node_units[2] != nullptr) { // has bias + int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } + } + + return true; +} + +// Utility function to either validate or create a quantized QNN Conv node. The function creates a temporary +// custom NodeUnit that excludes the Clip/Relu because it is redundant. This custom NodeUnit is passed to our +// existing Conv OpBuilder for creation or validation via QNN APIs. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate) { + const size_t num_dqs = dq_node_units.size(); + constexpr size_t max_num_dqs = 3; + ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); + ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == QUANTIZE_LINEAR, + "Expected Conv/ConvTranspose and QuantizeLinear but got ", conv_node_unit->OpType(), " and ", + q_node_unit->OpType()); + + std::array dq_nodes_buf = {}; + for (size_t i = 0; i < num_dqs; i++) { + dq_nodes_buf[i] = &dq_node_units[i]->GetNode(); + } + gsl::span dq_nodes(dq_nodes_buf.data(), num_dqs); + + std::array q_nodes = {&q_node_unit->GetNode()}; + const Node& target_node = conv_node_unit->GetNode(); + + // Populate NodeUnit inputs + std::vector inputs; + inputs.reserve(num_dqs); + for (const Node* dq_node : dq_nodes) { + const auto dq_inputs = dq_node->InputDefs(); + const auto& dq_attrs = dq_node->GetAttributes(); + + std::optional axis; + if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; + inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param}); + } + + // Populate NodeUnit outputs and output edges + std::vector outputs; + Node::EdgeSet output_edges; + for (const Node* q_node : q_nodes) { + const auto q_inputs = q_node->InputDefs(); + const auto& q_attrs = q_node->GetAttributes(); + const auto q_outputs = q_node->OutputDefs(); + + std::optional axis; + if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; + outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); + + // Gather output edges out of the Q node. + auto q_cur_edge = q_node->OutputEdgesBegin(); + auto q_end_edge = q_node->OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); + } + } + + NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, + inputs, outputs, num_dqs, output_edges); + const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); + if (conv_op_builder == nullptr) { + return Status::OK(); + } + + if (validate) { + return conv_op_builder->IsOpSupported(qnn_model_wrapper, custom_node_unit, logger); + } + + return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); +} + +// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. +// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. +std::unique_ptr ConvActivationFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Expect that this function is called with a standalone Conv or ConvTranspose. + const auto& conv_type = conv_node_unit.OpType(); + + if ((conv_type != "Conv" && conv_type != "ConvTranspose") || + (conv_node_unit.UnitType() != NodeUnit::Type::SingleNode)) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Conv must have a single Relu or Clip child. + const std::array activation_op_types = {"Relu", "Clip"}; + const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + if (activation_node_unit == nullptr) { + return nullptr; + } + + // Relu/Clip must have a single Q child. + const std::array q_op_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. + if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit, logger)) { + return nullptr; + } + + // Create a QDQ node group with DQ* -> Conv -> Q + const Node& conv_node = conv_node_unit.GetNode(); + std::array dq_node_units = {}; + if (!GetConvDQs(graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node, dq_node_units)) { + return nullptr; + } + + if (!CheckQDQConvDataTypes(dq_node_units, q_node_unit)) { + return nullptr; + } + + return std::make_unique(*dq_node_units[0], + *dq_node_units[1], + dq_node_units[2], + conv_node_unit, + *activation_node_unit, + *q_node_unit); +} + +ConvActivationFusion::ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit) + : node_units_{} { + size_t i = 0; + node_units_[i++] = &dq_node_unit_0; + node_units_[i++] = &dq_node_unit_1; + if (dq_node_unit_2 != nullptr) { + node_units_[i++] = dq_node_unit_2; + } + node_units_[i++] = &conv_node_unit; + node_units_[i++] = &activation_node_unit; + node_units_[i++] = &q_node_unit; + assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); +} + +Status ConvActivationFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); + + return ValidateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); +} + +Status ConvActivationFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); + + return CreateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); +} + +gsl::span ConvActivationFusion::GetNodeUnits() const { + const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; + return gsl::make_span(node_units_.data(), num_node_units); +} + +const NodeUnit* ConvActivationFusion::GetTargetNodeUnit() const { + const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; + return node_units_[conv_index]; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h new file mode 100644 index 0000000000000..b604b25e943e6 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ* -> Conv -> Relu/Clip -> Q sequence where the Relu (or Clip) is redundant +/// due to the quantization effects of the Q. This sequence is translated to a quantized QNN Conv. +/// All contained NodeUnits are of type SingleNode since they are not a part of an existing QDQ node unit. +/// +class ConvActivationFusion : public IQnnNodeGroup { + public: + ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ConvActivationFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ConvActivationFusion"; } + + /// + /// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. + /// + /// Used for validation and to traverse/query the graph + /// Conv node unit (type SingleNode) that be part of the sequence. + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; // Last elem is nullptr if the optional bias DQ is missing. +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc new file mode 100644 index 0000000000000..ce87ac4a3d21c --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -0,0 +1,179 @@ +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, bool validate); +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node); + +std::unique_ptr DQQFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + // Expect that this function is called with a standalone DQ. + if (dq_node_unit.OpType() != DEQUANTIZE_LINEAR || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); + + // DQ must have a single Q child (1 output edge) and must not produce a graph output. + const std::array child_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // DQ and Q must have equal scale type and different zp type. + if (!IsDQQConversion(graph_viewer, dq_node, q_node_unit->GetNode())) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, dq_node_unit, *q_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(dq_node_unit, *q_node_unit); +} + +DQQFusion::DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) + : node_units_{&dq_node_unit, &q_node_unit} { +} + +Status DQQFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status DQQFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span DQQFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* DQQFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + bool validate) { + assert(dq_node_unit.OpType() == DEQUANTIZE_LINEAR && q_node_unit.OpType() == QUANTIZE_LINEAR); + const auto& node_name = utils::GetNodeName(dq_node_unit); + const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused Convert node."); + } + + return Status::OK(); +} + +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + auto is_scalar_shape = [](const NodeArg& input_arg) -> bool { + auto shape = input_arg.Shape(); + if (shape == nullptr) { + return false; + } + + auto dim_size = shape->dim_size(); + return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1); + }; + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != QDQ_MAX_NUM_INPUTS || + q_input_defs.size() != QDQ_MAX_NUM_INPUTS || + !is_scalar_shape(*q_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // All TensorProtos must have a data type + if (!q_zp_tensor_proto->has_data_type() || !dq_zp_tensor_proto->has_data_type() || + !q_scale_tensor_proto->has_data_type() || !dq_scale_tensor_proto->has_data_type()) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) && + (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h new file mode 100644 index 0000000000000..90fe44c3af059 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ -> Q sequence that converts from one quantization type (e.g., uint8_t) to +/// another (e.g., uint16_t). This is translated into a QNN Convert operator, which is much faster than individual +/// ops. The DQ and Q are standalone NodeUnits that are not part of a QDQ node unit. +/// +class DQQFusion : public IQnnNodeGroup { + public: + DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(DQQFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "DQQFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid DQ -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the DQ and Q NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// DQ node unit that could start the DQ -> Q sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc new file mode 100644 index 0000000000000..76b1726646486 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -0,0 +1,144 @@ +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, bool validate); + +std::unique_ptr HardSigmoidMulFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + + // Looking for a standalone HardSigmoid to start the sequence. + if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || + hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); + float alpha = hs_attr_helper.Get("alpha", 0.2f); + float beta = hs_attr_helper.Get("beta", 0.5f); + constexpr float req_alpha = 1.0f / 6.0f; + constexpr float req_beta = 0.5f; + constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; + constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; + + // Check for explicit values of alpha and beta. + if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { + return nullptr; + } + + // HardSigmoid must have a single Mul child (1 output edge) and must not produce a graph output. + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::array child_types = {"Mul"}; + const NodeUnit* mul_node_unit = GetOnlyChildOfType(graph_viewer, hardsigmoid_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (mul_node_unit == nullptr) { + return nullptr; + } + + // Input to HardSigmoid must also be the other input to the Mul. + const Node& mul_node = mul_node_unit->GetNode(); + auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); + const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || + mul_node.InputDefs()[1]->Name() == hs_input_name; + + if (!same_root_input) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(hardsigmoid_node_unit, *mul_node_unit); +} + +HardSigmoidMulFusion::HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) + : node_units_{&hardsigmoid_node_unit, &mul_node_unit} { +} + +Status HardSigmoidMulFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status HardSigmoidMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span HardSigmoidMulFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* HardSigmoidMulFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, + bool validate) { + assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); + const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused HardSwish node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h new file mode 100644 index 0000000000000..3b67f13492a46 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a HardSigmoid -> Mul sequence that computes `x * HardSigmoid(x)`. +/// This is translated into a QNN HardSwish operator. +/// The contained NodeUnits are of type SingleNode since they are not a part of a QDQ node unit. +/// +class HardSigmoidMulFusion : public IQnnNodeGroup { + public: + HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(HardSigmoidMulFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "HardSigmoidMulFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid HardSigmoid -> Mul sequence. + /// If so, returns a IQnnNodeGroup that contains the HardSigmoid and Mul NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// HardSigmoid node unit that could start the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc new file mode 100644 index 0000000000000..9fb9e815321c0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group.h" + +#include +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +namespace onnxruntime { +namespace qnn { + +/// +/// A IQnnNodeGroup class that wraps a single NodeUnit. Most NodeUnits in the ONNX graph will +/// be wrapped by this class. +/// +class QnnNodeUnitWrapper : public IQnnNodeGroup { + public: + explicit QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_->OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type, + "` are not supported by QNN EP.", op_type, " node `", + node_unit_->Name(), "` will not be assigned to QNN EP."); + + return op_builder->IsOpSupported(qmw, *node_unit_, logger); + } + + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_->OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type); + return op_builder->AddToModelBuilder(qmw, *node_unit_, logger, /*do_op_validation*/ false); + } + + gsl::span GetNodeUnits() const override { + return gsl::span{&node_unit_, 1ULL}; + } + + const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; } + std::string_view Type() const override { return "NodeUnit"; } + + private: + const NodeUnit* node_unit_; +}; + +/// +/// The type of a function that tries to fuse NodeUnits into a IQnnNodeGroup. +/// +using FusionFunc = std::unique_ptr (*)( + QnnModelWrapper&, + const NodeUnit&, + const std::unordered_map&, + const std::unordered_map&, + const logging::Logger&); + +/// +/// Given a starting NodeUnit, this function tries all possible fusions that start with that NodeUnit. +/// If successful, returns a IQnnNodeGroup object that represents the fusion of various NodeUnits. +/// Currently only handles standalone NodeUnits that are not in a QDQ unit but that can change in the future. +/// +/// QnnModelWrapper that contains the ONNX GraphViewer. Used for validation. +/// NodeUnit that potentially starts a fusion. +/// Maps a Node* to a NodeUnit* +/// Maps a NodeUnit* to a IQnnNodeGroup* +/// +/// IQnnNodeGroup representing the fusion or an empty std::unique_ptr +static std::unique_ptr TryQnnFusions( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Maps a starting operator type to the fusion function. + static std::unordered_map fusions = { + {"DequantizeLinear", DQQFusion::TryFusion}, + {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, + {"Conv", ConvActivationFusion::TryFusion}, + {"ConvTranspose", ConvActivationFusion::TryFusion}, + }; + + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + auto iter = fusions.find(starting_node_unit.OpType()); + if (iter != fusions.end()) { + FusionFunc fusion_func = iter->second; + return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger); + } + return nullptr; +} + +// Traverses the ONNX Graph and groups NodeUnits into IQnnNodeGroup objects. Some IQnnNodeGroup objects +// represent a fusion of various NodeUnits. This function generates a vector of indices that +// represent the topological order of the qnn_node_groups. +static Status GetQnnNodeGroupsImpl(/*out*/ std::vector>& qnn_node_groups, + /*out*/ std::vector& sorted_qnn_node_group_indices, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + sorted_qnn_node_group_indices.reserve(num_node_units); + qnn_node_groups.reserve(num_node_units); + + std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map fused_qnn_node_group_indices; + std::vector> sorted_node_units; + sorted_node_units.reserve(num_node_units); + + // Process just the fusions of NodeUnits first to ensure a correct topological order of all IQnnNodeGroups. + // This is the same approach taken by ORT utilities for grouping Nodes into NodeUnits. + for (NodeIndex node_index : sorted_node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + + // Get the NodeUnit associated with the node. + const auto node_unit_it = node_to_node_unit.find(node); + ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); + gsl::not_null node_unit = node_unit_it->second; + + // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } + + sorted_node_units.push_back(node_unit); + + if (node_unit_to_qnn_node_group.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + std::unique_ptr fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); + + if (fused_node_group) { + const size_t index = qnn_node_groups.size(); + fused_qnn_node_group_indices[fused_node_group.get()] = index; + + for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { + assert(fused_node_unit != nullptr); + node_unit_to_qnn_node_group.insert({fused_node_unit, fused_node_group.get()}); + } + + qnn_node_groups.push_back(std::move(fused_node_group)); + } + } + + // Create IQnnNodeGroups for the leftover NodeUnits that were not fused. + for (gsl::not_null node_unit : sorted_node_units) { + const auto it = node_unit_to_qnn_node_group.find(node_unit); + + if (it != node_unit_to_qnn_node_group.end()) { + // Already added this NodeUnit to a IQnnNodeGroup, so we'll skip it. + // However, if this NodeUnit is the "target" for the IQnnNodeGroup, then add its index to + // the sorted list of indices. + gsl::not_null fused_qnn_node_group = it->second; + if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); + } + continue; + } + + const size_t index = qnn_node_groups.size(); + auto qnn_node_group = std::make_unique(*node_unit); + + node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); + qnn_node_groups.push_back(std::move(qnn_node_group)); + sorted_qnn_node_group_indices.push_back(index); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} + +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + std::vector sorted_qnn_node_group_indices; + std::vector> qnn_node_groups_holder; + ORT_RETURN_IF_ERROR(GetQnnNodeGroupsImpl(qnn_node_groups_holder, sorted_qnn_node_group_indices, qnn_model_wrapper, + node_to_node_unit, num_node_units, logger)); + + // Move IQnnNodeGroups to the output std::vector in sorted (topological) order. + qnn_node_groups.resize(0); + qnn_node_groups.reserve(qnn_node_groups_holder.size()); + for (auto index : sorted_qnn_node_group_indices) { + assert(index < qnn_node_groups_holder.size()); + std::unique_ptr qnn_node_group = std::move(qnn_node_groups_holder[index]); + qnn_node_groups.push_back(std::move(qnn_node_group)); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc new file mode 100644 index 0000000000000..5548d7d37c378 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -0,0 +1,66 @@ +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node& parent_node = parent_node_unit.GetNode(); + + // Parent must have a single child (1 output edge) and must not produce a graph output. + if (parent_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(parent_node)) { + return nullptr; + } + + // Child must be of a valid type. + const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + return nullptr; // Node is not in this GraphViewer + } + const std::string& child_type = child_node.OpType(); + bool is_valid_child_type = false; + + for (const auto& valid_op_type : child_op_types) { + if (valid_op_type == child_type) { + is_valid_child_type = true; + break; + } + } + + if (!is_valid_child_type) { + return nullptr; + } + + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(child_node_unit) != 0) { + return nullptr; + } + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + return child_node_unit; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h new file mode 100644 index 0000000000000..0d11d21906ccb --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { +constexpr const char* QUANTIZE_LINEAR = "QuantizeLinear"; +constexpr const char* DEQUANTIZE_LINEAR = "DequantizeLinear"; +constexpr size_t QDQ_MAX_NUM_INPUTS = 3; +constexpr size_t QDQ_SCALE_INPUT_IDX = 1; +constexpr size_t QDQ_ZERO_POINT_INPUT_IDX = 2; + +/// +/// Utility function to get a child NodeUnit. The returned NodeUnit must be the parent's only child, must be +/// of the expected type, and must not be a part of another IQnnNodeGroup. +/// +/// GraphViewer containing all Nodes +/// Parent NodeUnit +/// Valid child types +/// Maps a Node to its NodeUnit +/// Maps a NodeUnit to its IQnnNodeGroup. +/// Used to check that the child has not already been added to another IQnnNodeGroup. +/// +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& node_unit_to_qnn_node_group); + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index c56a47e67497e..fc64d63ede338 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -16,10 +16,10 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/builder/qnn_fusions.h" #include "core/providers/partitioning_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/framework/run_options.h" @@ -412,25 +412,35 @@ QNNExecutionProvider::~QNNExecutionProvider() { #endif } -bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const { - const std::string& op_type = node_unit.OpType(); - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); +// Logs information about the supported/unsupported nodes. +static void LogNodeSupport(const logging::Logger& logger, + logging::Severity log_severity, + logging::DataType log_data_type, + const onnxruntime::CodeLocation& call_site, + const qnn::IQnnNodeGroup& qnn_node_group, + Status support_status) { + if (!logger.OutputIsEnabled(log_severity, log_data_type)) { + return; + } + + std::ostringstream oss; + oss << (support_status.IsOK() ? "Validation PASSED " : "Validation FAILED ") << "for nodes (" + << qnn_node_group.Type() << "):" << std::endl; + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + oss << "\tOperator type: " << node->OpType() + << " Node name: " << node->Name() + << " Node index: " << node->Index() << std::endl; } - supported = (Status::OK() == status); } - return supported; + if (!support_status.IsOK()) { + oss << "\tREASON : " << support_status.ErrorMessage() << std::endl; + } + + logging::Capture(logger, log_severity, logging::Category::onnxruntime, + log_data_type, call_site) + .Stream() + << oss.str(); } std::unordered_set @@ -469,68 +479,33 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, initializer_input_lookup, qnn_backend_manager_->GetQnnBackendType()); - std::unordered_set handled_node_units; - handled_node_units.reserve(node_unit_size); - - auto add_supported_nodes = [](std::unordered_set& supported_nodes, const NodeUnit* node_unit) { - for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { - supported_nodes.insert(node_in_group); - } - }; - - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - gsl::not_null node(graph_viewer.GetNode(node_indices[i])); - - // Get the node_unit associated with the node. Note that the node may not be the node_unit's target node. - const NodeUnit* node_unit = node_unit_map.at(node); - - // Visiting 'nodes' in topological order does not guarantee that 'node_units' are - // also visited in topological order. Skip this node if it is not the node_unit's target node - // to ensure 'node_units' are visited in topological order. - if (node != &node_unit->GetNode()) { - continue; - } + std::vector> qnn_node_groups; + qnn_node_groups.reserve(node_unit_size); - if (handled_node_units.count(node_unit) != 0) { - continue; // Already handled this node unit - } + if (Status status = qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, + node_unit_map, node_unit_size, logger); + !status.IsOK()) { + LOGS(logger, ERROR) << status.ErrorMessage(); + return {}; + } - // Try to see if this node unit can be fused. - std::vector fused_nodes; - Status fusion_status = TryFusions(fused_nodes, qnn_model_wrapper, *node_unit, node_unit_map, - handled_node_units, logger, true /*do_op_validation*/); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->IsSupported(qnn_model_wrapper, logger); + const bool supported = status.IsOK(); - if (!fusion_status.IsOK()) { - LOGS(logger, WARNING) << "Failed to apply fusion: " << fusion_status.ErrorMessage(); - handled_node_units.insert(node_unit); - continue; - } - - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - add_supported_nodes(supported_nodes, fused_node_unit); - } - continue; + constexpr auto log_severity = logging::Severity::kVERBOSE; + constexpr auto log_data_type = logging::DataType::SYSTEM; + if (logger.OutputIsEnabled(log_severity, log_data_type)) { + LogNodeSupport(logger, log_severity, log_data_type, ORT_WHERE, *qnn_node_group, status); } - // Couldn't fuse the node unit. See if it is supported by itself. - const bool supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; - if (supported) { - add_supported_nodes(supported_nodes, node_unit); + for (const NodeUnit* node_unit : qnn_node_group->GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + supported_nodes.insert(node); + } + } } - - handled_node_units.insert(node_unit); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index f00ffb6cfdb96..4c48370492ef7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -53,9 +53,6 @@ class QNNExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; private: - bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const; - std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 35889c9fa2307..95673586677ef 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace test { +// Information for activation node placed between the Conv and Q. +struct OutputActivationInfo { + std::string op_type; // Relu or Clip + std::vector const_inputs; +}; + // Creates a graph with a single float32 Conv operator. Used for testing CPU backend. static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, const TestInputDef& input_def, const TestInputDef& weights_def, @@ -23,9 +29,10 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& pads, const std::vector& dilations, std::optional group, - const std::string& auto_pad = "NOTSET") { + const std::string& auto_pad = "NOTSET", + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad](ModelTestBuilder& builder) { + dilations, group, auto_pad, output_activation](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -34,9 +41,9 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons conv_inputs.push_back(MakeTestInput(builder, bias_def)); } - auto* output = builder.MakeOutput(); + auto* conv_output = output_activation.has_value() ? builder.MakeIntermediate() : builder.MakeOutput(); - Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {output}); + Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {conv_output}); conv_node.AddAttribute("auto_pad", auto_pad); if (group.has_value()) { @@ -54,6 +61,15 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons if (!dilations.empty()) { conv_node.AddAttribute("dilations", dilations); } + + if (output_activation.has_value()) { + NodeArg* output = builder.MakeOutput(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {output}); + } }; } @@ -88,19 +104,22 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef // Creates a graph with a single Q/DQ Conv operator. Used for testing HTP backend. template -static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, group, auto_pad, + use_contrib_qdq, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -144,27 +163,39 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } template -static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - int64_t weight_quant_axis, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQPerChannelConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + int64_t weight_quant_axis, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, group, auto_pad, use_contrib_qdq, - weight_quant_axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { + weight_quant_axis, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -248,7 +279,17 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -267,7 +308,8 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -277,10 +319,11 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef #endif TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad), + group, auto_pad, output_activation), BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad, use_contrib_qdq), + group, auto_pad, use_contrib_qdq, + output_activation), provider_options, opset, expected_ep_assignment, @@ -302,7 +345,8 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -312,11 +356,11 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te #endif auto f32_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad); + group, auto_pad, output_activation); auto qdq_fn = BuildQDQPerChannelConvTestCase(conv_op_type, input_def, weights_def, bias_def, weight_quant_axis, strides, pads, dilations, group, auto_pad, - use_contrib_qdq); + use_contrib_qdq, output_activation); TestQDQModelAccuracy(f32_fn, qdq_fn, provider_options, opset, expected_ep_assignment, tolerance); } @@ -764,6 +808,140 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { 21); // opset } +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-tensor quantization. +TEST_F(QnnHTPBackendTests, ConvU8U8S32_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (NO bias) -> Relu -> Q + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + // Opset 6 Clip uses attributes for min/max + OutputActivationInfo clip_info = {"Clip", {0.0f, 2.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 19, // opset + QDQTolerance(), + clip_info); + + // DQs -> Conv (NO bias) -> Clip -> Q + OutputActivationInfo clip_info_2 = {"Clip", {-6.0f, 6.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info_2); +} + +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-channel quantization. +TEST_F(QnnHTPBackendTests, ConvS8S8S32_PerChannel_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + OutputActivationInfo clip_info = {"Clip", {0.0f, 6.0f}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info); +} + // Test per-channel QDQ Conv with INT4 weights and a negative weight quantization axis that still points to dimension 0. TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { std::vector input_shape = {1, 2, 4, 4}; From 1391354265b7adab1a240599c63e524f4628c418 Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Sat, 3 Aug 2024 00:16:42 +0200 Subject: [PATCH 36/40] Adding CUDNN Frontend and use for CUDA NN Convolution (#19470) ### Description Added CUDNN Frontend and used it for NHWC convolutions, and optionally fuse activation. #### Backward compatible - For model existed with FusedConv, model can still run. - If ORT is built with cuDNN 8, cuDNN frontend will not be built into binary. Old kernels (using cudnn backend APIs) are used. #### Major Changes - For cuDNN 9, we will enable cudnn frontend to fuse convolution and bias when a provider option `fuse_conv_bias=1`. - Remove the fusion of FusedConv from graph transformer for CUDA provider, so there will not be FusedConv be added to graph for CUDA EP in the future. - Update cmake files regarding to cudnn settings. The search order of CUDNN installation in build are like the following: * environment variable `CUDNN_PATH` * `onnxruntime_CUDNN_HOME` cmake extra defines. If a build starts from build.py/build.sh, user can pass it through `--cudnn_home` parameter, or by environment variable `CUDNN_HOME` if `--cudnn_home` not used. * cudnn python package installation directory like python3.xx/site-packages/nvidia/cudnn * CUDA installation path #### Potential Issues - If ORT is built with cuDNN 8, FusedConv fusion is no longer done automatically, so some model might have performance regression. If user still wants FusedConv operator for performance reason, they can still have multiple ways to walkaround: like use older version of onnxruntime; or use older version of ORT to save optimized onnx, then run with latest version of ORT. We believe that majority users have moved to cudnn 9 when 1.20 release (since the default in ORT and PyTorch is cudnn 9 for 3 months when 1.20 release), so the impact is small. - cuDNN graph uses TF32 by default, and user cannot disable TF32 through the use_tf32 cuda provider option. If user encounters accuracy issue (like in testing), user has to set environment variable `NVIDIA_TF32_OVERRIDE=0` to disable TF32. Need update the document of use_tf32 later. #### Follow ups This is one of PRs that target to enable NHWC convolution in CUDA EP by default if device supports it. There are other changes will follow up to make it possible. (1) Enable `prefer_nhwc` by default for device with sm >= 70. (2) Change `fuse_conv_bias=1` by default after more testing. (3) Add other NHWC operators (like Resize or UpSample). ### Motivation and Context The new CUDNN Frontend library provides the functionality to fuse operations and provides new heuristics for kernel selection. Here it fuses the convolution with the pointwise bias operation. On the [NVIDIA ResNet50](https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/) we get a performance boost from 49.1144 ms to 42.4643 ms per inference on a 2560x1440 input (`onnxruntime_perf_test -e cuda -I -q -r 100-d 1 -i 'prefer_nhwc|1' resnet50.onnx`). --------- Co-authored-by: Tianlei Wu Co-authored-by: Maximilian Mueller --- cgmanifests/generated/cgmanifest.json | 10 + cmake/CMakeLists.txt | 6 - cmake/deps.txt | 1 + cmake/external/cuDNN.cmake | 111 ++++ cmake/external/cudnn_frontend.cmake | 12 + .../external/onnxruntime_external_deps.cmake | 16 +- cmake/onnxruntime_framework.cmake | 2 +- cmake/onnxruntime_providers_cuda.cmake | 14 +- cmake/onnxruntime_providers_tensorrt.cmake | 5 +- cmake/onnxruntime_python.cmake | 6 +- cmake/onnxruntime_rocm_hipify.cmake | 2 + cmake/onnxruntime_session.cmake | 4 +- cmake/onnxruntime_training.cmake | 7 - cmake/onnxruntime_unittests.cmake | 8 +- .../core/providers/cuda/cuda_context.h | 2 + .../providers/cuda/cuda_provider_options.h | 1 + .../core/providers/cuda/cuda_resource.h | 1 + .../contrib_ops/cuda/diffusion/nhwc_conv.cc | 2 +- onnxruntime/contrib_ops/cuda/fused_conv.cc | 419 +++++++++++-- .../core/optimizer/conv_activation_fusion.cc | 61 +- .../core/optimizer/graph_transformer_utils.cc | 7 +- onnxruntime/core/providers/common.h | 25 +- onnxruntime/core/providers/cuda/cuda_call.cc | 16 +- .../providers/cuda/cuda_execution_provider.cc | 17 + .../providers/cuda/cuda_execution_provider.h | 1 + .../cuda/cuda_execution_provider_info.cc | 4 + .../cuda/cuda_execution_provider_info.h | 4 +- .../providers/cuda/cuda_provider_factory.cc | 2 + .../core/providers/cuda/cuda_stream_handle.cc | 3 + .../core/providers/cuda/cudnn_common.cc | 110 +++- .../core/providers/cuda/cudnn_common.h | 31 +- .../core/providers/cuda/cudnn_fe_call.cc | 119 ++++ onnxruntime/core/providers/cuda/nn/conv.cc | 569 +++++++++--------- onnxruntime/core/providers/cuda/nn/conv.h | 50 +- onnxruntime/core/providers/cuda/nn/conv_8.h | 484 +++++++++++++++ .../providers/cuda/shared_inc/cuda_call.h | 6 +- .../providers/cuda/shared_inc/cudnn_fe_call.h | 23 + .../providers/cuda/shared_inc/fpgeneric.h | 74 +-- .../core/session/provider_bridge_ort.cc | 1 + .../test/contrib_ops/nhwc_conv_op_test.cc | 3 +- onnxruntime/test/onnx/main.cc | 1 + .../test/optimizer/graph_transform_test.cc | 90 +-- .../test/providers/cpu/nn/conv_op_test.cc | 30 +- setup.py | 1 + .../templates/download-deps.yml | 4 +- 45 files changed, 1806 insertions(+), 559 deletions(-) create mode 100644 cmake/external/cuDNN.cmake create mode 100644 cmake/external/cudnn_frontend.cmake create mode 100644 onnxruntime/core/providers/cuda/cudnn_fe_call.cc create mode 100644 onnxruntime/core/providers/cuda/nn/conv_8.h create mode 100644 onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 7de3f346f6386..f9e702b894f56 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -351,6 +351,16 @@ }, "comments": "directx_headers" } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "98ca4e1941fe3263f128f74f10063a3ea35c7019", + "repositoryUrl": "https://github.com/NVIDIA/cudnn-frontend.git" + }, + "comments": "cudnn_frontend" + } } ] } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5555fa692eae8..2e9a50e522171 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -729,9 +729,6 @@ set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) if (onnxruntime_USE_CUDA) - if (onnxruntime_USE_CUDA_NHWC_OPS) - add_compile_definitions(ENABLE_CUDA_NHWC_OPS) - endif() enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -1445,9 +1442,6 @@ if (onnxruntime_USE_CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) endif() find_package(CUDAToolkit REQUIRED) - if(onnxruntime_CUDNN_HOME) - file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) - endif() if (NOT CMAKE_CUDA_ARCHITECTURES) if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") # Support for Jetson/Tegra ARM devices diff --git a/cmake/deps.txt b/cmake/deps.txt index d0edf963451d5..56d77d9e70023 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -58,3 +58,4 @@ utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e +cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 diff --git a/cmake/external/cuDNN.cmake b/cmake/external/cuDNN.cmake new file mode 100644 index 0000000000000..3d05f6406a80e --- /dev/null +++ b/cmake/external/cuDNN.cmake @@ -0,0 +1,111 @@ +add_library(CUDNN::cudnn_all INTERFACE IMPORTED) + +find_path( + CUDNN_INCLUDE_DIR cudnn.h + HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS} + PATH_SUFFIXES include + REQUIRED +) + +file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header) +string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}") +string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}") + +function(find_cudnn_library NAME) + find_library( + ${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" + HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib64 lib/x64 lib + REQUIRED + ) + + if(${NAME}_LIBRARY) + add_library(CUDNN::${NAME} UNKNOWN IMPORTED) + set_target_properties( + CUDNN::${NAME} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} + IMPORTED_LOCATION ${${NAME}_LIBRARY} + ) + message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.") + else() + message(STATUS "${NAME} not found.") + endif() + + +endfunction() + +find_cudnn_library(cudnn) + +include (FindPackageHandleStandardArgs) +find_package_handle_standard_args( + LIBRARY REQUIRED_VARS + CUDNN_INCLUDE_DIR cudnn_LIBRARY +) + +if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY) + + message(STATUS "cuDNN: ${cudnn_LIBRARY}") + message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}") + + set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") + +else() + + set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found") + +endif() + +target_include_directories( + CUDNN::cudnn_all + INTERFACE + $ + $ +) + +target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn +) + +if(CUDNN_MAJOR_VERSION EQUAL 8) + find_cudnn_library(cudnn_adv_infer) + find_cudnn_library(cudnn_adv_train) + find_cudnn_library(cudnn_cnn_infer) + find_cudnn_library(cudnn_cnn_train) + find_cudnn_library(cudnn_ops_infer) + find_cudnn_library(cudnn_ops_train) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv_train + CUDNN::cudnn_ops_train + CUDNN::cudnn_cnn_train + CUDNN::cudnn_adv_infer + CUDNN::cudnn_cnn_infer + CUDNN::cudnn_ops_infer + ) +elseif(CUDNN_MAJOR_VERSION EQUAL 9) + find_cudnn_library(cudnn_cnn) + find_cudnn_library(cudnn_adv) + find_cudnn_library(cudnn_graph) + find_cudnn_library(cudnn_ops) + find_cudnn_library(cudnn_engines_runtime_compiled) + find_cudnn_library(cudnn_engines_precompiled) + find_cudnn_library(cudnn_heuristic) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv + CUDNN::cudnn_ops + CUDNN::cudnn_cnn + CUDNN::cudnn_graph + CUDNN::cudnn_engines_runtime_compiled + CUDNN::cudnn_engines_precompiled + CUDNN::cudnn_heuristic + ) +endif() + +mark_as_advanced(CUDNN_INCLUDE_DIR) diff --git a/cmake/external/cudnn_frontend.cmake b/cmake/external/cudnn_frontend.cmake new file mode 100644 index 0000000000000..7ac51deaf93bc --- /dev/null +++ b/cmake/external/cudnn_frontend.cmake @@ -0,0 +1,12 @@ +include(FetchContent) +FetchContent_Declare( + cudnn_frontend + URL ${DEP_URL_cudnn_frontend} + URL_HASH SHA1=${DEP_SHA1_cudnn_frontend} +) + +set(CUDNN_FRONTEND_BUILD_SAMPLES OFF) +set(CUDNN_FRONTEND_BUILD_UNIT_TESTS OFF) +set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF) +set(CUDNN_PATH ${onnxruntime_CUDNN_HOME}) +FetchContent_MakeAvailable(cudnn_frontend) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 775576a771529..4e52707474052 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -587,20 +587,16 @@ endif() message("Finished fetching external dependencies") - set(onnxruntime_LINK_DIRS ) + if (onnxruntime_USE_CUDA) - #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same find_package(CUDAToolkit REQUIRED) - if (WIN32) - if(onnxruntime_CUDNN_HOME) - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) - endif() - else() - if(onnxruntime_CUDNN_HOME) - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) - endif() + + if(onnxruntime_CUDNN_HOME) + file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) + set(CUDNN_PATH ${onnxruntime_CUDNN_HOME}) endif() + include(cuDNN) endif() if(onnxruntime_USE_SNPE) diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index b85edbf37d447..9f8d807fad8f4 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -69,7 +69,7 @@ endif() if(onnxruntime_USE_TENSORRT OR onnxruntime_USE_NCCL) # TODO: for now, core framework depends on CUDA. It should be moved to TensorRT EP # TODO: provider_bridge_ort.cc should not include nccl.h -target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_CUDNN_HOME}/include PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) else() target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 0829be05a3ab0..774b7a4f6bd77 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -197,12 +197,16 @@ target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) else() - target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart - ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) - if(onnxruntime_CUDNN_HOME) - target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + include(cudnn_frontend) # also defines CUDNN::* + if (onnxruntime_USE_CUDA_NHWC_OPS) + if(CUDNN_MAJOR_VERSION GREATER 8) + add_compile_definitions(ENABLE_CUDA_NHWC_OPS) + else() + message( WARNING "To compile with NHWC ops enabled please compile against cuDNN 9 or newer." ) + endif() endif() + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) endif() if (onnxruntime_USE_TRITON_KERNEL) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 3d46c139feea9..468aaa44ec4ee 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -159,7 +159,7 @@ if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() - set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + set(trt_link_libs CUDNN::cudnn_all cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) endif() file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" @@ -183,9 +183,6 @@ endif() target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) - if(onnxruntime_CUDNN_HOME) - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) - endif() # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 270139ceaff7b..372db15b108fb 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -98,11 +98,7 @@ endif() onnxruntime_add_include_to_target(onnxruntime_pybind11_state Python::Module Python::NumPy) target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT} ${pybind11_INCLUDE_DIRS}) if(onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - # cudnn_home is optional for Window when cuda and cudnn are installed in the same directory. - if(onnxruntime_CUDNN_HOME) - target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CUDNN_HOME}/include) - endif() + target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR}) endif() if(onnxruntime_USE_CANN) target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index a8c876d30873e..1740144cf6553 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -145,6 +145,7 @@ set(provider_excluded_files "rnn/rnn_impl.cu" "rnn/rnn_impl.h" "shared_inc/cuda_call.h" + "shared_inc/cudnn_fe_call.h" "shared_inc/fpgeneric.h" "cuda_allocator.cc" "cuda_allocator.h" @@ -171,6 +172,7 @@ set(provider_excluded_files "cuda_utils.cu" "cudnn_common.cc" "cudnn_common.h" + "cudnn_fe_call.cc" "cupti_manager.cc" "cupti_manager.h" "fpgeneric.cu" diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 79bee3bdb65d6..b51c875951135 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -44,9 +44,7 @@ if (onnxruntime_USE_EXTENSIONS) endif() add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime") -if (onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -endif() + if (onnxruntime_USE_ROCM) target_compile_options(onnxruntime_session PRIVATE -Wno-sign-compare -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index 01590a431205c..b633a9c2de378 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -39,10 +39,6 @@ endif() target_include_directories(onnxruntime_training PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header} ${MPI_CXX_INCLUDE_DIRS}) -if (onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_training PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -endif() - if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_training PRIVATE ${NCCL_INCLUDE_DIRS}) endif() @@ -81,9 +77,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) target_include_directories(onnxruntime_training_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header}) target_link_libraries(onnxruntime_training_runner PRIVATE nlohmann_json::nlohmann_json) - if (onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_training_runner PUBLIC ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_training_runner PRIVATE ${NCCL_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0c1e5e93c6844..ec396fbcc6117 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -75,7 +75,7 @@ function(AddTest) onnxruntime_add_include_to_target(${_UT_TARGET} date::date flatbuffers::flatbuffers) target_include_directories(${_UT_TARGET} PRIVATE ${TEST_INC_DIR}) if (onnxruntime_USE_CUDA) - target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) + target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR}) if (onnxruntime_USE_NCCL) target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS}) endif() @@ -392,7 +392,7 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R ) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src}) - if (onnxruntime_USE_CUDA_NHWC_OPS) + if (onnxruntime_USE_CUDA_NHWC_OPS AND CUDNN_MAJOR_VERSION GREATER 8) file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc" ) @@ -1498,7 +1498,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND custom_op_src_patterns "${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu" "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") - list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) + list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR}) if (HAS_QSPECTRE) list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 462b31bb433a5..12a19759bb979 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -40,6 +40,7 @@ struct CudaContext : public CustomOpContext { bool enable_skip_layer_norm_strict_mode = false; bool prefer_nhwc = false; bool use_tf32 = true; + bool fuse_conv_bias = true; void Init(const OrtKernelContext& kernel_ctx) { cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); @@ -57,6 +58,7 @@ struct CudaContext : public CustomOpContext { kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); + fuse_conv_bias = FetchResource(kernel_ctx, CudaResource::fuse_conv_bias_t); } template diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 01a14de699dc4..3b7a1e99346d7 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -38,5 +38,6 @@ struct OrtCUDAProviderOptionsV2 { int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not int use_tf32 = 1; // use TF32 + int fuse_conv_bias = 0; // Enable CUDNN Frontend kernel fusing, results in JIT compiles int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index 555023c442c01..b248d33035bc1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -19,4 +19,5 @@ enum CudaResource : int { enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, use_tf32_t, + fuse_conv_bias_t }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc index 79f0a18ba515f..c1459cb9d08d2 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc @@ -21,7 +21,7 @@ namespace cuda { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); + onnxruntime::cuda::Conv); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index e126f8bcb3d11..279df73ee3d45 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -3,17 +3,66 @@ #include "core/common/status.h" #include "core/providers/cuda/nn/conv.h" +#include "core/providers/cuda/tensor/slice.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace contrib { namespace cuda { +Status SliceOutUnwantedOutputSection(cudaStream_t stream, + const void* input_data, gsl::span input_dims, + void* output_data, + const gsl::span& output_dims, + const gsl::span& starts, + const gsl::span& ends, + const gsl::span& axes, + size_t element_size) { + SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); + + ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); + + // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape + ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); + + return ::onnxruntime::cuda::SliceCuda::Impl(stream, input_data, input_dims, output_data, + compute_metadata, element_size); +} + +static cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, + const ::onnxruntime::cuda::CudnnConvState& s, + cudnnConvolutionFwdAlgo_t algo, + size_t* sz) { + return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); +} + +static size_t GetMaxWorkspaceSize(cudnnHandle_t handle, + const ::onnxruntime::cuda::CudnnConvState& s, + const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { + // TODO: get maximum available size from memory arena + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation + free = static_cast(static_cast(free) * 0.9); + size_t max_ws_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t err; + size_t sz; + err = GetWorkspaceSize(handle, s, algo[i], &sz); + if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; + max_ws_size = sz; + } + return max_ws_size; +} + template -class FusedConv : public onnxruntime::cuda::Conv { +class FusedConv : public onnxruntime::cuda::CudaKernel { + using CudaT = typename ::onnxruntime::cuda::ToCudaType::MappedType; + public: - using Base = onnxruntime::cuda::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { + FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::CudaKernel(info), conv_attrs_(info) { + auto pads_size = conv_attrs_.pads.size(); + ORT_ENFORCE(pads_size % 2 == 0); std::string activation; ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); ORT_THROW_IF_ERROR(MapMode(activation)); @@ -32,66 +81,331 @@ class FusedConv : public onnxruntime::cuda::Conv { } } + Status UpdateState(OpKernelContext* context, bool bias_expected) const { + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + const auto x_dims = x_shape.AsShapeVector(); + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + // set W + const Tensor* W = context->Input(1); + const TensorShape& w_shape = W->Shape(); + auto w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // set B + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } + // set Z + if (context->InputCount() >= 4) { + const Tensor* Z = context->Input(3); + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), + ::onnxruntime::cuda::CudnnTensor::GetDataType())); + s_.z_data = reinterpret_cast(Z->Data()); + } else { + s_.z_data = nullptr; + } + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape())); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + + const size_t kernel_rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(kernel_rank * 2, 0); + } + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_rank, 1); + } + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_rank, 1); + } + + TensorShapeVector y_dims; + y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' + + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + y_dims.insert(y_dims.begin(), {N, M}); + + bool post_slicing_required = false; + TensorShapeVector slice_starts; + slice_starts.reserve(kernel_rank); + + TensorShapeVector slice_ends; + slice_ends.reserve(kernel_rank); + + TensorShapeVector slice_axes; + slice_axes.reserve(kernel_rank); + + constexpr size_t spatial_dim_start = 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); + + TensorShapeVector y_dims_with_adjusted_pads(y_dims); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, + strides, dilations, pads, y_dims, + y_dims_with_adjusted_pads, post_slicing_required, + slice_starts, slice_ends, slice_axes)); + + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); + s_.y_dims = gsl::make_span(y_dims); + s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; + s_.post_slicing_required = post_slicing_required; + s_.slice_starts = slice_starts; + s_.slice_ends = slice_ends; + s_.slice_axes = slice_axes; + + s_.Y = context->Output(0, TensorShape(s_.y_dims)); + if (post_slicing_required) { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + s_.memory_for_cudnn_conv_results = + GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, + context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + // No post slicing needed. Fill the output tensor's buffer directly. + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + if (w_dims_changed) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + ::onnxruntime::cuda::CudnnTensor::GetDataType(), UseTF32())); + + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + const auto& b_shape = B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = b_shape[0]; + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + // s_.b_data = reinterpret_cast(B->Data()); + } else if (bias_expected) { + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = w_dims[0]; + auto malloc_size = b_dims[1] * sizeof(CudaT); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + if (s_.b_zero) { + CUDA_CALL_THROW(cudaFree(s_.b_zero)); + s_.b_zero = nullptr; + } + CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); + CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); + } + + if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionFwdAlgoPerf_t perf; + int algo_count = 1; + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", + cudnn_conv_algo); + switch (cudnn_conv_algo) { + case 0: { + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), + s_, kAllAlgos, num_algos) + : ::onnxruntime::cuda::AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + GetCudnnHandle(context), + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.y_tensor, + s_.y_data, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf, + algo_search_workspace.get(), + max_ws_size)); + break; + } + case 1: + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( + GetCudnnHandle(context), + s_.x_tensor, + s_.w_desc, + s_.conv_desc, + s_.y_tensor, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf)); + break; + + default: + perf.algo = kDefaultConvAlgo; + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); + + if constexpr (std::is_same::value) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; + } else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + } + s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); + } + const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (s_.post_slicing_required) { + s_.memory_for_cudnn_conv_results = GetScratchBuffer( + TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + } + return Status::OK(); + } + Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); + std::lock_guard lock(s_.mutex); auto cudnnHandle = this->GetCudnnHandle(context); - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { + ORT_RETURN_IF_ERROR(UpdateState(context, true)); + if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - typedef typename onnxruntime::cuda::ToCudaType::MappedType CudaT; + bool has_z = nullptr != s_.z_data; + bool has_b = nullptr != s_.b_data; const auto alpha = onnxruntime::cuda::Consts::One; const auto beta = onnxruntime::cuda::Consts::Zero; - IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); auto cudnn_status = cudnnConvolutionBiasActivationForward(cudnnHandle, &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.algo, workspace.get(), - Base::s_.workspace_bytes, + s_.workspace_bytes, has_z ? &alpha : &beta, - has_z ? Base::s_.z_tensor : Base::s_.y_tensor, - has_z ? Base::s_.z_data : Base::s_.y_data, - Base::s_.b_tensor, - has_b ? Base::s_.b_data : Base::s_.b_zero, + has_z ? s_.z_tensor : s_.y_tensor, + has_z ? s_.z_data : s_.y_data, + s_.b_tensor, + has_b ? s_.b_data : s_.b_zero, activation_desc_, - Base::s_.y_tensor, - Base::s_.y_data); + s_.y_tensor, + s_.y_data); if (CUDNN_STATUS_SUCCESS != cudnn_status) { CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnnHandle, &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.algo, workspace.get(), - Base::s_.workspace_bytes, + s_.workspace_bytes, &beta, - Base::s_.y_tensor, - Base::s_.y_data)); + s_.y_tensor, + s_.y_data)); if (has_b) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data)); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); } if (has_z) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data)); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, s_.z_tensor, s_.z_data, + &alpha, s_.y_tensor, s_.y_data)); } - CUDNN_RETURN_IF_ERROR(cudnnActivationForward(cudnnHandle, activation_desc_, &alpha, Base::s_.y_tensor, - Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data)); + CUDNN_RETURN_IF_ERROR(cudnnActivationForward(cudnnHandle, activation_desc_, &alpha, s_.y_tensor, + s_.y_data, &beta, s_.y_tensor, s_.y_data)); } - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection( - this->Stream(context), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size)); + if (s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection( + this->Stream(context), s_.y_data, s_.y_dims_with_adjusted_pads, s_.Y->MutableDataRaw(), + s_.y_dims.GetDims(), s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size)); } return Status::OK(); } @@ -107,6 +421,25 @@ class FusedConv : public onnxruntime::cuda::Conv { } return Status::OK(); } + + inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + return GetScratchBuffer(s_.workspace_bytes, stream); + } + + ConvAttributes conv_attrs_; + mutable ::onnxruntime::cuda::CudnnConvState s_; + constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + constexpr static cudnnConvolutionFwdAlgo_t kAllAlgos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + cudnnActivationMode_t activation_mode_; cudnnActivationDescriptor_t activation_desc_ = nullptr; }; @@ -122,4 +455,4 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( } // namespace cuda } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 12746ad53123a..c7d2d95e6121d 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -58,16 +58,11 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { } bool ConvFusionDataTypeCheck(const Node& conv_node) { - // TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation + // TODO(hasesh): The CPU EP only supports float type for the Conv+Activation // and the Conv+Add+Relu fusions. // Assess the support level for the other compatible EPs and if they also // only support float, remove the EP check altogether. const std::string_view node_ep = conv_node.GetExecutionProviderType(); - if (node_ep == kCudaExecutionProvider) { - if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { - return false; - } - } if (node_ep == kCpuExecutionProvider) { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) && @@ -120,7 +115,9 @@ class ConvActivationSelector : public NodeSelector { } // check EP type and activation - if (node_ep == kCudaExecutionProvider || node_ep == kRocmExecutionProvider) { + if (node_ep == kCudaExecutionProvider) { + return std::nullopt; + } else if (node_ep == kRocmExecutionProvider) { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } @@ -142,43 +139,6 @@ class ConvActivationSelector : public NodeSelector { } }; -class ConvAddRelu : public NodeSelector { - public: - ConvAddRelu() = default; - - std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override { - const std::string_view node_ep = node.GetExecutionProviderType(); - // only for CUDA EP - if (node_ep != kCudaExecutionProvider) { - return std::nullopt; - } - - if (!ConvFusionDataTypeCheck(node)) { - return std::nullopt; - } - - const auto* add_node = GetLoneConsumerNode(graph_viewer, node); - if (!add_node || - !graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) || - add_node->GetExecutionProviderType() != node_ep) { - return std::nullopt; - } - - const auto* relu_node = GetLoneConsumerNode(graph_viewer, *add_node); - if (!relu_node || - !graph_utils::IsSupportedOptypeVersionAndDomain(*relu_node, "Relu", {6, 13, 14}) || - relu_node->GetExecutionProviderType() != node_ep) { - return std::nullopt; - } - - NodesToOptimizeIndicesBuilder builder{}; - builder.target_node = node.Index(); - builder.output_nodes = {add_node->Index(), - relu_node->Index()}; - return builder.Build(); - } -}; - } // namespace selectors #endif // !defined(ORT_MINIMAL_BUILD) @@ -304,22 +264,9 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { #endif } -void RegisterConvAddReluFusionRules(SelectorActionRegistry& registry) { - const auto name = "ConvAddRelu"; - auto action = std::make_unique(); -#if !defined(ORT_MINIMAL_BUILD) - auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, - std::move(selector), std::move(action)); -#else - registry.RegisterAction(name, std::move(action)); -#endif -} - SelectorActionRegistry CreateSelectorActionRegistry() { SelectorActionRegistry registry{}; RegisterConvActivationFusionRules(registry); - RegisterConvAddReluFusionRules(registry); return registry; } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 54bd44ec2dba4..08284e67277e1 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -282,6 +282,11 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, @@ -318,7 +323,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_dml_eps)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); + transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/providers/common.h b/onnxruntime/core/providers/common.h index 564c7690c8da1..7576dfba5c85e 100644 --- a/onnxruntime/core/providers/common.h +++ b/onnxruntime/core/providers/common.h @@ -83,6 +83,14 @@ inline AutoPadType StringToAutoPadType(const std::string& str) { // helper function +constexpr inline int64_t ComputeOutputShape(const int64_t in_dim, + const int64_t stride, const int64_t kernel, const int64_t dilation, + const int64_t pad_head, const int64_t pad_tail) { + const SafeInt dkernel = SafeInt(dilation) * (kernel - 1) + 1; + int64_t dkernel_value = SafeInt(in_dim) + pad_head + pad_tail - dkernel; + return static_cast(static_cast(dkernel_value) / stride + 1); +} + inline Status ComputePad(const int64_t in_dim, const int64_t stride, const int64_t kernel, const int64_t dilation, AutoPadType pad_type, @@ -106,6 +114,15 @@ inline Status ComputePad(const int64_t in_dim, // is retained as is SafeInt legacy_target_size = (SafeInt(in_dim) + stride - 1) / stride; SafeInt pad_needed = (legacy_target_size - 1) * stride + kernel - in_dim; + // out_dim = floor((in_dim + 2p - k) / s) + 1 + // => if (in_dim + 2p - k) is not divisible by s we can remove the floor with following equation: + // out_dim + eps = ((in_dim + 2p - k) / s) + 1 ;where eps is in [0.0, 1.0] + // therefore in edge cases padding can lower calculated above than it should be + SafeInt actual_out_size = ComputeOutputShape(in_dim, stride, kernel, /*dilation*/ 1, + pad_needed, pad_needed); + if (actual_out_size < legacy_target_size) { + pad_needed += 1; + } // make sure padding is symmetric if (force_symmetric_auto_padding) { // Inlining math::roundUpPow2() from util/math.h to avoid bringing in the transitive dependencies. @@ -126,14 +143,6 @@ inline Status ComputePad(const int64_t in_dim, return Status::OK(); } -constexpr inline int64_t ComputeOutputShape(const int64_t in_dim, - const int64_t stride, const int64_t kernel, const int64_t dilation, - const int64_t pad_head, const int64_t pad_tail) { - const SafeInt dkernel = SafeInt(dilation) * (kernel - 1) + 1; - int64_t dkernel_value = SafeInt(in_dim) + pad_head + pad_tail - dkernel; - return static_cast(static_cast(dkernel_value) / stride + 1); -} - inline Status ComputePadAndOutputShape(const int64_t in_dim, const int64_t stride, const int64_t kernel, const int64_t dilation, AutoPadType pad_type, diff --git a/onnxruntime/core/providers/cuda/cuda_call.cc b/onnxruntime/core/providers/cuda/cuda_call.cc index c73b23f3762ed..511a6e2dce199 100644 --- a/onnxruntime/core/providers/cuda/cuda_call.cc +++ b/onnxruntime/core/providers/cuda/cuda_call.cc @@ -34,7 +34,6 @@ const char* CudaErrString(cudaError_t x) { template <> const char* CudaErrString(cublasStatus_t e) { cudaDeviceSynchronize(); - switch (e) { CASE_ENUM_TO_STR(CUBLAS_STATUS_SUCCESS); CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED); @@ -87,9 +86,15 @@ const char* CudaErrString(ncclResult_t e) { } #endif -template +template +int GetErrorCode(ERRTYPE err) { + return static_cast(err); +} + +template std::conditional_t CudaCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line) { if (retCode != successCode) { try { #ifdef _WIN32 @@ -108,7 +113,7 @@ std::conditional_t CudaCall( cudaGetLastError(); // clear last CUDA error static char str[1024]; snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, CudaErrString(retCode), currentCudaDevice, + libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice, hostname, file, line, exprString, msg); if constexpr (THRW) { @@ -118,7 +123,8 @@ std::conditional_t CudaCall( LOGS_DEFAULT(ERROR) << str; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); } - } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, + // so we'd never get to see the error if constexpr (THRW) { ORT_THROW(e.what()); } else { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 5771380433b35..f74754c3cd064 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -180,6 +180,7 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_)); CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); + LOGS_DEFAULT(INFO) << "cuDNN version: " << cudnnGetVersion(); #endif cuda_graph_.SetStream(stream); } @@ -2469,6 +2470,19 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const return false; } +static bool NhwcConvNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, + [[maybe_unused]] const GraphViewer& graph_viewer, + [[maybe_unused]] const bool prefer_nhwc) { + // NHWC implementation doesn't handle W in NHWC layout if it's not an initializer + if (!graph_viewer.IsConstantInitializer(node.InputDefs()[1]->Name(), true)) { + LOGS(logger, WARNING) << "Dropping the NhwcConv node: " << node.Name() + << " to CPU because the Cuda EP requires W as initializer for NHWC operation."; + return true; + } + + return false; +} + static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { const auto& node_attributes = node.GetAttributes(); // Check attributes @@ -2539,6 +2553,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside + } else if ("NhwcConv" == node.OpType()) { + not_supported = NhwcConvNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred()); + force_inside = !not_supported; } if (!force_inside && not_supported) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 9c8a8712ca51c..0871f7e4d0a74 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -82,6 +82,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } bool IsNHWCPreferred() const { return info_.prefer_nhwc; } + bool IsFuseConvBias() const { return info_.fuse_conv_bias; } bool UseTF32() const { return info_.use_tf32; } #ifndef DISABLE_CONTRIB_OPS diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 31cf991a34fc9..14195703d5963 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; constexpr const char* kUseTF32 = "use_tf32"; +constexpr const char* kFuseConvBias = "fuse_conv_bias"; constexpr const char* kSdpaKernel = "sdpa_kernel"; } // namespace provider_option_names @@ -119,6 +120,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) .AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel) + .AddAssignmentToReference(cuda::provider_option_names::kFuseConvBias, info.fuse_conv_bias) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -173,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, + {cuda::provider_option_names::kFuseConvBias, MakeStringWithClassicLocale(info.fuse_conv_bias)}, }; return options; @@ -195,6 +198,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kFuseConvBias, MakeStringWithClassicLocale(info.fuse_conv_bias)}, {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 0efad80f743df..bfd50ca8d40a1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -78,6 +78,7 @@ struct CUDAExecutionProviderInfo { // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. bool use_tf32{true}; + bool fuse_conv_bias{false}; int sdpa_kernel{0}; @@ -107,7 +108,8 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ (static_cast(info.prefer_nhwc) << 28) ^ (static_cast(info.use_ep_level_unified_stream) << 29) ^ - (static_cast(info.use_tf32) << 30); + (static_cast(info.use_tf32) << 30) ^ + (static_cast(info.fuse_conv_bias) << 31); onnxruntime::HashCombine(data, value); onnxruntime::HashCombine(info.gpu_mem_limit, value); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index b1d54e56ded4e..83a5d02d16c6c 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -219,6 +219,7 @@ struct CUDA_Provider : Provider { info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; info.prefer_nhwc = params->prefer_nhwc; + info.fuse_conv_bias = params->fuse_conv_bias; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; @@ -262,6 +263,7 @@ struct CUDA_Provider : Provider { cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; cuda_options.use_tf32 = internal_options.use_tf32; cuda_options.sdpa_kernel = internal_options.sdpa_kernel; + cuda_options.fuse_conv_bias = internal_options.fuse_conv_bias; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 14b75d2383b58..e9b159516dad9 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -215,6 +215,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::prefer_nhwc_t: return reinterpret_cast(ep_info_.prefer_nhwc); break; + case CudaResource::fuse_conv_bias_t: + return reinterpret_cast(ep_info_.fuse_conv_bias); + break; case CudaResource::use_tf32_t: return reinterpret_cast(ep_info_.use_tf32); break; diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 31fc63a86d640..72482266a2eef 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -3,6 +3,8 @@ // Licensed under the MIT License. #include +#include +#include #include "core/providers/cuda/cudnn_common.h" #include "core/common/inlined_containers.h" @@ -60,7 +62,25 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat dims[1] = gsl::narrow_cast(input_dims[rank - 1]); strides[1] = gsl::narrow_cast(pitches[rank - 1]); } - CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); + CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), + dims.data(), strides.data())); + return Status::OK(); +} + +Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dataType, + gsl::span input_strides) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + + int rank = gsl::narrow_cast(input_dims.size()); + InlinedVector dims(rank); + InlinedVector strides(rank); + + for (int i = 0; i < rank; i++) { + dims[i] = gsl::narrow_cast(input_dims[i]); + strides[i] = gsl::narrow_cast(input_strides[i]); + } + CUDNN_RETURN_IF_ERROR( + cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); } @@ -100,7 +120,8 @@ Status CudnnDataTensor::Set(cudnnDataType_t dataType, const int32_t* seq_lengths) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, + // so that it will auto fill 0 for the shorter sequences cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; float padding_fill = 0.0f; CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout, @@ -238,6 +259,91 @@ const Float8E5M2 Consts::One = Float8E5M2(1.0f, true); #endif +std::vector generateStrides(const std::vector& shape, bool channels_last) { + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + std::vector strides(shape.size()); + int64_t nbDims = strides.size(); + if (nbDims <= 1) { + strides[0] = 1; + return strides; + } + if (channels_last) { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strides[1] = 1; + strides[nbDims - 1] = strides[1] * shape[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strides[d] = strides[d + 1] * shape[d + 1]; + } + strides[0] = strides[2] * shape[2]; + } else { + strides[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strides[d] = strides[d + 1] * shape[d + 1]; + } + } + return strides; +} + +#if !defined(__CUDACC__) +CudnnFeTensor::CudnnFeTensor(const onnxruntime::TensorShapeVector& shape, + const std::string& name, + std::optional dtype, + const bool nhwc) { + std::vector shape_vec; + if (shape.size() == 1) { + shape_vec = {1, shape[0], 1, 1}; + } else if (shape.size() >= 4) { + for (size_t i = 0; i < shape.size(); i++) { + shape_vec.push_back(shape[i]); + } + } else { + ORT_THROW("Invalid tensor shape size, tensor name: ", name, ", shape size: ", shape.size()); + } + auto strides = generateStrides(shape_vec, nhwc); + + if (dtype.has_value()) { + tensor_ = cudnn_frontend::graph::Tensor_attributes() + .set_name(name) + .set_dim(shape_vec) + .set_stride(strides) + .set_data_type(dtype.value()); + } else { + tensor_ = cudnn_frontend::graph::Tensor_attributes().set_name(name).set_dim(shape_vec).set_stride(strides); + } +} + +template +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::NOT_SET; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::FLOAT; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::HALF; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::DOUBLE; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::INT8; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::UINT8; +} +#endif + } // namespace cuda } // namespace onnxruntime #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 2cbeb13696270..b267ef6bed64f 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -5,12 +5,22 @@ #pragma once #include +#include +#include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" + #ifndef USE_CUDA_MINIMAL +#if !defined(__CUDACC__) +#include +#endif + namespace onnxruntime { namespace cuda { +#define CUDNN_FE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_FE_CALL(expr)) + class CudnnTensor final { public: CudnnTensor(); @@ -19,6 +29,7 @@ class CudnnTensor final { Status Set(gsl::span input_dims, cudnnDataType_t dataType, bool is_nhwc = false); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); + Status Set(gsl::span input_dims, cudnnDataType_t dataType, gsl::span input_strides); // Set 4D tensor format (for NHWC) Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); @@ -139,7 +150,8 @@ struct Consts { inline double ClampCudnnBatchNormEpsilon(double epsilon) { if (epsilon < CUDNN_BN_MIN_EPSILON) { if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) - LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. Setting it to CUDNN_BN_MIN_EPSILON"; + LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. " + << "Setting it to CUDNN_BN_MIN_EPSILON"; return CUDNN_BN_MIN_EPSILON; } return epsilon; @@ -258,6 +270,23 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc, return cudnnSetPoolingNdDescriptor(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); } +std::vector generateStrides(const std::vector& shape, bool channels_last); + +#if !defined(__CUDACC__) +class CudnnFeTensor final { + public: + CudnnFeTensor(const onnxruntime::TensorShapeVector& shape, const std::string& name, + std::optional dtype, const bool nhwc); + + template + static cudnn_frontend::DataType_t GetDataType(); + cudnn_frontend::graph::Tensor_attributes Get() { return tensor_; } + + private: + cudnn_frontend::graph::Tensor_attributes tensor_; +}; +#endif + } // namespace cuda } // namespace onnxruntime #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_fe_call.cc b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc new file mode 100644 index 0000000000000..640025c248187 --- /dev/null +++ b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#include "core/providers/shared_library/provider_api.h" +#include +#if !defined(__CUDACC__) +#include +#endif +#ifdef _WIN32 +#else // POSIX +#include +#include +#endif + +namespace onnxruntime { + +using namespace common; + +template +const char* CudaErrString(ERRTYPE) { + ORT_NOT_IMPLEMENTED(); +} + +#if !defined(__CUDACC__) +#define CASE_ENUM_TO_STR_CUDNN_FE(x) \ + case cudnn_frontend::error_code_t::x: \ + return #x + +template <> +const char* CudaErrString(cudnn_frontend::error_t x) { + cudaDeviceSynchronize(); + LOGS_DEFAULT(ERROR) << x.get_message(); + switch (x.get_code()) { + CASE_ENUM_TO_STR_CUDNN_FE(OK); + CASE_ENUM_TO_STR_CUDNN_FE(ATTRIBUTE_NOT_SET); + CASE_ENUM_TO_STR_CUDNN_FE(SHAPE_DEDUCTION_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(INVALID_TENSOR_NAME); + CASE_ENUM_TO_STR_CUDNN_FE(INVALID_VARIANT_PACK); + CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_NOT_SUPPORTED); + CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_EXECUTION_PLAN_CREATION_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_EXECUTION_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(HEURISTIC_QUERY_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(UNSUPPORTED_GRAPH_FORMAT); + CASE_ENUM_TO_STR_CUDNN_FE(CUDA_API_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(CUDNN_BACKEND_API_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(INVALID_CUDA_DEVICE); + CASE_ENUM_TO_STR_CUDNN_FE(HANDLE_ERROR); + default: + return "Unknown CUDNN_FRONTEND error status"; + } +} + +template +int GetErrorCode(ERRTYPE err) { + return static_cast(err); +} + +template <> +int GetErrorCode(cudnn_frontend::error_t err) { + return static_cast(err.get_code()); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line) { + if (retCode != successCode) { + try { +#ifdef _WIN32 + std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + if (hostname_str.empty()) { + hostname_str = "?"; + } + const char* hostname = hostname_str.c_str(); +#else + char hostname[HOST_NAME_MAX]; + if (gethostname(hostname, HOST_NAME_MAX) != 0) + strcpy(hostname, "?"); +#endif + int currentCudaDevice; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", + libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice, + hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + LOGS_DEFAULT(ERROR) << str; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, + // so we'd never get to see the error + if constexpr (THRW) { + ORT_THROW(e.what()); + } else { + LOGS_DEFAULT(ERROR) << e.what(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return Status::OK(); + } +} + +template Status CudaCall( + cudnn_frontend::error_t retCode, const char* exprString, const char* libName, + cudnn_frontend::error_code_t successCode, const char* msg, const char* file, const int line); +template void CudaCall( + cudnn_frontend::error_t retCode, const char* exprString, const char* libName, + cudnn_frontend::error_code_t successCode, const char* msg, const char* file, const int line); + +#endif +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 764feadcf4cb3..95ba698b707ac 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -3,14 +3,20 @@ // Licensed under the MIT License. #include +#include +#include +#include "core/common/status.h" #include "core/providers/cuda/nn/conv.h" #include "core/common/span_utils.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/providers/cuda/tensor/slice.h" #include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#if CUDNN_MAJOR < 9 +// if compiled with cuDNN 8 we want to use the legacy cuDNN API +#include "conv_8.h" +#endif namespace onnxruntime { namespace cuda { @@ -43,58 +49,7 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) #endif -template -const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { - CUDNN_CONVOLUTION_FWD_ALGO_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_FFT, - CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, -}; - -cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, cudnnConvolutionFwdAlgo_t algo, size_t* sz) { - return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); -} - -size_t GetMaxWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, - const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { - // TODO: get maximum available size from memory arena - size_t free, total; - CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); - // Assuming 10% of fragmentation - free = static_cast(static_cast(free) * 0.9); - size_t max_ws_size = 0; - for (int i = 0; i < n_algo; i++) { - cudnnStatus_t err; - size_t sz; - err = GetWorkspaceSize(handle, s, algo[i], &sz); - if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; - max_ws_size = sz; - } - return max_ws_size; -} - -Status SliceOutUnwantedOutputSection(cudaStream_t stream, - const void* input_data, gsl::span input_dims, - void* output_data, - const gsl::span& output_dims, - const gsl::span& starts, - const gsl::span& ends, - const gsl::span& axes, - size_t element_size) { - SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); - - ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); - - // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape - ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); - - return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); -} - +// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { @@ -104,14 +59,20 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); + auto shape_size = orig_shape.GetDims().size(); - InlinedVector perm{0, 2, 3, 1}; - gsl::span permutation(perm.data(), 4); - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[3], - orig_shape[1]}; - W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + InlinedVector perm; + perm.push_back(0); + for (size_t i = 2; i < shape_size; i++) perm.push_back(i); + perm.push_back(1); + gsl::span permutation(perm.data(), shape_size); + + TensorShapeVector nhwc_dims; + for (size_t i = 0; i < shape_size; i++) { + nhwc_dims.push_back(orig_shape[perm[i]]); + } + + W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc)); auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), @@ -122,6 +83,8 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; + } else { + W_already_nhwc = true; } } else { ORT_UNUSED_PARAMETER(tensor); @@ -132,45 +95,205 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr return Status::OK(); } -template -Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { +#if CUDNN_MAJOR >= 9 +#if !defined(__CUDACC__) + +template +Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const Tensor* Z, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool bias_expected, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const { + s_.bias_fused = fuse_bias; + s_.act_fused = fuse_act; + s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change + s_.cudnn_fe_graph = std::make_unique(); + cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType(); + s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type); + if (data_type == cudnn_frontend::DataType_t::HALF) { + s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + } else { + s_.cudnn_fe_graph->set_compute_data_type(data_type); + } + + s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get()); + s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get()); + + auto conv_options = cudnn_frontend::graph::Conv_fprop_attributes() + .set_pre_padding(std::vector(pads.begin(), + pads.begin() + pads.size() / 2)) + .set_post_padding(std::vector(pads.begin() + pads.size() / 2, pads.end())) + .set_stride(strides) + .set_dilation(dilations); + s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_fprop(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options); + auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get(); + + if (!bias_expected && B == nullptr) { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + } else { + int64_t bias_size; + if (B != nullptr) { + bias_size = B->Shape()[0]; + } else { + bias_size = w_dims[0]; + } + + std::optional cudnn_fe_z_tensor; + if (Z) { + const auto& z_shape = Z->Shape().AsShapeVector(); + cudnn_fe_z_tensor = CudnnFeTensor(z_shape, "z", data_type, Layout == LAYOUT_NHWC).Get(); + } else if (fuse_bias && Layout == LAYOUT_NCHW) { + // Z is required for NCHW precompiled kernels in cuDNN + s_.z_data = s_.y_data; + cudnn_fe_z_tensor = cudnn_fe_y_tensor; + } + + if (fuse_bias) { + std::shared_ptr add_output; + if (cudnn_fe_z_tensor.has_value()) { + s_.cudnn_fe_Z = s_.cudnn_fe_graph->tensor(cudnn_fe_z_tensor.value()); + auto add_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + add_output = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_Z, add_options); + } else { + add_output = s_.cudnn_fe_conv_Y; + } + + onnxruntime::TensorShapeVector b_dims; + for (size_t i = 0; i < x_dims.size(); i++) { + b_dims.push_back(i == 1 ? bias_size : 1); + } + auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get(); + auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(add_output, s_.cudnn_fe_B, bias_options); + } else { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + + TensorShapeVector b_dims(y_dims.size(), 1); + TensorShapeVector b_strides(y_dims.size(), 1); + b_dims[1] = bias_size; + b_strides[0] = bias_size; + if (Z) { + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().AsShapeVector(), + CudnnTensor::GetDataType(), + cudnn_fe_z_tensor->get_stride())); + } + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), b_strides)); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType(), cudnn_fe_y_tensor.get_stride())); + + /* Creating an own CUDNN Frontend graph for the bias addition. + s_.cudnn_fe_bias_graph = std::make_unique(); + s_.cudnn_fe_bias_graph->set_io_data_type(data_type) + .set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ? + cudnn_frontend::DataType_t::FLOAT : data_type) + .set_intermediate_data_type(data_type); + s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor(y_dims, "x", data_type).Get()); + + s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor); + s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options); + s_.cudnn_fe_bias_Y->set_output(true); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate()); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode})); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/ + } + } + if (fuse_act && s_.cudnn_fe_act_attr.has_value()) { + auto& activation_attr = s_.cudnn_fe_act_attr.value(); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr); + } + + s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim()); + s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride()); + s_.cudnn_fe_Y->set_output(true); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate()); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); + } catch (const std::exception& ex) { + std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE}); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); + } catch (const std::exception& ex) { + if (!fuse_bias && !fuse_act && use_tf32) { + std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + // Try fallback. + return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, Z, y_dims, handle, heur_mode, + pads, strides, dilations, bias_expected, false, false, w_in_nhwc, true); + } + + s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size(); + return Status::OK(); +} + +#endif + +template +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { + constexpr bool channels_last = Layout; + // set X const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); + // X incl. x_dims is in NHWC Format iff. NHWC == true const auto x_dims = x_shape.AsShapeVector(); + s_.x_data = reinterpret_cast(X->Data()); s_.element_size = X->DataType()->Size(); // set W + bool w_in_nhwc; const Tensor* W; if (!W_) { W = context->Input(1); + w_in_nhwc = W_already_nhwc; + // Dims and memory layout are in NCHW format } else { W = W_.get(); + w_in_nhwc = true; + // W got prepacked, therfore if NHWC == true, then dims and memory layout are in NHWC } const TensorShape& w_shape = W->Shape(); - auto w_dims = w_shape.AsShapeVector(); + onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector(); s_.w_data = reinterpret_cast(W->Data()); - // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. - constexpr bool channels_last = NHWC; - if constexpr (channels_last) { - if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); - } - } - // set B + // Always in NCHW format + const Tensor* B = nullptr; if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); + B = context->Input(2); s_.b_data = reinterpret_cast(B->Data()); } else { s_.b_data = nullptr; } + // set Z + const Tensor* Z = nullptr; if (context->InputCount() >= 4) { - const Tensor* Z = context->Input(3); - ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType())); + Z = context->Input(3); s_.z_data = reinterpret_cast(Z->Data()); } else { s_.z_data = nullptr; @@ -183,13 +306,12 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_results.clear(); } - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, w_in_nhwc)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, channels_last)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, w_in_nhwc)); const size_t kernel_rank = kernel_shape.size(); @@ -211,59 +333,46 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const int64_t N = X->Shape()[0]; const int64_t M = W->Shape()[0]; - if (channels_last) { + + if constexpr (channels_last) { y_dims.push_back(N); } else { y_dims.insert(y_dims.begin(), {N, M}); } - bool post_slicing_required = false; - TensorShapeVector slice_starts; - slice_starts.reserve(kernel_rank); - - TensorShapeVector slice_ends; - slice_ends.reserve(kernel_rank); - - TensorShapeVector slice_axes; - slice_axes.reserve(kernel_rank); - constexpr size_t spatial_dim_start = channels_last ? 1 : 2; const size_t spatial_dim_end = spatial_dim_start + kernel_rank; TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); - TensorShapeVector y_dims_with_adjusted_pads(y_dims); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, - strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, - post_slicing_required, slice_starts, slice_ends, slice_axes, - channels_last)); - if (channels_last) { + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(spatial_shape, kernel_shape, + strides, dilations, pads, y_dims)); + if constexpr (channels_last) { y_dims.push_back(M); - y_dims_with_adjusted_pads.push_back(M); } - ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); s_.y_dims = gsl::make_span(y_dims); - s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; - s_.post_slicing_required = post_slicing_required; - s_.slice_starts = slice_starts; - s_.slice_ends = slice_ends; - s_.slice_axes = slice_axes; - s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (post_slicing_required) { - // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); - s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); - } else { - // No post slicing needed. Fill the output tensor's buffer directly. - s_.y_data = reinterpret_cast(s_.Y->MutableData()); - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); const CUDAExecutionProvider* cuda_ep = static_cast(this->Info().GetExecutionProvider()); TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; - TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; + TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()}; + + if constexpr (channels_last) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1)); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1)); + x_dims_cudnn.erase(x_dims_cudnn.end() - 1); + y_dims_cudnn.erase(y_dims_cudnn.end() - 1); + + if (w_in_nhwc) { + w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1)); + w_dims_cudnn.erase(w_dims_cudnn.end() - 1); + } + } + if (kernel_rank < 2) { // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] // especially for EXHAUSTIVE algo search which may result in a better algo selection. @@ -276,7 +385,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (cuda_ep->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - w_dims.insert(w_dims.begin() + 2, 1); + w_dims_cudnn.insert(w_dims.begin() + 2, 1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); @@ -285,7 +394,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } else { x_dims_cudnn.push_back(1); y_dims_cudnn.push_back(1); - w_dims.push_back(1); + w_dims_cudnn.push_back(1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.end(), 0); kernel_shape.push_back(1); @@ -294,188 +403,105 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } } - if (w_dims_changed) { - if (!channels_last) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, - CudnnTensor::GetDataType(), - static_cast(w_dims[0]), - static_cast(w_dims[3]), - static_cast(w_dims[1]), - static_cast(w_dims[2]))); - } - } - // We must delay returning early until here so that the weight dims have been cached properly if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - if (channels_last) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, - CudnnTensor::GetDataType(), - static_cast(x_dims_cudnn[0]), - static_cast(x_dims_cudnn[3]), - static_cast(x_dims_cudnn[1]), - static_cast(x_dims_cudnn[2]))); - - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, - CudnnTensor::GetDataType(), - static_cast(y_dims_cudnn[0]), - static_cast(y_dims_cudnn[3]), - static_cast(y_dims_cudnn[1]), - static_cast(y_dims_cudnn[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + auto handle = GetCudnnHandle(context); + + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); +#if !defined(__CUDACC__) + cudnn_frontend::HeurMode_t heur_mode; + switch (cudnn_conv_algo) { + case 0: + heur_mode = cudnn_frontend::HeurMode_t::B; + break; + case 1: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + case 2: + heur_mode = cudnn_frontend::HeurMode_t::FALLBACK; + LOGS_DEFAULT(WARNING) << "OP " << CudaKernel::Node().OpType() << "(" << CudaKernel::Node().Name() + << ") running in Fallback mode. May be extremely slow."; + break; + default: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; } - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, - gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), - UseTF32())); - - if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); - const auto& b_shape = B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = b_shape[0]; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - // s_.b_data = reinterpret_cast(B->Data()); - } else if (bias_expected) { - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = w_dims[0]; - auto malloc_size = b_dims[1] * sizeof(CudaT); - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - if (s_.b_zero) { - CUDA_CALL_THROW(cudaFree(s_.b_zero)); - s_.b_zero = nullptr; - } - CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); - CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); - } - - if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { - // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - } else if constexpr (std::is_same::value) { - if (!UseTF32()) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); - } - } - - cudnnConvolutionFwdAlgoPerf_t perf; - int algo_count = 1; - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); - ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); - switch (cudnn_conv_algo) { - case 0: { - static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( - GetCudnnHandle(context), - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.y_tensor, - s_.y_data, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf, - algo_search_workspace.get(), - max_ws_size)); - break; - } - case 1: - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( - GetCudnnHandle(context), - s_.x_tensor, - s_.w_desc, - s_.conv_desc, - s_.y_tensor, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf)); - break; - - default: - perf.algo = kDefaultConvAlgo; - CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); - - if constexpr (std::is_same::value) { - perf.mathType = CUDNN_TENSOR_OP_MATH; - } else if (std::is_same::value && !UseTF32()) { - perf.mathType = CUDNN_FMA_MATH; - } else { - perf.mathType = CUDNN_DEFAULT_MATH; - } - } - s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); - } - const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; + const auto use_tf32 = cuda_ep->UseTF32(); + // fuse if this op is part of a FusedConv or if the EP is set to fuse ops + const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_act = is_fused_node_; + + ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, Z, y_dims_cudnn, handle, heur_mode, + std::vector(pads.begin(), + pads.end()), + std::vector(strides.begin(), + strides.end()), + std::vector(dilations.begin(), + dilations.end()), + bias_expected, fuse_bias, fuse_act, w_in_nhwc, use_tf32)); +#endif } else { // set Y s_.Y = context->Output(0, s_.y_dims); if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - if (s_.post_slicing_required) { - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); - s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); - } else { - s_.y_data = reinterpret_cast(s_.Y->MutableData()); - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); } return Status::OK(); } -template -Status Conv::ComputeInternal(OpKernelContext* context) const { +template +Status Conv::ComputeInternal(OpKernelContext* context) const { std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + const auto alpha = onnxruntime::cuda::Consts::One; auto cudnn_handle = GetCudnnHandle(context); - CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle, - &alpha, - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.algo, - workspace.get(), - s_.workspace_bytes, - &beta, - s_.y_tensor, - s_.y_data)); - if (nullptr != s_.b_data) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, +#if !defined(__CUDACC__) + s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast(s_.x_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast(s_.w_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data); + if (s_.bias_fused && s_.b_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + } + if (s_.bias_fused && s_.z_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast(s_.z_data)); + if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) { + // memset Z if it's required for a succesful fusion + CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); + } + } + auto ws = GetWorkSpace(context->GetComputeStream()); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, + s_.variant_pack, + ws.get())); + + if (!s_.bias_fused && s_.z_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data, &alpha, s_.y_tensor, s_.y_data)); } - // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions - // This may have lead to extra results that are unnecessary and hence we slice that off here - if (s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), - s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, - s_.slice_ends, s_.slice_axes, s_.element_size)); + if (!s_.bias_fused && s_.b_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + + /* For the standalone bias addition graph. + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle, + s_.variant_pack_bias, + GetWorkSpace(context->GetComputeStream()).get()));*/ } +#endif + return Status::OK(); } @@ -536,6 +562,7 @@ Status CudnnConvolutionDescriptor::Set( return Status::OK(); } +#endif #ifndef DISABLE_CONTRIB_OPS // template instantiation for NhwcConv diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 3aec654224e39..484d66081018b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -6,6 +6,12 @@ #include #include +#include +#include + +#if !defined(__CUDACC__) +#include +#endif #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_kernel.h" @@ -150,6 +156,24 @@ struct CudnnConvState { CudnnTensor z_tensor; const void* z_data = nullptr; CudnnConvolutionDescriptor conv_desc; + bool bias_fused = true; + bool act_fused = true; + +#if !defined(__CUDACC__) + std::unique_ptr cudnn_fe_graph; + std::unique_ptr cudnn_fe_bias_graph; + std::shared_ptr cudnn_fe_X; + std::shared_ptr cudnn_fe_W; + std::shared_ptr cudnn_fe_conv_Y; + std::shared_ptr cudnn_fe_Z; + std::shared_ptr cudnn_fe_B; + std::shared_ptr cudnn_fe_Y; + + std::optional cudnn_fe_act_attr = std::nullopt; + + std::unordered_map, void*> variant_pack; + std::unordered_map, void*> variant_pack_bias; +#endif struct PerfResultParams { decltype(AlgoPerfType().algo) algo; @@ -183,7 +207,7 @@ enum : size_t { // ONNX Conv operator uses NCHW format for input, weights and output. // NhwcConv contrib ops uses NHWC format: last dimension of input, weights and output are channels. -template +template class Conv : public CudaKernel { public: using CudaT = typename ToCudaType::MappedType; @@ -205,12 +229,32 @@ class Conv : public CudaKernel { } Status UpdateState(OpKernelContext* context, bool bias_expected = false) const; + +#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9 + Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const Tensor* Z, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool bias_expected, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const; +#endif + ConvAttributes conv_attrs_; mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - static const cudnnConvolutionFwdAlgo_t kAllAlgos[]; std::unique_ptr W_; - bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_fused_node_ = false; // ensures the node is fused although the session option is not set + bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain }; Status SliceOutUnwantedOutputSection(cudaStream_t stream, diff --git a/onnxruntime/core/providers/cuda/nn/conv_8.h b/onnxruntime/core/providers/cuda/nn/conv_8.h new file mode 100644 index 0000000000000..10239d09041fe --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/conv_8.h @@ -0,0 +1,484 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#include "core/common/status.h" +#include "core/providers/cuda/nn/conv.h" +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#include "core/providers/cuda/tensor/slice.h" + +namespace onnxruntime { +namespace cuda { + +static const cudnnConvolutionFwdAlgo_t kAllAlgos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, +}; + +static cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, cudnnConvolutionFwdAlgo_t algo, size_t* sz) { + return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); +} + +size_t GetMaxWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, + const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { + // TODO: get maximum available size from memory arena + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation + free = static_cast(static_cast(free) * 0.9); + size_t max_ws_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t err; + size_t sz; + err = GetWorkspaceSize(handle, s, algo[i], &sz); + if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; + max_ws_size = sz; + } + return max_ws_size; +} + +Status SliceOutUnwantedOutputSection(cudaStream_t stream, + const void* input_data, gsl::span input_dims, + void* output_data, + const gsl::span& output_dims, + const gsl::span& starts, + const gsl::span& ends, + const gsl::span& axes, + size_t element_size) { + SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); + + ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); + + // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape + ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); + + return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); +} + +template +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + const auto x_dims = x_shape.AsShapeVector(); + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + bool w_in_nhwc; + const Tensor* W; + if (!W_) { + W = context->Input(1); + w_in_nhwc = W_already_nhwc; + // Dims and memory layout are in NCHW format + } else { + W = W_.get(); + w_in_nhwc = true; + // W got prepacked, therfore if NHWC == true, then dims and memory layout are in NHWC + } + const TensorShape& w_shape = W->Shape(); + auto w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. + constexpr bool channels_last = NHWC; + if constexpr (channels_last) { + if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } + } + + // set B + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } + // set Z + if (context->InputCount() >= 4) { + const Tensor* Z = context->Input(3); + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType())); + s_.z_data = reinterpret_cast(Z->Data()); + } else { + s_.z_data = nullptr; + } + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, w_in_nhwc)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, w_in_nhwc)); + + const size_t kernel_rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(kernel_rank * 2, 0); + } + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_rank, 1); + } + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_rank, 1); + } + + TensorShapeVector y_dims; + y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' + + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + if (channels_last) { + y_dims.push_back(N); + } else { + y_dims.insert(y_dims.begin(), {N, M}); + } + + bool post_slicing_required = false; + TensorShapeVector slice_starts; + slice_starts.reserve(kernel_rank); + + TensorShapeVector slice_ends; + slice_ends.reserve(kernel_rank); + + TensorShapeVector slice_axes; + slice_axes.reserve(kernel_rank); + + constexpr size_t spatial_dim_start = channels_last ? 1 : 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); + + TensorShapeVector y_dims_with_adjusted_pads(y_dims); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, + strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, + post_slicing_required, slice_starts, slice_ends, slice_axes, + channels_last)); + if (channels_last) { + y_dims.push_back(M); + y_dims_with_adjusted_pads.push_back(M); + } + + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); + s_.y_dims = gsl::make_span(y_dims); + s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; + s_.post_slicing_required = post_slicing_required; + s_.slice_starts = slice_starts; + s_.slice_ends = slice_ends; + s_.slice_axes = slice_axes; + + s_.Y = context->Output(0, TensorShape(s_.y_dims)); + if (post_slicing_required) { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + // No post slicing needed. Fill the output tensor's buffer directly. + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], especially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + if (w_dims_changed) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } else if (w_in_nhwc) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(w_dims[0]), + static_cast(w_dims[3]), + static_cast(w_dims[1]), + static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(w_dims[0]), + static_cast(w_dims[1]), + static_cast(w_dims[2]), + static_cast(w_dims[3]))); + } + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + + if (channels_last) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(x_dims_cudnn[0]), + static_cast(x_dims_cudnn[3]), + static_cast(x_dims_cudnn[1]), + static_cast(x_dims_cudnn[2]))); + + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(y_dims_cudnn[0]), + static_cast(y_dims_cudnn[3]), + static_cast(y_dims_cudnn[1]), + static_cast(y_dims_cudnn[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + } + + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); + + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + const auto& b_shape = B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = b_shape[0]; + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + // s_.b_data = reinterpret_cast(B->Data()); + } else if (bias_expected) { + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = w_dims[0]; + auto malloc_size = b_dims[1] * sizeof(CudaT); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + if (s_.b_zero) { + CUDA_CALL_THROW(cudaFree(s_.b_zero)); + s_.b_zero = nullptr; + } + CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); + CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); + } + + if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionFwdAlgoPerf_t perf; + int algo_count = 1; + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); + switch (cudnn_conv_algo) { + case 0: { + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + GetCudnnHandle(context), + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.y_tensor, + s_.y_data, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf, + algo_search_workspace.get(), + max_ws_size)); + break; + } + case 1: + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( + GetCudnnHandle(context), + s_.x_tensor, + s_.w_desc, + s_.conv_desc, + s_.y_tensor, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf)); + break; + + default: + perf.algo = kDefaultConvAlgo; + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); + + if constexpr (std::is_same::value) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; + } else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + } + s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); + } + const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (s_.post_slicing_required) { + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + } + return Status::OK(); +} + +template +Status Conv::ComputeInternal(OpKernelContext* context) const { + std::lock_guard lock(s_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context)); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + auto cudnn_handle = GetCudnnHandle(context); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle, + &alpha, + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.algo, + workspace.get(), + s_.workspace_bytes, + &beta, + s_.y_tensor, + s_.y_data)); + if (nullptr != s_.b_data) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + } + // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions + // This may have lead to extra results that are unnecessary and hence we slice that off here + if (s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), + s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, + s_.slice_ends, s_.slice_axes, s_.element_size)); + } + return Status::OK(); +} + +CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() : desc_(nullptr) { +} + +CudnnConvolutionDescriptor::~CudnnConvolutionDescriptor() { + if (desc_ != nullptr) { + cudnnDestroyConvolutionDescriptor(desc_); + desc_ = nullptr; + } +} + +Status CudnnConvolutionDescriptor::Set( + size_t rank, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, + int groups, + cudnnConvolutionMode_t mode, + cudnnDataType_t data_type, + bool use_tf32) { + if (!desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); + + InlinedVector pad_dims(rank); + InlinedVector stride_dims(rank); + InlinedVector dilation_dims(rank); + for (size_t i = 0; i < rank; i++) { + pad_dims[i] = gsl::narrow_cast(pads[i]); + stride_dims[i] = gsl::narrow_cast(strides[i]); + dilation_dims[i] = gsl::narrow_cast(dilations[i]); + } + + // This piece of code is copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h + // Setting math_type to CUDNN_DATA_FLOAT for half input + cudnnDataType_t math_type = data_type; + if (data_type == CUDNN_DATA_HALF) math_type = CUDNN_DATA_FLOAT; + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionNdDescriptor( + desc_, + gsl::narrow_cast(rank), + pad_dims.data(), + stride_dims.data(), + dilation_dims.data(), + mode, + math_type)); + + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(desc_, groups)); + + // Copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h + // See Note [behavior of cudnnFind and cudnnGet] at /pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); + if (data_type == CUDNN_DATA_HALF) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); + } + + return Status::OK(); +} +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h index 51a5631b930b0..2b2b726e62c79 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h @@ -4,16 +4,16 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_pch.h" - namespace onnxruntime { // ----------------------------------------------------------------------- // Error handling // ----------------------------------------------------------------------- -template +template std::conditional_t CudaCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line); #define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) #define CUBLAS_CALL(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h new file mode 100644 index 0000000000000..a51d84a7efa59 --- /dev/null +++ b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_pch.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#if !defined(__CUDACC__) +#include +#endif +namespace onnxruntime { + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- + +#define CUDNN_FE_CALL(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ + cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) +#define CUDNN_FE_CALL_THROW(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ + cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 053c66ddcb34a..240272923a3a7 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -15,9 +15,6 @@ #include "core/providers/cuda/cuda_common.h" -using namespace onnxruntime; -using namespace onnxruntime::cuda; - // Generalize library calls to be use in template functions inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, @@ -84,7 +81,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, half* C, int ldc, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { return cublasGemmEx(handle, @@ -127,7 +124,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, half* C, int ldc, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { // The alpha and beta shall have same precision as compute type. @@ -162,8 +159,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, #if defined(USE_CUDA) inline cublasStatus_t cublasGemmHelper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, - const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, + int n, int k, const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* A, int lda, + const onnxruntime::BFloat16* B, int ldb, const onnxruntime::BFloat16* beta, onnxruntime::BFloat16* C, int ldc, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -174,8 +171,9 @@ inline cublasStatus_t cublasGemmHelper( } #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, - const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*, - BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { + const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int, + const onnxruntime::BFloat16*, int, const onnxruntime::BFloat16*, + onnxruntime::BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -250,7 +248,7 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, int batch_count, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { return cublasGemmBatchedEx(handle, @@ -286,9 +284,9 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, #if defined(USE_CUDA) inline cublasStatus_t cublasGemmBatchedHelper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], - int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, - BFloat16* Carray[], int ldc, int batch_count, + int m, int n, int k, const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* Aarray[], + int lda, const onnxruntime::BFloat16* Barray[], int ldb, const onnxruntime::BFloat16* beta, + onnxruntime::BFloat16* Carray[], int ldc, int batch_count, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -300,8 +298,9 @@ inline cublasStatus_t cublasGemmBatchedHelper( } #else inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, - const BFloat16*, const BFloat16*[], int, const BFloat16*[], int, - const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&, + const onnxruntime::BFloat16*, const onnxruntime::BFloat16*[], int, + const onnxruntime::BFloat16*[], int, const onnxruntime::BFloat16*, + onnxruntime::BFloat16*[], int, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } @@ -314,12 +313,12 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const float* alpha, const float* A, int lda, - long long int strideA, + int64_t strideA, const float* B, int ldb, - long long int strideB, + int64_t strideB, const float* beta, float* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& prop, bool use_tf32) { @@ -349,12 +348,12 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const double* alpha, const double* A, int lda, - long long int strideA, + int64_t strideA, const double* B, int ldb, - long long int strideB, + int64_t strideB, const double* beta, double* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { @@ -376,16 +375,16 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const __half* alpha, const __half* A, int lda, - long long int strideA, + int64_t strideA, const __half* B, int ldb, - long long int strideB, + int64_t strideB, const __half* beta, __half* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { return cublasGemmStridedBatchedEx(handle, @@ -425,16 +424,16 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const float* alpha, const __half* A, int lda, - long long int strideA, + int64_t strideA, const __half* B, int ldb, - long long int strideB, + int64_t strideB, const float* beta, __half* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { // The alpha and beta shall have same precision as compute type. @@ -472,10 +471,10 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, inline cublasStatus_t cublasGemmStridedBatchedHelper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const BFloat16* alpha, const BFloat16* A, int lda, - long long int strideA, const BFloat16* B, int ldb, - long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, - long long int strideC, int batch_count, + const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* A, int lda, + int64_t strideA, const onnxruntime::BFloat16* B, int ldb, + int64_t strideB, const onnxruntime::BFloat16* beta, onnxruntime::BFloat16* C, int ldc, + int64_t strideC, int batch_count, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -488,9 +487,9 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper( #else inline cublasStatus_t cublasGemmStridedBatchedHelper( cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, - int, const BFloat16*, const BFloat16*, int, long long int, - const BFloat16*, int, long long int, const BFloat16*, BFloat16*, - int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) { + int, const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int, int64_t, + const onnxruntime::BFloat16*, int, int64_t, const onnxruntime::BFloat16*, onnxruntime::BFloat16*, + int, int64_t, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -531,4 +530,5 @@ cublasStatus_t cublasCopyHelper( cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); cublasStatus_t cublasCopyHelper( - cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); + cudaStream_t stream, cublasHandle_t handle, int n, const onnxruntime::BFloat16* x, + int incx, onnxruntime::BFloat16* y, int incy); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 62e9cfe24c367..6d6940590602a 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1782,6 +1782,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.cudnn_conv_use_max_workspace = 1; cuda_options_converted.enable_cuda_graph = 0; cuda_options_converted.prefer_nhwc = 0; + cuda_options_converted.fuse_conv_bias = 0; cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; cuda_options_converted.use_ep_level_unified_stream = 0; diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc index 3f298b0a8f8e1..e780d35df08bd 100644 --- a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc @@ -30,7 +30,8 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, bool use_float16, bool weight_is_initializer = false) { int min_cuda_architecture = use_float16 ? 530 : 0; - bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + // NHWC implementation doesn't handle W in NHWC layout if it's not an initializer + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && weight_is_initializer; bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 9886d98dcc6d6..343027b42c894 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -438,6 +438,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (enable_cuda) { #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.device_id = device_id; cuda_options.do_copy_in_default_stream = true; cuda_options.use_tf32 = false; // TODO: Support arena configuration for users of test runner diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 3e4e845440117..7ab0268c3509c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1993,67 +1993,22 @@ TEST_F(GraphTransformationTests, NotWhereFusion) { ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where } -#if (defined(USE_CUDA) || defined(USE_JSEP)) && !defined(DISABLE_CONTRIB_OPS) -// Conv->Add->Relu will be transformed to FusedConv -TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_TRUE(op_to_count["Relu"] == 1); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph - ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph -} - -// Currently the ConvAddRelu fusion is only backed by a float kernel for the -// the CUDA EP. - -// When we see the corresponding pattern for the fp16 data type, the fusion -// should not be triggered as there is no kernel to back the fused pattern. - -// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of -// data type support in other compatible EPs is still yet to be ascertained. - -// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove -// this test. -TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["Add"], 1); - ASSERT_EQ(op_to_count["Relu"], 1); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered) - ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered) -} - +#if !defined(DISABLE_CONTRIB_OPS) // Conv->Add->Relu will be left intact since there is Identity depend on Add TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); +#if defined(USE_JSEP) for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kJsExecutionProvider); } +#else + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kCpuExecutionProvider); + } +#endif std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 1); ASSERT_TRUE(op_to_count["Relu"] == 1); @@ -2073,9 +2028,15 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) { std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); +#if defined(USE_JSEP) for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kJsExecutionProvider); } +#else + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kCpuExecutionProvider); + } +#endif std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 1); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -2165,17 +2126,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); -#ifdef USE_CUDA - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } -#elif defined(USE_ROCM) +#if defined(USE_JSEP) for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kJsExecutionProvider); } -#elif defined(USE_JSEP) +#else for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kJsExecutionProvider); + node.SetExecutionProviderType(kCpuExecutionProvider); } #endif std::map op_to_count_before_fusion = CountOpsInGraph(graph); @@ -2187,14 +2144,7 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count_after_fusion = CountOpsInGraph(graph); -#if defined(USE_CUDA) || defined(USE_ROCM) - std::set cuda_rocm_supported = {"Relu"}; - if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) { - ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); - } else { - ASSERT_EQ(op_to_count_after_fusion[model.second], 0); - } -#elif defined(USE_JSEP) +#if defined(USE_JSEP) std::set js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"}; if (js_supported.find(model.second) == js_supported.end()) { ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 2d885ee9d479f..25caa732efa25 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -25,6 +25,7 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, const std::initializer_list& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, + optional epsilon = optional(), OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", int opset = 7) { @@ -56,11 +57,13 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, test.AddOutput("Y", expected_output_shape, expected_output); + if (epsilon.has_value()) { + test.SetOutputTolerance(*epsilon); + } + std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); - // Disable CUDA NHWC execution provider as it is currently flaky - excluded_providers.insert(kCudaNHWCExecutionProvider); // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs. excluded_providers.insert(kQnnExecutionProvider); @@ -189,10 +192,15 @@ TEST(ConvTest, Conv1D_Bias) { vector Y_shape = {2, 1, 4}; auto expected_vals = {0.37892162799835205f, 0.4625728130340576f, 0.4934738576412201f, 0.44801419973373413f, 0.37892162799835205f, 0.2499445676803589f, 0.31682088971138f, 0.32773756980895996f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // For the CUDA EP: Due to CUDNN Frontend using TF32 for FP32 operations we get a higher error than using FP32 only, + // as TF32 has a 10 bit mantissa. + float epsilon = 1.1e-5f; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon); // CoreML EP requires weight to be an initializer - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon); } // Conv47 @@ -240,7 +248,7 @@ TEST(ConvTest, Conv1D_Invalid_Input_Shape) { vector X_shape = {1, 1, 1}; vector dummy_shape = {1, 1, 2}; auto dummy_vals = {0.0f, 0.0f}; - TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, + TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, optional(), OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " "Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2", @@ -263,7 +271,7 @@ TEST(ConvTest, Conv2D_Invalid_Input_Shape) { vector dummy_shape = {2, 2, 1, 2}; auto dummy_vals = {-0.0f, 0.0f, -0.0f, -0.0f, -0.0f, 0.0f, -0.0f, -0.0f}; - TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, + TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, optional(), OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " "Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0", @@ -620,7 +628,12 @@ TEST(ConvTest, Conv3D_Bias) { -0.47542816400527954f, -0.5078460574150085f, -0.4205915927886963f, -0.5584549903869629f, -0.39770257472991943f, -0.45317384600639343f, -0.5598302483558655f, -0.2542789578437805f, -0.5359901785850525f, -0.48090484738349915f, -0.38603779673576355f, -0.4991581439971924f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // For the CUDA EP: Due to CUDNN Frontend using TF32 for FP32 operations we get a higher error than using FP32 only, + // as TF32 has a 10 bit mantissa. + float epsilon = 2.1e-4f; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon); } TEST(ConvTest, Conv2D_group) { @@ -902,7 +915,8 @@ TEST(ConvTest, ConvDimWithZero) { // not handled by ACL attrs.excluded_providers.insert(kAclExecutionProvider); - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, OpTester::ExpectResult::kExpectSuccess, "", 10); + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, optional(), + OpTester::ExpectResult::kExpectSuccess, "", 10); } TEST(ConvTest, Conv1D_asymmetric_padding) { diff --git a/setup.py b/setup.py index 51feedcfd3286..1fa297e22acd9 100644 --- a/setup.py +++ b/setup.py @@ -196,6 +196,7 @@ def run(self): to_preload_cann = [] cuda_dependencies = [ + "libcuda.so.1", "libcublas.so.11", "libcublas.so.12", "libcublasLt.so.11", diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 01965343c4592..45c79a677b68d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.173 + version: 1.0.175 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.173 + version: 1.0.175 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 54d6614ad6e25c1519f9e1829c9c3650be89be30 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 2 Aug 2024 15:45:42 -0700 Subject: [PATCH 37/40] Security fuzz address sanitizer fix Bug (continue) (#21579) Add a check of node.InputDefs()[2]->Exists() for Layernorm bias (Follow up https://github.com/microsoft/onnxruntime/pull/21528/files#r1694026327) Format the file: break long line to be within 120 chars limit. --- .../core/optimizer/attention_fusion.cc | 53 +++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 64a38214caff0..ff8943de79679 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -50,7 +50,8 @@ void MergeWeights(const T* q, const T* k, const T* v, std::vector& result, in // Merge 2-D weights (q, k and v) by concatenating them row by row. template -void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, std::vector& result, int64_t hidden_size) { +void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, + std::vector& result, int64_t hidden_size) { const T* q = q_weight; const T* k = k_weight; const T* v = v_weight; @@ -144,7 +145,8 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, return graph_utils::AddInitializer(graph, initializer); } -static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, const logging::Logger& logger) { +static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, + const logging::Logger& logger) { // Validate mask input shape (batch_size, sequence_length) and data type. // Note that batch_size and sequence_length could be symbolic. const TensorShapeProto* mask_shape = mask_input->Shape(); @@ -208,9 +210,11 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert + // Add node.GetOutputEdgesCount() == 5/6 for distilbert + if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && - graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.InputDefs().size() > 2) { + graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && + node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { // Bias is an optional input for LayerNorm // Get hidden size from layer norm bias tensor shape. const NodeArg& layer_norm_bias = *(node.InputDefs()[2]); if (!optimizer_utils::IsShapeKnownOnAllDims(layer_norm_bias, 1)) { @@ -242,8 +246,10 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, fused_count++; modified = true; } - } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (static_cast(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT - if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) { + } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && + (static_cast(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT + if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, + logger)) { fused_count++; modified = true; } @@ -301,7 +307,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, return false; } - if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), q_reshape, q_transpose, num_heads, head_size, logger)) { + if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), + q_reshape, q_transpose, num_heads, head_size, logger)) { DEBUG_LOG("CheckNodesInPathQ returns false"); return false; } @@ -365,7 +372,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, } // Now everything is ready, we will start fusing subgraph. - NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), logger); + NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), + logger); if (nullptr == mask_int32) { DEBUG_LOG("Failed to convert mask to int32"); return false; @@ -438,7 +446,8 @@ static bool FuseSubGraphQK(Node& layer_norm, } std::vector nodes_to_remove; - if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, + if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, + mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -529,7 +538,8 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm, } std::vector nodes_to_remove; - if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, + if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, + mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -615,7 +625,12 @@ After Fusion: | | Add */ -bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, const logging::Logger& logger) { +bool AttentionFusion::FuseSubGraph(Node& layer_norm, + const Node& add_after_layer_norm, + Graph& graph, + int64_t hidden_size, + std::map& mask_int32_map, + const logging::Logger& logger) { std::vector parent_path{ {0, 0, "Add", {7, 13}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, @@ -657,7 +672,9 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer int64_t num_heads = 0; // will be updated in CheckNodesInPathV int64_t head_size = 0; // will be updated in CheckNodesInPathV NodeIndex record_node_idx = 0; // will be updated in CheckNodesInPathV if it's distilbert model - if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, qkv_matmul, v_transpose, v_reshape, num_heads, head_size, hidden_size, record_node_idx, logger)) { + if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, + qkv_matmul, v_transpose, v_reshape, num_heads, + head_size, hidden_size, record_node_idx, logger)) { DEBUG_LOG("CheckNodesInPathV return false"); return false; } @@ -672,7 +689,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer } // store parent path - std::vector> parent_path_nodes{reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul}; + std::vector> parent_path_nodes{ + reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul}; // Find mask nodes: Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax --> [MatMul] // The "Cast" node in parentheses is optional. @@ -681,10 +699,13 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, qkv_matmul, mask_nodes, logger, false)) { NodeArg* mask_input = graph.GetNode(mask_nodes.unsqueeze_1->Index())->MutableInputDefs()[0]; - return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); - } else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, mask_nodes_distilbert, record_node_idx, logger)) { + return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, + parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); + } else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, + mask_nodes_distilbert, record_node_idx, logger)) { NodeArg* mask_input = graph.GetNode(mask_nodes_distilbert.equal->Index())->MutableInputDefs()[0]; - return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); + return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, + parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); } else { DEBUG_LOG("Failed in match input mask subgraph"); return false; From 45b7c41ef04fa91cc41a1726d3ee4c0bb9aabaa8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Fri, 2 Aug 2024 19:19:04 -0400 Subject: [PATCH 38/40] [MIGraphX EP] Set External Data Path (#21598) ### Description Changes to add in Set external data path for model weight files. Additional fixes to ensure this compiles off the latest v1.19 Onnxruntime ### Motivation and Context Separate weights used for larger models (like stable diffusion) is motivation for this change set --------- Co-authored-by: Jeff Daily Co-authored-by: Artur Wojcik Co-authored-by: Ted Themistokleous --- .../core/providers/migraphx/migraphx_execution_provider.cc | 7 +++++-- .../core/providers/migraphx/migraphx_execution_provider.h | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 097b16ecde536..314e278695c49 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -990,6 +991,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); + model_path_ = graph_viewer.ModelPath(); // dump onnx file if environment var is set if (dump_model_ops_) { @@ -1168,7 +1170,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto param_shapes = prog.get_parameter_shapes(); // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map) { + for (auto& [cal_key, cal_val] : dynamic_range_map_) { auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } @@ -1217,7 +1219,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map, + int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; *state = p.release(); @@ -1297,6 +1299,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index f34ca320d0a5a..21b582de8f86e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -11,6 +11,7 @@ #include #include +#include namespace onnxruntime { @@ -91,7 +92,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_calibration_cache_available_ = false; bool int8_use_native_migraphx_calibration_table_ = false; std::string calibration_cache_path_; - std::unordered_map dynamic_range_map; + std::unordered_map dynamic_range_map_; bool save_compiled_model_ = false; std::string save_compiled_path_; bool load_compiled_model_ = false; @@ -100,6 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; + mutable std::filesystem::path model_path_; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; From 8c641d718250cd839844316eff2aed082c022227 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 3 Aug 2024 07:25:04 +0800 Subject: [PATCH 39/40] [WebNN EP] Support Dropout op (#21586) ### Description WebNN only supports test mode, so we don't care about other inputs or attributes about training mode, use WebNN's identity op to implement the Dropout op directly. --- js/web/docs/webnn-operators.md | 1 + .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/dropout_op_builder.cc | 101 ++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 4 + .../webnn/builders/op_builder_factory.h | 1 + 5 files changed, 108 insertions(+) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 8711d4d20e370..9934c758621cb 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -25,6 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | +| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | | Erf | ai.onnx(7-9, 10-12, 13+) | erf | ✗ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 05b783fd17902..fc13ce201f2e9 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -171,6 +171,7 @@ static const InlinedHashMap op_map = { {"Cos", {"cos", true}}, {"Div", {"div", true}}, {"DequantizeLinear", {"dequantizeLinear", false}}, + {"Dropout", {"identity", true}}, {"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}}, {"Elu", {"elu", true}}, {"Equal", {"equal", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc new file mode 100644 index 0000000000000..469acbc7a7e18 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class DropoutOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void DropoutOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip ratio and training_mode if present. + for (size_t i = 1; i < node.InputDefs().size(); i++) { + const auto input_name = node.InputDefs()[i]->Name(); + model_builder.AddInitializerToSkip(input_name); + model_builder.AddInputToSkip(input_name); + } +} + +Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + // WebNN EP only supports test mode. So we don't need to care about other inputs or + // attributes about training mode. Simply use WebNN's identity op to copy the input. + emscripten::val output = model_builder.GetBuilder().call("identity", input, options); + + model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); + + // If mask output is requested as output it will contain all ones (bool tensor). + if (output_defs.size() > 1) { + std::vector mask_shape; + ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape"); + std::vector dims = GetVecUint32FromVecInt64(mask_shape); + + emscripten::val desc = emscripten::val::object(); + desc.set("dataType", "uint8"); + desc.set("dimensions", emscripten::val::array(dims)); + const auto num_elements = narrow(Product(mask_shape)); + emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements); + ones_buffer.call("fill", 1); + + emscripten::val mask_output = model_builder.GetBuilder().call("constant", desc, ones_buffer); + + emscripten::val options = emscripten::val::object(); + options.set("label", output_defs[1]->Name() + "_identity"); + // Add additional identity op in case the mask is the output of a WebNN graph, + // beacuse WebNN does not support a constant operand as output. + mask_output = model_builder.GetBuilder().call("identity", mask_output, options); + model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); + } + return Status::OK(); +} + +// Operator support related. +bool DropoutOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + +void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index dfe015725c121..862cf5ded15bc 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -81,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateConcatOpBuilder("Concat", op_registrations); } + { // Dropout + CreateDropoutOpBuilder("Dropout", op_registrations); + } + { // Quantize/Dequantize CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations); CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 818ff094fb640..e11938d8fa406 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -26,6 +26,7 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); From 2653226ed0ef5819f1df5dd1303e2224c329369a Mon Sep 17 00:00:00 2001 From: "Po-Wei (Vincent)" Date: Fri, 2 Aug 2024 18:27:36 -0700 Subject: [PATCH 40/40] Fail tests gracefully for the minimal cuda build (#21391) ### Description Several tests result in segfaults during the minimal cuda build. Although test failures are expected due to the limitation of the minimal cuda EP, failing gracefully would be much preferred. ### Motivation and Context To reproduce: 1. Build ORT with: ```bash ./build.sh --build_shared_lib --use_full_protobuf --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --tensorrt_home /TensorRT-10.0.1.6 --parallel --skip_tests --skip_submodule_sync --allow_running_as_root --use_tensorrt --cmake_extra_defines onnxruntime_CUDA_MINIMAL=1 ``` 2. Run `onnxruntime_test_all` ```bash ... [----------] 1 test from AllocationPlannerTest [ RUN ] AllocationPlannerTest.ReusedInputCrossDifferentStreams Segmentation fault (core dumped) ``` --- cmake/onnxruntime_unittests.cmake | 3 +++ onnxruntime/test/contrib_ops/bias_dropout_op_test.cc | 2 +- onnxruntime/test/framework/allocation_planner_test.cc | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec396fbcc6117..d5c3af748e528 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -79,6 +79,9 @@ function(AddTest) if (onnxruntime_USE_NCCL) target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS}) endif() + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_CUDA_MINIMAL) + endif() endif() if (onnxruntime_USE_TENSORRT) # used for instantiating placeholder TRT builder to mitigate TRT library load/unload overhead diff --git a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc index f700e31003012..027d4b3fff1b0 100644 --- a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. // BiasDropout kernel is only implemented for CUDA/ROCM -#if defined(USE_CUDA) || defined(USE_ROCM) +#if (defined(USE_CUDA) && !defined(USE_CUDA_MINIMAL)) || defined(USE_ROCM) #ifdef _MSC_VER #pragma warning(disable : 4389) diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 4e9e80b180e9c..43d3782be3280 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -2078,6 +2078,7 @@ TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; int gather_count = 0; + ASSERT_GT(plan->execution_plan.size(), 1) << "Number of execution plans should be greater than 1"; for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex());