diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f523e97293427..e295dfa203ae5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Constrain input A data type to 8-bit integer tensor.
T2 : tensor(int8), tensor(uint8)
Constrain input B data type to 8-bit integer tensor.
-
T3 : tensor(float)
+
T3 : tensor(float), tensor(float16)
Constrain input a_scale, b_scale and output Y data type as float tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1eaf0fb6dad76..0e60b4622f2fb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1268,6 +1268,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 4313fae767fe5..22a79ef652515 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -434,7 +434,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Matrix multiply results from A * B", "T3") .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.") .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.") - .TypeConstraint("T3", {"tensor(float)"}, + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input a_scale, b_scale and output Y data type as float tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 2, 0); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 8376b87aee6b2..f319e7254568d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -278,7 +278,8 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - + const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_eps)); transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 56e51cb787931..4fee1a6ce224e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) { return bias_last_dim > 1; } +bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { + if (!node_arg.Exists()) { + return false; + } + + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto) { + return false; + } + + int32_t actual_data_type; + if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) { + return false; + } + + return data_type == actual_data_type; +} + /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: @@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g auto& mul_node = *node_ptr; ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger)); - + const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider; if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || - !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) { + !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) || + (!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) { continue; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index e1e7eacfbd85d..7c25755a7d09e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -879,6 +879,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1041,12 +1047,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; -}; - template <> struct OperatorDescTraits { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 5fe6603c2a0bf..da57c2aa235fd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -1885,6 +1885,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, @@ -2395,24 +2414,6 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHE DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { - "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", - DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, -}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 4be41ad3924a2..86c66d8cca26c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1139,6 +1139,19 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc) { return { @@ -1487,19 +1500,6 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_P OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } -inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1829,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA; case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA; @@ -1856,7 +1857,6 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -2360,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_CONVOLUTION_INTEGER: return AbstractOperatorDesc( &DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA, @@ -2468,10 +2472,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: - return AbstractOperatorDesc( - &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp new file mode 100644 index 0000000000000..b5a3dd0960b86 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorMatMulIntegerToFloat : public DmlOperator +{ + enum OrtInputTensors : uint32_t + { + ortA, + ortB, + ortAScale, + ortBScale, + ortAZeroPoint, + ortBZeroPoint, + ortBias, + ortInputCount + }; + + enum DmlInputIndex : uint32_t + { + dmlA, + dmlAScale, + dmlAZeroPoint, + dmlB, + dmlBScale, + dmlBZeroPoint, + dmlBias, + dmlInputCount, + }; + +public: + DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + std::vector> inputIndices = { OrtInputTensors::ortA, OrtInputTensors::ortAScale, OrtInputTensors::ortAZeroPoint, OrtInputTensors::ortB, OrtInputTensors::ortBScale, OrtInputTensors::ortBZeroPoint, OrtInputTensors::ortBias }; + DmlOperator::Initialize(kernelInfo, inputIndices); + + std::vector inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA); + std::vector inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape); + + // Initialize the input descriptions with broadcasting + m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0); + m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1); + + // Broadcast Bias tensor to the shape of the output tensor. + if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) { + m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce, + TensorAxis::W, TensorAxis::RightAligned, outputShape); + } + + uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount(); + // Resize the A Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAScale, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + + // Resize the A ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortAZeroPoint)) + { + m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAZeroPoint, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + } + + // B Zeropoint and BScale are already aligned in the W dimension so no need to align them + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {}; + matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA]; + matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale]; + matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr; + matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB]; + matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale]; + matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr; + matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr; + matMulDesc.OutputTensor = &outputDescs[0]; + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 9c136ed8c9484..f08151b61197a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,6 +503,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); +DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); @@ -622,6 +623,13 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; + +constexpr static std::array supportedTypeListMatMulIntegerToFloat = { + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Float16to32 +}; + constexpr static std::array supportedTypeListQLinearConv = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -1083,6 +1091,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 1b2521a86613f..06bacc1b28c99 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -870,7 +870,6 @@ class QLinearMatMulHelper : public MatMulHelperBase QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {} }; - class TopKHelper { void Initialize( @@ -1776,6 +1775,7 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; +using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e725ba085113d..d081aa2e29148 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -449,6 +449,7 @@ namespace OperatorHelper static const int sc_sinceVer_FusedMatMulActivation = 1; static const int sc_sinceVer_QLinearSigmoid = 1; static const int sc_sinceVer_Attention = 1; + static const int sc_sinceVer_MatMulIntegerToFloat = 1; static const int sc_sinceVer_MultiHeadAttention = 1; static const int sc_sinceVer_SkipLayerNormalization = 1; static const int sc_sinceVer_EmbedLayerNormalization = 1; diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 26ce5272d25ee..6f3ca7e239671 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -23,135 +23,407 @@ using namespace std; namespace onnxruntime { namespace test { -template -void TestMatMulIntegerToFloat(const std::vector& A_dims, - std::vector B_dims, - const std::string& reference_model, - bool is_matrix_b_constant, +template +static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, const int64_t K, + const std::vector& A_data, const std::vector& A_scale, + const std::vector& A_zero_point, const std::vector& B_data, + std::vector& B_scale, std::vector& B_zero_point, + const std::vector& Bias, std::vector& Y_data, + bool per_column, bool has_zp, bool has_bias) { + if (!per_column) { + B_zero_point.resize(N, B_zero_point[0]); + B_scale.resize(N, B_scale[0]); + } + + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + float A_dequantized = has_zp ? (static_cast(A_data[m * K + k]) - static_cast(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0]; + float B_dequantized = has_zp ? (static_cast(B_data[k * N + n]) - static_cast(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n]; + + sum += A_dequantized * B_dequantized; + } + if (has_bias) { + sum += Bias[n]; + } + Y_data[m * N + n] = static_cast(sum); + } + } +} + +template +void TestMatMulIntegerToFloat(bool is_matrix_b_constant, bool per_column = false, bool has_zp = true, bool has_bias = false) { // create rand inputs RandomValueGenerator random{}; - + int64_t M = 4; + int64_t N = 128; + int64_t K = 128; + std::vector A_dims{M, K}; + std::vector B_dims{K, N}; + std::vector Y_dims{M, K}; std::vector A_data; - std::vector tmp_A_data = random.Uniform(A_dims, - std::numeric_limits::lowest(), - std::numeric_limits::max()); - std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> WType { + std::vector tmp_A_data = random.Uniform(A_dims, + std::numeric_limits::lowest(), + std::numeric_limits::max()); + std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> IType { return static_cast(v); }); std::vector B_data; - std::vector tmp_B_data = random.Uniform(B_dims, - std::numeric_limits::lowest(), - std::numeric_limits::max()); + + std::vector tmp_B_data; + tmp_B_data = random.Uniform(B_dims, + std::is_signed::value ? std::numeric_limits::lowest() / 2 : std::numeric_limits::lowest(), + std::numeric_limits::max() / 2); std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType { return static_cast(v); }); - std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f); + std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f); std::vector A_zero_point{(std::numeric_limits::lowest() + std::numeric_limits::max() + IType(2)) / 2}; int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; - std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); + std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); std::for_each(B_zero_point.begin(), B_zero_point.end(), [&random](WType& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::lowest(), - std::numeric_limits::max())[0]); + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::lowest(), + std::numeric_limits::max())[0]); }); - std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); + std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("a_scale", {1}, A_scale); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale); if (has_zp) { test.AddInput("a_zero_point", {1}, A_zero_point); test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddOptionalInputEdge(); } if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + test.AddInput("bias", {B_dims.back()}, Bias); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); } - test.AddReferenceOutputs(reference_model); - test.SetOutputRelErr("Y", 1e-4f); - test.Run(); -} + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + per_column, has_zp, has_bias); -template -void RunMatMulIntegerToFloatTest(const string& model_path) { - std::vector A_dims{4, 128}; - std::vector B_dims{128, 128}; - std::vector Y_dims{4, 128}; + if (std::is_same_v) { + test.AddOutput("Y", {M, N}, Y_data); + test.SetOutputRelErr("Y", 0.02f); + } else { + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + test.SetOutputAbsErr("Y", 0.5f); + } - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + // Only DML EP supports these data type combinations for now + if (std::is_same_v || + (std::is_same_v && + std::is_same_v && + std::is_same_v)) { + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.Run(); + } +} + +template +void RunMatMulIntegerToFloatTest() { + TestMatMulIntegerToFloat( + false, /*is_matrix_b_constant*/ + false, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + true, /*is_matrix_b_constant*/ + false, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + false, /*is_matrix_b_constant*/ + true, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + true, /*is_matrix_b_constant*/ + true, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); } -TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8X8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8.onnx"); - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8.onnx"); +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8X8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_bias.onnx"); - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8_bias.onnx"); +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8.onnx"); +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8_bias.onnx"); +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8X8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +// DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output +#if defined(USE_DML) + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8U8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {1, 5, 2, 1, 9, + 1, 1, 3, 7, 2}; + std::vector B_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector A_scale = ToFloat16({3.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {1}; + std::vector B_zero_point = {1}; + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, {}, Y_data, + false, true, false); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8S8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector B_data = {2, -1, -9, 1, 1, + -1, 0, -3, 1, -4}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {1}; + std::vector B_zero_point = {3}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, {}, Y_data, + false, true, false); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8S8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, -2, 1, 1, + 2, -1, -9, 1, 1}; + std::vector B_data = {2, -1, -9, 1, 1, + -1, 0, -3, 1, -4}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {-1}; + std::vector B_zero_point = {3}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + false, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8U8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, -2, 1, 1, + 2, -1, -9, 1, 1}; + std::vector B_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {-1}; + std::vector B_zero_point = {1}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + false, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 2; + int64_t N = 2; + int64_t K = 3; + + std::vector A_data = {11, -2, 5, + -1, 3, 10}; + std::vector B_data = {-13, -2, + 9, 55, + -1, 23}; + std::vector A_scale = ToFloat16({0.910f}); + std::vector B_scale = ToFloat16({1.10f, 1.123f}); + + std::vector A_zero_point = {113}; + std::vector B_zero_point = {98, 71}; + + std::vector Bias = ToFloat16({0.10f, 1.123f}); + + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + + test.AddInput("a_scale", {}, {A_scale}); + test.AddInput("b_scale", {N}, B_scale); + test.AddInput("a_zero_point", {}, {A_zero_point}); + test.AddInput("b_zero_point", {N}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + true, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + test.SetOutputRelErr("Y", 2e-2f); + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) { auto test_case = [&](const std::vector& input_shape, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 16f38bac62713..1535e2b60a3bd 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5679,6 +5679,24 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) { EXPECT_EQ(op_to_count["Add"], 1); } +#ifdef USE_DML +TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float16_int8.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kDmlExecutionProvider); + } + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); +} +#endif // USE_DML + #endif #ifndef DISABLE_CONTRIB_OPS diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index b898390044cf4..e6c51009018f9 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -4,7 +4,7 @@ from onnx import TensorProto, helper -def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: N802 +def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False): # noqa: N802 nodes = [ # subgraph helper.make_node( "MatMulInteger", @@ -13,7 +13,13 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: "MatMulInteger", ), helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"), - helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1), + helper.make_node( + "Cast", + ["matmul_output_int32"], + ["matmul_output_float"], + "cast", + to=TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, + ), helper.make_node( "Mul", ["matmul_output_float", "multiplier"], @@ -25,8 +31,8 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: inputs = [ # inputs helper.make_tensor_value_info("A", TensorProto.INT8 if sign_i else TensorProto.UINT8, ["M", "K"]), helper.make_tensor_value_info("B", TensorProto.INT8 if sign_w else TensorProto.UINT8, ["K", "N"]), - helper.make_tensor_value_info("a_scale", TensorProto.FLOAT, [1]), - helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("a_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("b_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["C"]), ] if has_zp: @@ -48,14 +54,22 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: if bias: nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend([helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["N"])]) + inputs.extend( + [ + helper.make_tensor_value_info( + "bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"] + ) + ] + ) graph = helper.make_graph( nodes, "DynamicQuantizeMatMul_fusion", # name inputs, [ # outputs - helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info( + "Y", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["M", "N"] + ), ], ) @@ -64,10 +78,32 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: if __name__ == "__main__": - GenerateModel("matmul_integer_to_float_int8.onnx", False, True) - GenerateModel("matmul_integer_to_float_uint8.onnx", False, False) - GenerateModel("matmul_integer_to_float_int8_bias.onnx", False, True, False, True) - GenerateModel("matmul_integer_to_float_uint8_bias.onnx", False, False, False, True) + GenerateModel("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) + GenerateModel("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) + GenerateModel("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) + GenerateModel( + "matmul_integer_to_float_int8_bias.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + ) + GenerateModel( + "matmul_integer_to_float_uint8_bias.onnx", + sign_i=False, + sign_w=False, + output_type_fp16=False, + has_zp=False, + bias=True, + ) - GenerateModel("matmul_integer_to_float_int8_int8.onnx", True, True) - GenerateModel("matmul_integer_to_float_int8_int8_bias.onnx", True, True, False, True) + GenerateModel("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) + GenerateModel( + "matmul_integer_to_float_int8_int8_bias.onnx", + sign_i=True, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + ) diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx index 9f4465a914963..906dec542a4fa 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx index 01b7e15aa4a1f..16cdf03c7ae59 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx index 9d38828e25d6a..55102757a0b57 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx index 4d9a55af50a87..d9d7222a1acaa 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx index a4c6d20d59be8..5373ce145688e 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx index a5be0c63f4dcb..e407414b23b24 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx index 7ea69c580ee43..aa8e67bcbc59e 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx new file mode 100644 index 0000000000000..22293b0d10756 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx @@ -0,0 +1,51 @@ + :Ì +U +A +B + a_zero_point + b_zero_pointmatmul_output_int32 MatMulInteger" MatMulInteger +. +a_scale +b_scale +multiplier mul_right"Mul +A +matmul_output_int32matmul_output_floatcast"Cast* +to +  +5 +matmul_output_float + +multiplierY +mul_bottom"MulDynamicQuantizeMatMul_fusionZ +A + + +M +KZ +B + + +K +NZ +a_scale + + + +Z +b_scale +  + +CZ + a_zero_point + + +Z + b_zero_point +  +Cb +Y + + + +M +NB \ No newline at end of file