Skip to content

Commit

Permalink
[DML] MatrixMultiplyIntegerToFloat (microsoft#19608)
Browse files Browse the repository at this point in the history
### Description
DML Implementation for
[com.microsoft.MatMulIntegerToFloat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulIntegerToFloat)

```
.\onnxruntime_test_all.exe --gtest_filter="*MatMulIntegerToFloat.*"
Note: Google Test filter = *MatMulIntegerToFloat.*
[==========] Running 22 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 22 tests from MatMulIntegerToFloat
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 (620 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 (497 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8 (488 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8 (503 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8 (495 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8 (488 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8 (492 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8 (502 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8 (452 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8 (454 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8 (446 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8 (508 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8 (456 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8 (455 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 (447 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8 (465 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8
[       OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8 (111 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8
[       OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8 (115 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8
[       OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8 (114 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8
[       OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8 (110 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16
[       OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16 (112 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint
[       OK ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint (337 ms)
[----------] 22 tests from MatMulIntegerToFloat (8679 ms total)

[----------] Global test environment tear-down
[==========] 22 tests from 1 test suite ran. (8680 ms total)
[  PASSED  ] 22 tests.
memleakdbg:
----- No memory leaks detected -----
```


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
* `CalculateMatMulIntegerToFloat` to replace CPU EP run reference
* Added more FP32 testcases to isolate all input datatype combinations 
* Added fixed input to `MatMulIntegerToFloat_FP16*` test cases as for
FP16 test cases.
* onnxruntime/test/testdata/matmul_integer_to_float.py` is capable of
generating FP16 models, but we do not produce any for now
  • Loading branch information
raoanag authored and Zhenze Wang committed Mar 7, 2024
1 parent 09c907f commit 28bfe9b
Show file tree
Hide file tree
Showing 23 changed files with 664 additions and 144 deletions.
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Constrain input A data type to 8-bit integer tensor.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8)</dt>
<dd>Constrain input B data type to 8-bit integer tensor.</dd>
<dt><tt>T3</tt> : tensor(float)</dt>
<dt><tt>T3</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input a_scale, b_scale and output Y data type as float tensor.</dd>
</dl>

Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/quantization_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Output(0, "Y", "Matrix multiply results from A * B", "T3")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.")
.TypeConstraint("T3", {"tensor(float)"},
.TypeConstraint("T3", {"tensor(float)", "tensor(float16)"},
"Constrain input a_scale, b_scale and output Y data type as float tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
onnxruntime::kAclExecutionProvider,
onnxruntime::kArmNNExecutionProvider,
onnxruntime::kJsExecutionProvider};

const InlinedHashSet<std::string_view> cpu_dml_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kDmlExecutionProvider};
#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
Expand All @@ -296,7 +297,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
}

transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_ep));
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_dml_eps));
transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_ep));

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_cuda_rocm_acl_armnn_js_eps));
Expand Down
23 changes: 21 additions & 2 deletions onnxruntime/core/optimizer/matmul_integer_to_float.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) {
return bias_last_dim > 1;
}

bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
if (!node_arg.Exists()) {
return false;
}

const auto* type_proto = node_arg.TypeAsProto();
if (!type_proto) {
return false;
}

int32_t actual_data_type;
if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) {
return false;
}

return data_type == actual_data_type;
}

/**
MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat:
Expand Down Expand Up @@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
auto& mul_node = *node_ptr;

ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));

const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) ||
(!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,12 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
};

template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};

template <>
struct OperatorDescTraits<DML_CONVOLUTION_INTEGER_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1041,12 +1047,6 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
};

template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};

template <>
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE
DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};

constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
Expand Down Expand Up @@ -2395,24 +2414,6 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHE
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};

constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,19 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc)
{
return {
Expand Down Expand Up @@ -1487,19 +1500,6 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_P
OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
Expand Down Expand Up @@ -1829,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA;
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
Expand Down Expand Up @@ -1856,7 +1857,6 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA;
case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
Expand Down Expand Up @@ -2360,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return AbstractOperatorDesc(
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_CONVOLUTION_INTEGER:
return AbstractOperatorDesc(
&DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA,
Expand Down Expand Up @@ -2468,10 +2472,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return AbstractOperatorDesc(
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
Expand Down
Loading

0 comments on commit 28bfe9b

Please sign in to comment.