diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f523e97293427..e295dfa203ae5 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Constrain input A data type to 8-bit integer tensor.
T2 : tensor(int8), tensor(uint8)
Constrain input B data type to 8-bit integer tensor.
-T3 : tensor(float)
+T3 : tensor(float), tensor(float16)
Constrain input a_scale, b_scale and output Y data type as float tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 1eaf0fb6dad76..0e60b4622f2fb 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1268,6 +1268,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
+|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 4313fae767fe5..22a79ef652515 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -434,7 +434,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Output(0, "Y", "Matrix multiply results from A * B", "T3")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.")
- .TypeConstraint("T3", {"tensor(float)"},
+ .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"},
"Constrain input a_scale, b_scale and output Y data type as float tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 8376b87aee6b2..f319e7254568d 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -278,7 +278,8 @@ InlinedVector> GenerateTransformers(
onnxruntime::kAclExecutionProvider,
onnxruntime::kArmNNExecutionProvider,
onnxruntime::kJsExecutionProvider};
-
+ const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider,
+ onnxruntime::kDmlExecutionProvider};
#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
@@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers(
}
transformers.emplace_back(std::make_unique(cpu_ep));
- transformers.emplace_back(std::make_unique(cpu_ep));
+ transformers.emplace_back(std::make_unique(cpu_dml_eps));
transformers.emplace_back(std::make_unique(cpu_ep));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps));
diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc
index 56e51cb787931..4fee1a6ce224e 100644
--- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc
+++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc
@@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) {
return bias_last_dim > 1;
}
+bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
+ if (!node_arg.Exists()) {
+ return false;
+ }
+
+ const auto* type_proto = node_arg.TypeAsProto();
+ if (!type_proto) {
+ return false;
+ }
+
+ int32_t actual_data_type;
+ if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) {
+ return false;
+ }
+
+ return data_type == actual_data_type;
+}
+
/**
MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat:
@@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
auto& mul_node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));
-
+ const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
- !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
+ !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) ||
+ (!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) {
continue;
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index e1e7eacfbd85d..7c25755a7d09e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -879,6 +879,12 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
+};
+
template <>
struct OperatorDescTraits
{
@@ -1041,12 +1047,6 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
};
-template <>
-struct OperatorDescTraits
-{
- static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
-};
-
template <>
struct OperatorDescTraits
{
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 5fe6603c2a0bf..da57c2aa235fd 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -1885,6 +1885,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE
DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS,
};
+constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
+ "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
+ static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 8,
+ DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
@@ -2395,24 +2414,6 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHE
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};
-constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
- DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
-};
-
-constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
- "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
- DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT,
- DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
- 8,
- DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
-};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index 4be41ad3924a2..86c66d8cca26c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -1139,6 +1139,19 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))),
};
}
+inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))),
+ OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ };
+}
inline std::vector GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc)
{
return {
@@ -1487,19 +1500,6 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_P
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))),
};
}
-inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
-{
- return {
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))),
- OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))),
- };
-}
inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
@@ -1829,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA;
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
+ case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
@@ -1856,7 +1857,6 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA;
case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA;
- case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
@@ -2360,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+ case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
+ return AbstractOperatorDesc(
+ &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_CONVOLUTION_INTEGER:
return AbstractOperatorDesc(
&DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA,
@@ -2468,10 +2472,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
- case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
- return AbstractOperatorDesc(
- &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
- GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp
new file mode 100644
index 0000000000000..b5a3dd0960b86
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp
@@ -0,0 +1,111 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+
+class DmlOperatorMatMulIntegerToFloat : public DmlOperator
+{
+ enum OrtInputTensors : uint32_t
+ {
+ ortA,
+ ortB,
+ ortAScale,
+ ortBScale,
+ ortAZeroPoint,
+ ortBZeroPoint,
+ ortBias,
+ ortInputCount
+ };
+
+ enum DmlInputIndex : uint32_t
+ {
+ dmlA,
+ dmlAScale,
+ dmlAZeroPoint,
+ dmlB,
+ dmlBScale,
+ dmlBZeroPoint,
+ dmlBias,
+ dmlInputCount,
+ };
+
+public:
+ DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo)
+ : DmlOperator(kernelInfo)
+ {
+ std::vector> inputIndices = { OrtInputTensors::ortA, OrtInputTensors::ortAScale, OrtInputTensors::ortAZeroPoint, OrtInputTensors::ortB, OrtInputTensors::ortBScale, OrtInputTensors::ortBZeroPoint, OrtInputTensors::ortBias };
+ DmlOperator::Initialize(kernelInfo, inputIndices);
+
+ std::vector inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA);
+ std::vector inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB);
+ std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
+
+ OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
+
+ // Initialize the input descriptions with broadcasting
+ m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
+ m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
+
+ // Broadcast Bias tensor to the shape of the output tensor.
+ if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) {
+ m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce,
+ TensorAxis::W, TensorAxis::RightAligned, outputShape);
+ }
+
+ uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount();
+ // Resize the A Scale to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput(
+ kernelInfo,
+ OrtInputTensors::ortAScale,
+ TensorAxis::DoNotCoerce,
+ TensorAxis::H,
+ TensorAxis::LeftAligned,
+ std::nullopt,
+ dmlDimSize
+ );
+
+ // Resize the A ZeroPoint to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ if (kernelInfo.IsInputValid(OrtInputTensors::ortAZeroPoint))
+ {
+ m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput(
+ kernelInfo,
+ OrtInputTensors::ortAZeroPoint,
+ TensorAxis::DoNotCoerce,
+ TensorAxis::H,
+ TensorAxis::LeftAligned,
+ std::nullopt,
+ dmlDimSize
+ );
+ }
+
+ // B Zeropoint and BScale are already aligned in the W dimension so no need to align them
+
+ // Initialize the output description while overriding the shape
+ m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {};
+ matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA];
+ matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale];
+ matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr;
+ matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB];
+ matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale];
+ matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr;
+ matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr;
+ matMulDesc.OutputTensor = &outputDescs[0];
+
+ DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc };
+ SetDmlOperatorDesc(opDesc, kernelInfo);
+ }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat);
+
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 9c136ed8c9484..f08151b61197a 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -503,6 +503,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
+DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
@@ -622,6 +623,13 @@ constexpr static std::array supportedTypeListQLinea
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
};
+
+constexpr static std::array supportedTypeListMatMulIntegerToFloat = {
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Float16to32
+};
+
constexpr static std::array supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
@@ -1083,6 +1091,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 1b2521a86613f..06bacc1b28c99 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -870,7 +870,6 @@ class QLinearMatMulHelper : public MatMulHelperBase
QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {}
};
-
class TopKHelper
{
void Initialize(
@@ -1776,6 +1775,7 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MatMul = MatMulHelper;
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
+using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper;
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index e725ba085113d..d081aa2e29148 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -449,6 +449,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
+ static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
index 26ce5272d25ee..6f3ca7e239671 100644
--- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
+++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
@@ -23,135 +23,407 @@ using namespace std;
namespace onnxruntime {
namespace test {
-template
-void TestMatMulIntegerToFloat(const std::vector& A_dims,
- std::vector B_dims,
- const std::string& reference_model,
- bool is_matrix_b_constant,
+template
+static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, const int64_t K,
+ const std::vector& A_data, const std::vector& A_scale,
+ const std::vector& A_zero_point, const std::vector& B_data,
+ std::vector& B_scale, std::vector& B_zero_point,
+ const std::vector& Bias, std::vector& Y_data,
+ bool per_column, bool has_zp, bool has_bias) {
+ if (!per_column) {
+ B_zero_point.resize(N, B_zero_point[0]);
+ B_scale.resize(N, B_scale[0]);
+ }
+
+ for (int64_t m = 0; m < M; m++) {
+ for (int64_t n = 0; n < N; n++) {
+ float sum = 0.0f;
+ for (int64_t k = 0; k < K; k++) {
+ float A_dequantized = has_zp ? (static_cast(A_data[m * K + k]) - static_cast(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0];
+ float B_dequantized = has_zp ? (static_cast(B_data[k * N + n]) - static_cast(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n];
+
+ sum += A_dequantized * B_dequantized;
+ }
+ if (has_bias) {
+ sum += Bias[n];
+ }
+ Y_data[m * N + n] = static_cast(sum);
+ }
+ }
+}
+
+template
+void TestMatMulIntegerToFloat(bool is_matrix_b_constant,
bool per_column = false,
bool has_zp = true,
bool has_bias = false) {
// create rand inputs
RandomValueGenerator random{};
-
+ int64_t M = 4;
+ int64_t N = 128;
+ int64_t K = 128;
+ std::vector A_dims{M, K};
+ std::vector B_dims{K, N};
+ std::vector Y_dims{M, K};
std::vector A_data;
- std::vector tmp_A_data = random.Uniform(A_dims,
- std::numeric_limits::lowest(),
- std::numeric_limits::max());
- std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> WType {
+ std::vector tmp_A_data = random.Uniform(A_dims,
+ std::numeric_limits::lowest(),
+ std::numeric_limits::max());
+ std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> IType {
return static_cast(v);
});
std::vector B_data;
- std::vector tmp_B_data = random.Uniform(B_dims,
- std::numeric_limits::lowest(),
- std::numeric_limits::max());
+
+ std::vector tmp_B_data;
+ tmp_B_data = random.Uniform(B_dims,
+ std::is_signed::value ? std::numeric_limits::lowest() / 2 : std::numeric_limits::lowest(),
+ std::numeric_limits::max() / 2);
std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType {
return static_cast(v);
});
- std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f);
+ std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f);
std::vector A_zero_point{(std::numeric_limits::lowest() + std::numeric_limits::max() + IType(2)) / 2};
int64_t b_scale_zp_size = per_column ? B_dims.back() : 1;
- std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f);
+ std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f);
std::vector B_zero_point(b_scale_zp_size);
std::for_each(B_zero_point.begin(),
B_zero_point.end(),
[&random](WType& zp) {
- zp = static_cast(random.Uniform(std::array{1},
- std::numeric_limits::lowest(),
- std::numeric_limits::max())[0]);
+ zp = static_cast(random.Uniform(std::array{1},
+ std::numeric_limits::lowest(),
+ std::numeric_limits::max())[0]);
});
- std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f);
+ std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f);
OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain);
test.AddInput("A", A_dims, A_data);
test.AddInput("B", B_dims, B_data, is_matrix_b_constant);
- test.AddInput("a_scale", {1}, A_scale);
- test.AddInput("b_scale", {b_scale_zp_size}, B_scale);
+ test.AddInput("a_scale", {1}, A_scale);
+ test.AddInput("b_scale", {b_scale_zp_size}, B_scale);
if (has_zp) {
test.AddInput("a_zero_point", {1}, A_zero_point);
test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point);
} else {
- test.AddOptionalInputEdge();
+ test.AddOptionalInputEdge();
test.AddOptionalInputEdge();
}
if (has_bias) {
- test.AddInput("bias", {B_dims.back()}, Bias);
+ test.AddInput("bias", {B_dims.back()}, Bias);
} else {
- test.AddOptionalInputEdge();
+ test.AddOptionalInputEdge();
}
- test.AddReferenceOutputs(reference_model);
- test.SetOutputRelErr("Y", 1e-4f);
- test.Run();
-}
+ std::vector Y_data(M * N);
+ CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point,
+ B_data, B_scale, B_zero_point, Bias, Y_data,
+ per_column, has_zp, has_bias);
-template
-void RunMatMulIntegerToFloatTest(const string& model_path) {
- std::vector A_dims{4, 128};
- std::vector B_dims{128, 128};
- std::vector Y_dims{4, 128};
+ if (std::is_same_v) {
+ test.AddOutput("Y", {M, N}, Y_data);
+ test.SetOutputRelErr("Y", 0.02f);
+ } else {
+ test.AddOutput("Y", {M, N}, ToFloat16(Y_data));
+ test.SetOutputAbsErr("Y", 0.5f);
+ }
- TestMatMulIntegerToFloat(A_dims,
- B_dims,
- model_path,
- false, /*is_matrix_b_constant*/
- false, /*per_column*/
- HasZeroPoint, /*has_zp*/
- HasBias /*has_bias*/
+ // Only DML EP supports these data type combinations for now
+ if (std::is_same_v ||
+ (std::is_same_v &&
+ std::is_same_v &&
+ std::is_same_v)) {
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ } else {
+ test.Run();
+ }
+}
+
+template
+void RunMatMulIntegerToFloatTest() {
+ TestMatMulIntegerToFloat(
+ false, /*is_matrix_b_constant*/
+ false, /*per_column*/
+ HasZeroPoint, /*has_zp*/
+ HasBias /*has_bias*/
);
- TestMatMulIntegerToFloat(A_dims,
- B_dims,
- model_path,
- true, /*is_matrix_b_constant*/
- false, /*per_column*/
- HasZeroPoint, /*has_zp*/
- HasBias /*has_bias*/
+ TestMatMulIntegerToFloat(
+ true, /*is_matrix_b_constant*/
+ false, /*per_column*/
+ HasZeroPoint, /*has_zp*/
+ HasBias /*has_bias*/
);
- TestMatMulIntegerToFloat(A_dims,
- B_dims,
- model_path,
- false, /*is_matrix_b_constant*/
- true, /*per_column*/
- HasZeroPoint, /*has_zp*/
- HasBias /*has_bias*/
+ TestMatMulIntegerToFloat(
+ false, /*is_matrix_b_constant*/
+ true, /*per_column*/
+ HasZeroPoint, /*has_zp*/
+ HasBias /*has_bias*/
);
- TestMatMulIntegerToFloat(A_dims,
- B_dims,
- model_path,
- true, /*is_matrix_b_constant*/
- true, /*per_column*/
- HasZeroPoint, /*has_zp*/
- HasBias /*has_bias*/
+ TestMatMulIntegerToFloat(
+ true, /*is_matrix_b_constant*/
+ true, /*per_column*/
+ HasZeroPoint, /*has_zp*/
+ HasBias /*has_bias*/
);
}
-TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8X8) {
- RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8.onnx");
- RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8.onnx");
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) {
+ RunMatMulIntegerToFloatTest();
}
-TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8X8) {
- RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_bias.onnx");
- RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8_bias.onnx");
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) {
+ RunMatMulIntegerToFloatTest();
}
-TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) {
- RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8.onnx");
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8S8) {
+ RunMatMulIntegerToFloatTest();
}
-TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) {
- RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8_bias.onnx");
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8S8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8X8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8S8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8S8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8S8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+// DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output
+#if defined(USE_DML)
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8U8) {
+ RunMatMulIntegerToFloatTest();
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8U8) {
+ OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+ int64_t M = 5;
+ int64_t N = 5;
+ int64_t K = 2;
+
+ std::vector A_data = {1, 5, 2, 1, 9,
+ 1, 1, 3, 7, 2};
+ std::vector B_data = {3, 7, 2, 1, 1,
+ 2, 1, 9, 1, 1};
+ std::vector A_scale = ToFloat16({3.0f});
+ std::vector B_scale = ToFloat16({2.0f});
+ test.AddInput("A", {M, K}, A_data);
+ test.AddInput("B", {K, N}, B_data);
+ std::vector A_zero_point = {1};
+ std::vector B_zero_point = {1};
+
+ test.AddInput("a_scale", {1}, A_scale);
+ test.AddInput("b_scale", {1}, B_scale);
+ test.AddInput("a_zero_point", {1}, A_zero_point);
+ test.AddInput("b_zero_point", {1}, B_zero_point);
+
+ std::vector Y_data(M * N);
+ CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point,
+ B_data, B_scale, B_zero_point, {}, Y_data,
+ false, true, false);
+
+ test.AddOutput("Y", {M, N}, ToFloat16(Y_data));
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8S8) {
+ OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+ int64_t M = 5;
+ int64_t N = 5;
+ int64_t K = 2;
+
+ std::vector A_data = {3, 7, 2, 1, 1,
+ 2, 1, 9, 1, 1};
+ std::vector B_data = {2, -1, -9, 1, 1,
+ -1, 0, -3, 1, -4};
+ std::vector A_scale = ToFloat16({-4.0f});
+ std::vector B_scale = ToFloat16({2.0f});
+ test.AddInput("A", {M, K}, A_data);
+ test.AddInput("B", {K, N}, B_data);
+ std::vector A_zero_point = {1};
+ std::vector B_zero_point = {3};
+ std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f});
+
+ test.AddInput("a_scale", {1}, A_scale);
+ test.AddInput("b_scale", {1}, B_scale);
+ test.AddInput("a_zero_point", {1}, A_zero_point);
+ test.AddInput("b_zero_point", {1}, B_zero_point);
+
+ std::vector Y_data(M * N);
+ CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point,
+ B_data, B_scale, B_zero_point, {}, Y_data,
+ false, true, false);
+
+ test.AddOutput("Y", {M, N}, ToFloat16(Y_data));
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8S8) {
+ OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+ int64_t M = 5;
+ int64_t N = 5;
+ int64_t K = 2;
+
+ std::vector A_data = {3, 7, -2, 1, 1,
+ 2, -1, -9, 1, 1};
+ std::vector B_data = {2, -1, -9, 1, 1,
+ -1, 0, -3, 1, -4};
+ std::vector A_scale = ToFloat16({-4.0f});
+ std::vector B_scale = ToFloat16({2.0f});
+ test.AddInput("A", {M, K}, A_data);
+ test.AddInput("B", {K, N}, B_data);
+ std::vector A_zero_point = {-1};
+ std::vector B_zero_point = {3};
+ std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f});
+
+ test.AddInput("a_scale", {1}, A_scale);
+ test.AddInput("b_scale", {1}, B_scale);
+ test.AddInput("a_zero_point", {1}, A_zero_point);
+ test.AddInput("b_zero_point", {1}, B_zero_point);
+ test.AddInput("bias", {N}, Bias);
+
+ std::vector Y_data(M * N);
+ CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point,
+ B_data, B_scale, B_zero_point, Bias, Y_data,
+ false, true, true);
+
+ test.AddOutput("Y", {M, N}, ToFloat16(Y_data));
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8U8) {
+ OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+ int64_t M = 5;
+ int64_t N = 5;
+ int64_t K = 2;
+
+ std::vector A_data = {3, 7, -2, 1, 1,
+ 2, -1, -9, 1, 1};
+ std::vector B_data = {3, 7, 2, 1, 1,
+ 2, 1, 9, 1, 1};
+ std::vector A_scale = ToFloat16({-4.0f});
+ std::vector B_scale = ToFloat16({2.0f});
+ test.AddInput("A", {M, K}, A_data);
+ test.AddInput("B", {K, N}, B_data);
+ std::vector A_zero_point = {-1};
+ std::vector B_zero_point = {1};
+ std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f});
+
+ test.AddInput("a_scale", {1}, A_scale);
+ test.AddInput("b_scale", {1}, B_scale);
+ test.AddInput("a_zero_point", {1}, A_zero_point);
+ test.AddInput("b_zero_point", {1}, B_zero_point);
+ test.AddInput("bias", {N}, Bias);
+
+ std::vector Y_data(M * N);
+ CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point,
+ B_data, B_scale, B_zero_point, Bias, Y_data,
+ false, true, true);
+
+ test.AddOutput("Y", {M, N}, ToFloat16(Y_data));
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16) {
+ OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+ int64_t M = 2;
+ int64_t N = 2;
+ int64_t K = 3;
+
+ std::vector A_data = {11, -2, 5,
+ -1, 3, 10};
+ std::vector