From c3d96a7b35c975a4eb4ad5a4c94349797defbc78 Mon Sep 17 00:00:00 2001
From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com>
Date: Tue, 2 Jan 2024 18:06:26 -0800
Subject: [PATCH 01/25] Update DML version to 1.13.0 (#18978)
Update DML nuget version to 1.13.0
---
.pipelines/nuget_config/x64/packages.config | 2 +-
.pipelines/nuget_config/x86/packages.config | 2 +-
cmake/external/dml.cmake | 2 +-
packages.config | 2 +-
tools/nuget/generate_nuspec_for_native_nuget.py | 2 +-
5 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config
index 2ac650b0e6dc9..2583e0d1b2ead 100644
--- a/.pipelines/nuget_config/x64/packages.config
+++ b/.pipelines/nuget_config/x64/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config
index f80f96194a230..5ca659941c159 100644
--- a/.pipelines/nuget_config/x86/packages.config
+++ b/.pipelines/nuget_config/x86/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake
index 5d25b9529e030..d777306722cd6 100644
--- a/cmake/external/dml.cmake
+++ b/cmake/external/dml.cmake
@@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
- set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.1)
+ set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0)
# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
diff --git a/packages.config b/packages.config
index da61a10adfa74..b67219d6d6913 100644
--- a/packages.config
+++ b/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py
index 66248565a3e3a..56e50750ac153 100644
--- a/tools/nuget/generate_nuspec_for_native_nuget.py
+++ b/tools/nuget/generate_nuspec_for_native_nuget.py
@@ -219,7 +219,7 @@ def add_common_dependencies(xml_text, package_name, version):
def generate_dependencies(xml_text, package_name, version):
- dml_dependency = ''
+ dml_dependency = ''
if package_name == "Microsoft.AI.MachineLearning":
xml_text.append("")
From 9bbe425d7f805d4dfaad2e69c3edca40063fe673 Mon Sep 17 00:00:00 2001
From: Xiang Zhang
Date: Thu, 27 Jul 2023 19:13:15 -0700
Subject: [PATCH 02/25] Register LPpool18 and AvgPool 19 (#16880)
---
.../src/External/DirectMLHelpers/ApiTraits.h | 30 ++++++++++++++
.../External/DirectMLHelpers/DirectMLSchema.h | 40 +++++++++++++++++++
.../DirectMLHelpers/GeneratedSchemaHelpers.h | 39 ++++++++++++++++++
.../src/Operators/DmlOperatorPooling.cpp | 40 +++++++++++++++++--
.../src/Operators/OperatorRegistration.cpp | 2 +
.../OperatorAuthorHelper/OperatorVersions.h | 6 +++
.../test/providers/cpu/nn/pool_op_test.cc | 25 ------------
7 files changed, 153 insertions(+), 29 deletions(-)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index c75b662af788d..94f2220fcc168 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -459,12 +459,24 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING1;
+};
+
template <>
struct OperatorDescTraits
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING1;
+};
+
template <>
struct OperatorDescTraits
{
@@ -1448,12 +1460,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING>
using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING1>
+{
+ using DescType = DML_AVERAGE_POOLING1_OPERATOR_DESC;
+};
+
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING>
{
using DescType = DML_LP_POOLING_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING1>
+{
+ using DescType = DML_LP_POOLING1_OPERATOR_DESC;
+};
+
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING>
{
@@ -2259,8 +2283,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_AVERAGE_POOLING:
return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...);
+ case DML_OPERATOR_AVERAGE_POOLING1:
+ return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING1_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_LP_POOLING:
return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...);
+ case DML_OPERATOR_LP_POOLING1:
+ return std::invoke(std::forward(visitor), DML_LP_POOLING1_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_MAX_POOLING:
return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_MAX_POOLING1:
@@ -2554,7 +2582,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING";
+ case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1";
case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING";
+ case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1";
case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING";
case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1";
case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING";
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 1ebd52d4ed427..9eae1c1fe8158 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -757,6 +757,26 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA {
DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};
+constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS[9] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING1_OPERATOR_SCHEMA {
+ "DML_OPERATOR_AVERAGE_POOLING1",
+ DML_OPERATOR_AVERAGE_POOLING1,
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 9,
+ DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@@ -776,6 +796,26 @@ constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA {
DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS,
};
+constexpr DML_SCHEMA_FIELD DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS[9] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING1_OPERATOR_SCHEMA {
+ "DML_OPERATOR_LP_POOLING1",
+ DML_OPERATOR_LP_POOLING1,
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 9,
+ DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index 833871de0bbd9..ad4cceb85cfd2 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -425,6 +425,21 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D
OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))),
};
}
+
+inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.IncludePadding))),
+ };
+}
inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc)
{
return {
@@ -438,6 +453,20 @@ inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC&
OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))),
};
}
+inline std::vector GetFields(const DML_LP_POOLING1_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.P))),
+ };
+}
inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc)
{
return {
@@ -1684,7 +1713,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA;
case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA;
+ case DML_OPERATOR_AVERAGE_POOLING1: return DML_AVERAGE_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA;
+ case DML_OPERATOR_LP_POOLING1: return DML_LP_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA;
@@ -2002,10 +2033,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+ case DML_OPERATOR_AVERAGE_POOLING1:
+ return AbstractOperatorDesc(
+ &DML_AVERAGE_POOLING1_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_LP_POOLING:
return AbstractOperatorDesc(
&DML_LP_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+ case DML_OPERATOR_LP_POOLING1:
+ return AbstractOperatorDesc(
+ &DML_LP_POOLING1_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_MAX_POOLING:
return AbstractOperatorDesc(
&DML_MAX_POOLING_OPERATOR_SCHEMA,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
index e8d5b2746aa13..10ff1d8be8a29 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
@@ -34,7 +34,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
kernelOutputIndices.emplace_back(1);
}
DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices);
-
+
std::vector inputDescs = GetDmlInputDescs();
std::vector outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
@@ -98,6 +98,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
SetOpDesc(desc);
break;
}
+ case DML_OPERATOR_AVERAGE_POOLING1:
+ {
+ if (hasDilations) {
+ DML_AVERAGE_POOLING1_OPERATOR_DESC desc = {};
+ desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+ desc.Dilations = m_kernel.dilations;
+ SetOpDesc(desc);
+ }
+ else {
+ DML_AVERAGE_POOLING_OPERATOR_DESC desc = {};
+ desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+ SetOpDesc(desc);
+ }
+ break;
+ }
case DML_OPERATOR_LP_POOLING:
{
DML_LP_POOLING_OPERATOR_DESC desc = {};
@@ -106,6 +121,23 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
SetOpDesc(desc);
break;
}
+ case DML_OPERATOR_LP_POOLING1:
+ {
+ if (hasDilations) {
+ DML_LP_POOLING1_OPERATOR_DESC desc = {};
+ desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2);
+ ML_CHECK_VALID_ARGUMENT(desc.P > 0);
+ desc.Dilations = m_kernel.dilations;
+ SetOpDesc(desc);
+ }
+ else {
+ DML_LP_POOLING_OPERATOR_DESC desc = {};
+ desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2);
+ ML_CHECK_VALID_ARGUMENT(desc.P > 0);
+ SetOpDesc(desc);
+ }
+ break;
+ }
case DML_OPERATOR_MAX_POOLING:
case DML_OPERATOR_MAX_POOLING1:
case DML_OPERATOR_MAX_POOLING2:
@@ -152,7 +184,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling
void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = false;
-
+
MLOperatorAttributes attributes(context);
int storageOrder = attributes.GetOptionalAttribute(AttrName::StorageOrder, 0);
@@ -164,11 +196,11 @@ void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool*
*isSupported = true;
}
-DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate);
-DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 28360f09bcba3..dbe9f5da4f569 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -667,6 +667,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO( 19, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
@@ -677,6 +678,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO( 18, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index e18ba31def48a..3eb35faeba82f 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -406,6 +406,12 @@ namespace OperatorHelper
static const int sc_sinceVer_BitwiseNot = 18;
static const int sc_sinceVer_Pad = 18;
static const int sc_sinceVer_Split = 18;
+ static const int sc_sinceVer_LpPool = 18;
+ }
+
+ namespace OnnxOperatorSet19
+ {
+ static const int sc_sinceVer_AveragePool = 19;
}
namespace MsftOperatorSet1
diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
index 10476ada2fa69..4b194ec18b31b 100644
--- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
@@ -777,11 +777,6 @@ TEST(PoolTest, GlobalMaxPool3D) {
}
TEST(PoolTest, AveragePool) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool");
test.AddAttribute("auto_pad", "");
@@ -863,11 +858,6 @@ TEST(PoolTest, AveragePool) {
}
TEST(PoolTest, AveragePool_IncludePadPixel) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool");
test.AddAttribute("auto_pad", "");
@@ -911,11 +901,6 @@ TEST(PoolTest, AveragePool_DefaultStrides) {
}
TEST(PoolTest, AveragePool_10_ceil1_2d) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool", 10);
test.AddAttribute("auto_pad", "");
@@ -939,11 +924,6 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) {
}
TEST(PoolTest, AveragePool_19_dilation_2d) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool", 19);
test.AddAttribute("auto_pad", "");
@@ -1070,11 +1050,6 @@ TEST(PoolTest, GlobalAveragePool_Large_256) {
}
TEST(PoolTest, LpPool) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("LpPool");
test.AddAttribute("auto_pad", "");
From 9ff5e3b7b0f5e9683422fa3609175fe4f82ccb77 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Fri, 3 Nov 2023 09:34:35 -0700
Subject: [PATCH 03/25] Add QLinearConcat for DML EP (#16971) (#18268)
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)
[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](https://github.com/microsoft/onnxruntime/pull/16971)
### Motivation and Context
Co-authored-by: Xiang Zhang
---
.../Operators/DmlOperatorQLinearConcat.cpp | 236 ++++++++++++++++++
.../src/Operators/OperatorRegistration.cpp | 16 +-
.../src/Operators/OperatorUtility.cpp | 4 +-
.../src/Operators/OperatorUtility.h | 3 +-
.../OperatorAuthorHelper/OperatorHelper.cpp | 18 +-
.../dml/OperatorAuthorHelper/OperatorHelper.h | 25 +-
.../OperatorAuthorHelper/OperatorVersions.h | 1 +
7 files changed, 290 insertions(+), 13 deletions(-)
create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
new file mode 100644
index 0000000000000..67711fdc28b84
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
@@ -0,0 +1,236 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+// QLinearConcat = Dequantize + Join + Quantize
+class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
+{
+ // This order matches the ONNX schema.
+ enum OnnxInputIndex
+ {
+ YScale,
+ YZeroPoint,
+ Count,
+ };
+
+public:
+ DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext),
+ QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
+ {
+ DmlOperator::Initialize(kernelCreationContext);
+
+ auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
+
+ // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
+ uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount();
+ ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
+ ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");
+
+ uint32_t inputCount = (inputDefinitionCount - 2) / 3;
+
+ auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType;
+ auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType;
+
+ // broadcast y_scale and y_zero_point to output shape
+ m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc(
+ yScaleDataType,
+ outputShape,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+
+ m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc(
+ yZeroPointDataType,
+ outputShape,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+
+ // Validate input tensors
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ // Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2
+ auto tupleStartIndex = 2 + inputIndex * 3;
+ auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType;
+ auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType;
+ ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
+ ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");
+
+ // broadcast x_scale and x_zero_point to shape of corresponding x
+ m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc(
+ xScaleDataType,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+
+ m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc(
+ xZeroPointDataType,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+ }
+
+ uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2);
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ // 1. output edges between Dequantize and Join node
+ // 2. input edge between Join and Quantize node
+ std::vector intermediateOutputTensorDescs(inputCount);
+ std::vector namedDequantizeOperatorDescs(inputCount);
+ std::vector dequantizeOperatorDescs(inputCount);
+ std::vector dmlOpDesc(inputCount);
+ std::vector opDescs;
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ auto tupleStartIndex = 2 + inputIndex * 3;
+ intermediateOutputTensorDescs[inputIndex] = TensorDesc(
+ MLOperatorTensorDataType::Float,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment)
+ );
+ namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc();
+
+ dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
+ dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
+ dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
+ dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
+
+ dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
+ opDescs.push_back(&dmlOpDesc[inputIndex]);
+ }
+
+ TensorDesc joinOutputTensorDesc = TensorDesc(
+ MLOperatorTensorDataType::Float,
+ outputShape,
+ outputShape,
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+ DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc();
+
+ DML_JOIN_OPERATOR_DESC joinDesc = {};
+ joinDesc.InputCount = gsl::narrow_cast(namedDequantizeOperatorDescs.size());
+ joinDesc.InputTensors = namedDequantizeOperatorDescs.data();
+ joinDesc.OutputTensor = &namedJoinOutputTensorDesc;
+ joinDesc.Axis = dmlAxis;
+
+ const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc};
+ opDescs.push_back(&opJoinDesc);
+
+ DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {};
+ quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor;
+ quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
+ quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
+ quantizeOperatorDesc.OutputTensor = &outputDescs[0];
+ const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
+ opDescs.push_back(&opQuantizeDesc);
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.nodeCount = static_cast(opDescs.size());
+ operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+
+ uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2;
+ uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1;
+
+ std::vector inputEdges;
+ // Input edges to Dequantize nodes
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ auto tupleStartIndex = 2 + inputIndex * 3;
+ for (auto edge_index = 0; edge_index < 3; ++edge_index)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
+ inputEdge.GraphInputIndex = tupleStartIndex + edge_index;
+ inputEdge.ToNodeIndex = inputIndex;
+ inputEdge.ToNodeInputIndex = edge_index;
+ inputEdges.push_back(inputEdge);
+ }
+ }
+
+ // Input edge from y_scale to quantize node
+ DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {};
+ yScaleInputEdge.GraphInputIndex = 0; // Y_scale
+ yScaleInputEdge.ToNodeIndex = quantizeNodeIndex;
+ yScaleInputEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(yScaleInputEdge);
+
+ // Input edge from y_zero_point to quantize node
+ DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {};
+ yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point
+ yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex;
+ yZeroPointInputEdge.ToNodeInputIndex = 2;
+ inputEdges.push_back(yZeroPointInputEdge);
+
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+
+ // set intermediate edges
+ std::vector intermediateEdges;
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {};
+ dequantizeToJoinEdge.FromNodeIndex = inputIndex;
+ dequantizeToJoinEdge.FromNodeOutputIndex = 0;
+ dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join
+ dequantizeToJoinEdge.ToNodeInputIndex = inputIndex;
+ intermediateEdges.push_back(dequantizeToJoinEdge);
+ }
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {};
+ joinToQuantizeEdge.FromNodeIndex = joinNodeIndex;
+ joinToQuantizeEdge.FromNodeOutputIndex = 0;
+ joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join
+ joinToQuantizeEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(joinToQuantizeEdge);
+
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+
+ // set the output edges
+ std::vector outputEdges;
+ DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
+ outputEdge.FromNodeIndex = quantizeNodeIndex;
+ outputEdge.FromNodeOutputIndex = 0;
+ outputEdge.GraphOutputIndex = 0;
+ outputEdges.push_back(outputEdge);
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+ };
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat);
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index dbe9f5da4f569..fa2750a22425f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -496,6 +496,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
+DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
@@ -547,6 +548,7 @@ constexpr static std::array typeNameListEyeLike = { "T1", "T2" }
constexpr static std::array typeNameShape = { "T", "T1" };
constexpr static std::array typeNameSize = { "T", "T1" };
constexpr static std::array typeNameListGroupNorm = {"T", "M"};
+constexpr static std::array typeNameListQLinearConcat= {"TF", "T8", "TV"};
constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All};
constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
@@ -618,7 +620,18 @@ constexpr static std::array supportedTypeListQLinea
constexpr static std::array supportedTypeListDynamicQuantizeLinear = {
SupportedTensorDataTypes::Float32,
- SupportedTensorDataTypes::UInt8,
+ SupportedTensorDataTypes::Ints8Bit
+};
+
+constexpr static std::array supportedTypeListDynamicQuantizeMatMul= {
+ SupportedTensorDataTypes::Float32,
+ SupportedTensorDataTypes::Ints8Bit,
+};
+
+constexpr static std::array supportedTypeListQLinearConcat= {
+ SupportedTensorDataTypes::Float32,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};
template
@@ -1012,6 +1025,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp
index d8290bbdaee3e..2965fa32ce131 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp
@@ -419,9 +419,9 @@ namespace Dml
} // namespace FusionHelpers
- uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount)
+ uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex)
{
- const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
+ const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex);
uint32_t onnxDimCount = gsl::narrow_cast(inputDimensions.size());
onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount);
return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h
index f0fad6a05ffb0..8b2da6084242d 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h
@@ -64,8 +64,7 @@ namespace Dml
} // namespace FusionHelpers
// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
- // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
- uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount);
+ uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0);
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 370f336ff5203..4d59964dcc664 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -1862,7 +1862,7 @@ namespace OperatorHelper
return { std::move(outputShape) };
}
- void ConcatHelper::Initialize(
+ void ConcatHelperBase::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span inputDimensions
)
@@ -1872,13 +1872,13 @@ namespace OperatorHelper
ML_CHECK_VALID_ARGUMENT(m_axis < static_cast(inputDimensions.size()));
}
- std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ std::vector ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const
{
- auto outputShape = shapeInfo.GetInputTensorShape(0);
+ auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex);
uint32_t inputCount = shapeInfo.GetInputCount();
- for (uint32_t i = 1; i < inputCount; ++i)
+ for (uint32_t i = firstInputIndex + step; i < inputCount; i += step)
{
auto inputShape = shapeInfo.GetInputTensorShape(i);
for (size_t j = 0; j < outputShape.size(); ++j)
@@ -1893,6 +1893,16 @@ namespace OperatorHelper
return { EdgeShapes(outputShape) };
}
+ std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1);
+ }
+
+ std::vector QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3);
+ }
+
void CropHelper::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span inputDimensions
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index f7e545d9d99a9..55a01c59ee4b5 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -864,7 +864,7 @@ class RecurrentHelper
int m_hiddenSize = 0;
};
-class ConcatHelper
+class ConcatHelperBase
{
public:
void Initialize(
@@ -875,17 +875,33 @@ class ConcatHelper
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template
- ConcatHelper(const Info_t& info, const Shape_t& shape)
+ ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex)
{
- Initialize(info, shape.GetInputTensorShape(0));
+ Initialize(info, shape.GetInputTensorShape(firstInputIndex));
}
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const;
protected:
int m_axis;
};
+class ConcatHelper: public ConcatHelperBase
+{
+public:
+ template
+ ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+};
+
+class QLinearConcatHelper: public ConcatHelperBase
+{
+public:
+ template
+ QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+};
+
class CropHelper
{
public:
@@ -1512,6 +1528,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper;
using ShapeInferenceHelper_Split18 = VersionedOpsetHelper;
using ShapeInferenceHelper_Transpose = TransposeHelper;
using ShapeInferenceHelper_Concat = ConcatHelper;
+using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper;
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper;
using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper; // Note 11 and 10 are identical - no functional change.
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 3eb35faeba82f..996ea1ddcb52c 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -443,6 +443,7 @@ namespace OperatorHelper
static const int sc_sinceVer_BiasAdd = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
+ static const int sc_sinceVer_QLinearConcat = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
} // namespace MsftOperatorSet1
From cb7f28a16ab78601363c2e694679d2dada149dd1 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Fri, 3 Nov 2023 09:43:49 -0700
Subject: [PATCH 04/25] Register Resize for INT8 and UINT8 (#18252)
### Description
### Motivation and Context
Co-authored-by: Adrian Tsai
---
.../DmlExecutionProvider/src/Operators/OperatorRegistration.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index fa2750a22425f..d7910a6c6849f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -589,7 +589,7 @@ constexpr static std::array supportedTypeListLogica
constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool };
constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32};
-constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */};
+constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */};
constexpr static std::array supportedTypeListResize13 = supportedTypeListResize11;
constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
From dcfff10f57fbbdf81ea9ebcae008be712be42700 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Mon, 6 Nov 2023 09:09:11 -0800
Subject: [PATCH 05/25] Enable QLinearAveragePooling DML EP (#17384) (#18240)
[Cherry Pick Reviewed]
DML EP Implementation for
[QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool)
```
Note: Google Test filter = *QLinear*Pool*
[==========] Running 72 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 36 tests from QLinearGlobalAveragePool
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32
[ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1
[ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms)
[----------] 36 tests from QLinearGlobalAveragePool (5968 ms total)
[----------] 36 tests from QLinearPoolTest
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage
[ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global
[ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8
[ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global_S8
[ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms)
[----------] 36 tests from QLinearPoolTest (12914 ms total)
[----------] Global test environment tear-down
[==========] 72 tests from 2 test suites ran. (18885 ms total)
[ PASSED ] 72 tests.
memleakdbg:
----- No memory leaks detected -----
```
### Description
### Motivation and Context
---
.../src/External/DirectMLHelpers/ApiTraits.h | 20 ++-
.../External/DirectMLHelpers/DirectMLSchema.h | 25 +++
.../DirectMLHelpers/GeneratedSchemaHelpers.h | 26 +++
.../DmlOperatorQLinearAveragePooling.cpp | 150 ++++++++++++++++++
.../src/Operators/OperatorRegistration.cpp | 8 +
.../DmlExecutionProvider/src/TensorDesc.cpp | 36 +++++
.../dml/DmlExecutionProvider/src/TensorDesc.h | 3 +
.../dml/OperatorAuthorHelper/Attributes.h | 2 +-
.../OperatorAuthorHelper/OperatorHelper.cpp | 42 ++++-
.../dml/OperatorAuthorHelper/OperatorHelper.h | 28 +++-
.../OperatorAuthorHelper/OperatorVersions.h | 2 +
.../qlinear_global_average_pool_test.cc | 3 +
12 files changed, 339 insertions(+), 6 deletions(-)
create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index 94f2220fcc168..a5415ba85f3d3 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -24,7 +24,7 @@ struct EnumTraits
template <>
struct EnumTraits
{
- static constexpr auto ValueCount = 160;
+ static constexpr auto ValueCount = 161;
static constexpr size_t ActivationFunctionCount = 24;
};
@@ -495,6 +495,12 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
+};
+
template <>
struct OperatorDescTraits
{
@@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
using DescType = DML_ROI_POOLING_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
+{
+ using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
+};
+
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
{
@@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
case DML_OPERATOR_ACTIVATION_GELU:
return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...);
+#pragma warning(push)
+#pragma warning(disable: 4063)
+ case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+ return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...);
+#pragma warning(pop)
+
default:
ORT_THROW_HR(E_INVALIDARG);
return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 9eae1c1fe8158..2a82c12872a72 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
};
+
+constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
+ "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
+ static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 13,
+ DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index ad4cceb85cfd2..99218c135f058 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -502,6 +502,24 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))),
};
}
+inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))),
+ };
+}
inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc)
{
return {
@@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_GELU_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+#pragma warning(push)
+#pragma warning(disable: 4063)
+ case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+ return AbstractOperatorDesc(
+ &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
+#pragma warning(pop)
+
default:
ORT_THROW_HR(E_INVALIDARG);
return AbstractOperatorDesc(
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
new file mode 100644
index 0000000000000..0fccedfe311c1
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
@@ -0,0 +1,150 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+
+class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase
+{
+ // For QLinear Avg Pool ORT and DML have same indexing order
+ enum OrtInputTensors : uint32_t
+ {
+ ortInput,
+ ortInputScale,
+ ortInputZeroPoint,
+ ortOutputScale,
+ ortOutputZeroPoint,
+ ortInputCount
+ };
+
+public:
+ using Self = DmlOperatorQLinearAveragePooling;
+
+ DmlOperatorQLinearAveragePooling(
+ const MLOperatorKernelCreationContext& kernelInfo,
+ bool useGlobalPooling
+ )
+ : DmlOperator(kernelInfo),
+ PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling)
+ {
+ DmlOperator::Initialize(kernelInfo);
+
+ bool isNhwc = m_kernel.channelsLast;
+ std::vector inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput);
+ std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
+
+ uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount();
+ ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2);
+
+ // DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling
+ uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
+ if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
+ {
+ size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount;
+
+ for (int i = gsl::narrow_cast(m_kernel.spatialDimensionCount) - 1; i >= 0; i--)
+ {
+ m_kernel.windowSize[i + shift] = m_kernel.windowSize[i];
+ m_kernel.windowSize[i] = 1;
+
+ m_kernel.strides[i + shift] = m_kernel.strides[i];
+ m_kernel.strides[i] = 1;
+
+ m_kernel.startPadding[i + shift] = m_kernel.startPadding[i];
+ m_kernel.startPadding[i] = 0;
+
+ m_kernel.endPadding[i + shift] = m_kernel.endPadding[i];
+ m_kernel.endPadding[i] = 0;
+
+ m_kernel.dilations[i + shift] = m_kernel.dilations[i];
+ m_kernel.dilations[i] = 1;
+ }
+
+ m_kernel.spatialDimensionCount = expectedSpatialDimCount;
+ }
+
+ // Initialize dimensionMapping for NCHW or NHWC layout
+ std::vector dimensionMapping = {0u, dmlDimSize - 1u};
+ dimensionMapping.resize(dmlDimSize);
+ if (isNhwc)
+ {
+ // Form a remapping for dimensions so C is moved before the spatial dimensions.
+ // e.g. NWC -> {0,2,1} -> NCW
+ // NHWC -> {0,3,1,2} -> NCHW
+ // NDHWC -> {0,4,1,2,3} -> NCDHW
+ std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u);
+ }
+ else
+ {
+ // Use NCHW {0,1,2,3} format with increasing order of indexs
+ std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u);
+ }
+ m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ // Reshape the Input Scale to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ // Reshape the Input ZeroPoint to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint))
+ {
+ m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+ }
+
+ // Reshape the Output Scale to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ // Reshape the Input ZeroPoint to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint))
+ {
+ m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+ }
+
+ // Initialize the output description while overriding the shape
+ m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize));
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {};
+
+ qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
+ qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
+ qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
+ qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
+ qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
+ qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
+ qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
+ qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
+ qLinearAvgPooldesc.Strides = m_kernel.strides;
+ qLinearAvgPooldesc.StartPadding = m_kernel.startPadding;
+ qLinearAvgPooldesc.EndPadding = m_kernel.endPadding;
+ qLinearAvgPooldesc.Dilations = m_kernel.dilations;
+ qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+
+ DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
+ SetDmlOperatorDesc(opDesc, kernelInfo);
+ }
+};
+
+template
+class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling
+{
+public:
+ DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
+ : DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling)
+ {
+ }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate);
+
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index d7910a6c6849f..0234bb6b7ec1e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
+DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool);
+DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
@@ -634,6 +636,10 @@ constexpr static std::array supportedTypeListQLinea
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};
+constexpr static std::array supportedTypeListQLinearAveragePool = {
+ SupportedTensorDataTypes::Ints8Bit
+};
+
template
constexpr auto requiredConstantCpuInputs(Args... args)
{
@@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
+ {REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp
index 067a320dd8000..a2183aab52eed 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp
@@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
}
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}
+
+// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
+void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment)
+{
+ EnsureStridesExist();
+ SetDimensionCount(static_cast(dimensionMapping.size()), alignment);
+
+ // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
+ std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
+ std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount};
+
+ for (size_t i = 0; i < dimensionMapping.size(); i++)
+ {
+ m_sizes[i] = tempSizes[dimensionMapping[i]];
+ m_strides[i] = tempStrides[dimensionMapping[i]];
+ }
+
+ m_bufferTensorDesc.Sizes = m_sizes;
+ m_bufferTensorDesc.Strides = m_strides;
+}
+
+void TensorDesc::EnsureStridesExist()
+{
+ if (m_bufferTensorDesc.Strides != nullptr)
+ {
+ // Strides are populated
+ return;
+ }
+
+ uint32_t stride = 1;
+ for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
+ {
+ m_strides[i] = stride;
+ stride *= m_sizes[i];
+ }
+}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h
index ff70dec5b8871..909e2084d0163 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h
@@ -44,6 +44,7 @@ namespace Dml
gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span GetStrides() const;
void SetStrides(gsl::span strides);
+ void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment);
inline uint64_t GetBufferSizeInBytes() const
{
@@ -90,6 +91,8 @@ namespace Dml
uint32_t m_sizes[MaximumDimensionCount] = {};
uint32_t m_strides[MaximumDimensionCount] = {};
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};
+
+ void EnsureStridesExist();
};
class TensorDescBuilder
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
index e9591cfce6870..85333aa77b686 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
@@ -23,8 +23,8 @@ namespace AttrName
static constexpr const char* BlockSize = "blocksize";
static constexpr const char* Border = "border";
static constexpr const char* Broadcast = "broadcast";
- static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* CeilMode = "ceil_mode";
+ static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* Clip = "clip";
static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode";
static constexpr const char* CountIncludePad = "count_include_pad";
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 4d59964dcc664..1fcd3b04300f4 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -365,13 +365,20 @@ namespace OperatorHelper
}
// Creates a kernel that spans the entire spatial dimensions of the input.
- KernelArgs InitializeGlobalKernel(gsl::span inputDimensions)
+ KernelArgs InitializeGlobalKernel(
+ const MLOperatorAttributes& kernelInfo,
+ gsl::span inputDimensions)
{
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor)
uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount;
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor).
KernelArgs args(spatialDimensionCount);
+ args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0);
+ args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0);
+ // For Global Pooling, kernel size equal to the spatial dimension of input tensor
+ // NHWC layout need to offset by one dim to acount for channel placed at the end
+ int dimOffset = args.channelsLast ? 1 : 0;
for (size_t dim = 0; dim < spatialDimensionCount; ++dim)
{
@@ -379,7 +386,7 @@ namespace OperatorHelper
args.dilations[dim] = 1;
args.startPadding[dim] = 0;
args.endPadding[dim] = 0;
- args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]);
+ args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]);
}
return args;
@@ -495,6 +502,7 @@ namespace OperatorHelper
}
args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0);
+ args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0);
return args;
}
@@ -2012,7 +2020,37 @@ namespace OperatorHelper
}
return outputShapes;
}
+
+ std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ auto inputShape = shapeInfo.GetInputTensorShape(0);
+ std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
+
+ const uint32_t outputCount = shapeInfo.GetOutputCount();
+
+ std::vector outputShapes;
+ for (uint32_t i = 0; i < outputCount; ++i)
+ {
+ outputShapes.push_back(outputDimensions);
+ }
+ return outputShapes;
+ }
+
+ std::vector QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ auto inputShape = shapeInfo.GetInputTensorShape(0);
+ std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
+ const uint32_t outputCount = shapeInfo.GetOutputCount();
+
+ std::vector outputShapes;
+ for (uint32_t i = 0; i < outputCount; ++i)
+ {
+ outputShapes.push_back(outputDimensions);
+ }
+ return outputShapes;
+ }
+
std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 55a01c59ee4b5..d8d09efd8d6e8 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -160,6 +160,7 @@ struct KernelArgs
bool autoPad = false;
bool autoPadSameUpper = false;
bool useCeilingOutputShape = false;
+ bool channelsLast = false;
uint32_t spatialDimensionCount = 0;
KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount)
@@ -188,6 +189,7 @@ struct KernelArgs
KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount)
: autoPad(kernelArgs.autoPad),
autoPadSameUpper(kernelArgs.autoPadSameUpper),
+ channelsLast(kernelArgs.channelsLast),
spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount))
{
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
@@ -211,7 +213,9 @@ std::vector InitializeKernelOutputDimsTranspose(
gsl::span inputDimensions,
const KernelArgs& args);
-KernelArgs InitializeGlobalKernel(gsl::span inputDimensions);
+KernelArgs InitializeGlobalKernel(
+ const MLOperatorAttributes& kernelInfo,
+ gsl::span inputDimensions);
KernelArgs InitializeKernel(
const MLOperatorAttributes& kernelInfo,
@@ -1059,7 +1063,7 @@ class PoolingHelperBase
bool useGlobalPooling
)
: m_kernel(useGlobalPooling
- ? InitializeGlobalKernel(shape.GetInputTensorShape(0))
+ ? InitializeGlobalKernel(info, shape.GetInputTensorShape(0))
: InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span()))
{
if (!useGlobalPooling)
@@ -1161,6 +1165,24 @@ class RoiAlignHelper : public RoiPoolingHelperBase
std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
+class QLinearAveragePoolingHelper : public PoolingHelperBase
+{
+public:
+ template
+ QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+};
+
+class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase
+{
+public:
+ template
+ QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+};
+
class SqueezeHelper
{
public:
@@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper;
using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
+using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper;
+using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 996ea1ddcb52c..e9d88adf3e221 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -445,6 +445,8 @@ namespace OperatorHelper
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_QLinearConcat = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
+ static const int sc_sinceVer_QLinearAveragePool = 1;
+ static const int sc_sinceVer_QLinearGlobalAveragePool = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper
diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc
index 8fb245819fd26..71b6f27b5391f 100644
--- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc
+++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc
@@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool(
test.AddInput("y_scale", {}, {y_scale});
test.AddInput("y_zero_point", {}, {y_zero_point});
test.AddOutput("Y", y_dims, y_data);
+ if (channels_last) {
+ test.AddAttribute("channels_last", (int64_t)1LL);
+ }
auto q8checker = [&](const std::vector& fetches, const std::string& provider_type) {
const OrtValue& ort_value = fetches[0];
From d5f3aae3fd30a79d9bbeed26a70907618ac84835 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Fri, 17 Nov 2023 16:43:09 -0800
Subject: [PATCH 06/25] Utilize DML constant input graph node (#18267)
### Description
This PR also includes,
8b0a55e7cc DML constant pow operator
7520974970 Enable custom heaps based on query-
### Motivation and Context
---------
Co-authored-by: Jeff Bloomfield
---
.../src/DmlGraphFusionHelper.cpp | 27 ++++-
.../src/ExecutionProvider.cpp | 31 ++++++
.../src/ExecutionProvider.h | 2 +
.../src/GraphDescBuilder.cpp | 104 ++++++++++++++----
.../src/GraphDescBuilder.h | 5 +-
.../src/IExecutionProvider.h | 1 +
.../src/MLOperatorAuthorImpl.cpp | 29 ++++-
.../src/MLOperatorAuthorImpl.h | 7 ++
.../src/Operators/DmlOperatorElementWise.cpp | 38 +++++--
.../MLOperatorAuthorHelper.h | 13 +++
.../MLOperatorAuthorPrivate.h | 10 ++
11 files changed, 229 insertions(+), 38 deletions(-)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
index 4f7ec188140b5..18cdc5d1bf86e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
@@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper
{
ComPtr initializeInputBuffer;
- // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
- if (providerImpl->IsMcdmDevice())
+ if (!providerImpl->CustomHeapsSupported())
{
initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize);
}
@@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper
const uint32_t inputCount,
const uint32_t outputCount,
_Inout_ std::vector& dmlOperatorGraphNodes,
+ _Inout_ std::vector& dmlConstantGraphNodes,
_Inout_ std::vector& dmlGraphNodes,
_Inout_ std::vector& dmlInputEdges,
_Inout_ std::vector& dmlOutputEdges,
@@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
auto& nodeInfo = graphDesc.nodes[i];
- dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
- dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
+
+ if (std::holds_alternative>(nodeInfo.nodeDef))
+ {
+ dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
+ dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
+ }
+ else
+ {
+ auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef);
+ dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{
+ nodeDefinitionData.data(),
+ nodeDefinitionData.size(),
+ nodeInfo.name.data()
+ };
+
+ // TODO: Change as new header is ingested
+ dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]};
+ }
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
@@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper
// convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector dmlOperatorGraphNodes(graphDesc.nodes.size());
+ std::vector dmlConstantGraphNodes(graphDesc.nodes.size());
+
std::vector dmlGraphNodes(graphDesc.nodes.size());
std::vector dmlInputEdges(graphDesc.inputEdges.size());
std::vector dmlOutputEdges(graphDesc.outputEdges.size());
@@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper
fusedNodeInputCount,
fusedNodeOutputCount,
dmlOperatorGraphNodes,
+ dmlConstantGraphNodes,
dmlGraphNodes,
dmlInputEdges,
dmlOutputEdges,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 8644b8d56a426..49a64c4810252 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -182,6 +182,32 @@ namespace Dml
}
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
+ m_areCustomHeapsSupported = !m_isMcdmDevice;
+
+ if (m_isMcdmDevice)
+ {
+
+ // TODO: Ingest updated header file
+ typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19
+ {
+ BOOL MismatchingOutputDimensionsSupported;
+ UINT SupportedSampleCountsWithNoOutputs;
+ BOOL PointSamplingAddressesNeverRoundUp;
+ BOOL RasterizerDesc2Supported;
+ BOOL NarrowQuadrilateralLinesSupported;
+ BOOL AnisoFilterWithPointMipSupported;
+ UINT MaxSamplerDescriptorHeapSize;
+ UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers;
+ UINT MaxViewDescriptorHeapSize;
+ _Out_ BOOL ComputeOnlyCustomHeapSupported;
+ } D3D12_FEATURE_DATA_D3D12_OPTIONS19;
+
+ D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {};
+
+ // The call may fail in which case the default value is false
+ d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19));
+ m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported;
+ }
m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
@@ -1089,6 +1115,11 @@ namespace Dml
return m_isMcdmDevice;
}
+ bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept
+ {
+ return m_areCustomHeapsSupported;
+ }
+
bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept
{
return m_areMetacommandsEnabled;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
index 3aaa11cdee479..ab932fb8a4367 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
@@ -150,6 +150,7 @@ namespace Dml
}
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
+ STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
bool DynamicGraphFusionEnabled() const noexcept;
@@ -186,6 +187,7 @@ namespace Dml
ComPtr m_d3d12Device;
ComPtr m_dmlDevice;
bool m_isMcdmDevice = false;
+ bool m_areCustomHeapsSupported = false;
bool m_areMetacommandsEnabled = true;
bool m_dynamicGraphFusionEnabled = false;
bool m_native16BitShaderOpsSupported = false;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
index 3fc8f415e5a58..ba022533a1e94 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
@@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder
const std::unordered_map>& isInitializerTransferable,
const std::unordered_map& graphNodePropertyMap,
IDMLDevice* device,
- const void* executionHandle,
+ const ExecutionProviderImpl* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span subgraphNodes,
gsl::span subgraphInputs,
@@ -198,7 +198,7 @@ namespace Dml::GraphDescBuilder
const uint32_t minNodeCountToReuseCommandList = 5;
bool reuseCommandList = false;
- if (subgraphNodes.size() >= minNodeCountToReuseCommandList)
+ if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice())
{
reuseCommandList = true;
}
@@ -232,14 +232,22 @@ namespace Dml::GraphDescBuilder
{
ComPtr tensor = nullptr;
- // Check whether this specific node requested support for constant CPU inputs
- if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end())
+ auto inputDefs = node.InputDefs();
+
+ if (inputIndex < inputDefs.size())
{
- auto inputDefs = node.InputDefs();
- if (inputIndex < inputDefs.size())
+ const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
+ tensor = constantCpuGraphInputGetter(arg->Name());
+
+ if (tensor == nullptr)
{
- const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
- tensor = constantCpuGraphInputGetter(arg->Name());
+ bool inputRequiredAsConstant = std::find(
+ requiredConstantCpuInputs.begin(),
+ requiredConstantCpuInputs.end(),
+ inputIndex) != requiredConstantCpuInputs.end();
+
+ // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
+ ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
}
}
@@ -289,6 +297,7 @@ namespace Dml::GraphDescBuilder
std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap;
uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size());
const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0;
+ size_t firstOpDescGraphNodeIndex = graphNodes.size();
if (isNodeAsOpDesc)
{
@@ -298,6 +307,8 @@ namespace Dml::GraphDescBuilder
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]);
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
}
+
+ graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount);
}
else
{
@@ -306,7 +317,7 @@ namespace Dml::GraphDescBuilder
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get());
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
NodeInfo nodeInfo = {};
- nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
+ nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
graphNodes.push_back(std::move(nodeInfo));
}
}
@@ -328,21 +339,59 @@ namespace Dml::GraphDescBuilder
const uint32_t dmlFusedNodeInputIndex = iter->second;
- DML_INPUT_GRAPH_EDGE_DESC edge = {};
- edge.GraphInputIndex = dmlFusedNodeInputIndex;
- edge.ToNodeIndex = mainGraphNodeIndex;
- edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex
- graphInputEdges.push_back(edge);
-
// If this is a constant input, set the appropriate flags on the desc
if (isNodeAsOpDesc &&
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
isConstGpuGraphInput[dmlFusedNodeInputIndex])
{
- auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
- std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors();
- DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
- tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
+ // This is a highly inefficient approach to generating constant nodes. It duplicates constant data
+ // across the graph input as well as every consumer's unique constant node. However it is currently
+ // only used for small inputs.
+
+ // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
+ // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
+ // values that have been de-duplicated with conversion to scalars in kernels.
+ uint32_t c_maxConstNodeDataSize = 1024 * 1024;
+
+ ComPtr constantInput = constantCpuGraphInputGetter(arg->Name());
+
+ if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize)
+ {
+ auto data = static_cast(constantInput->GetData());
+ std::vector tensorData(data, data + constantInput->GetTensorByteSize());
+
+ NodeInfo nodeInfo = {};
+ nodeInfo.nodeDef = std::move(tensorData);
+ graphNodes.push_back(std::move(nodeInfo));
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
+ edge.FromNodeIndex = static_cast(graphNodes.size() - 1);
+ edge.FromNodeOutputIndex = 0;
+ edge.ToNodeIndex = mainGraphNodeIndex;
+ edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
+ graphIntermediateEdges.push_back(edge);
+ }
+ else
+ {
+ DML_INPUT_GRAPH_EDGE_DESC edge = {};
+ edge.GraphInputIndex = dmlFusedNodeInputIndex;
+ edge.ToNodeIndex = mainGraphNodeIndex;
+ edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
+ graphInputEdges.push_back(edge);
+
+ auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
+ std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors();
+ DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
+ tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
+ }
+ }
+ else
+ {
+ DML_INPUT_GRAPH_EDGE_DESC edge = {};
+ edge.GraphInputIndex = dmlFusedNodeInputIndex;
+ edge.ToNodeIndex = mainGraphNodeIndex;
+ edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
+ graphInputEdges.push_back(edge);
}
}
else
@@ -387,17 +436,28 @@ namespace Dml::GraphDescBuilder
if (isNodeAsOpDesc)
{
- for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc)
+ for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i)
{
+ auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
+
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
+
+ // TODO: Change as new header is ingested
+ if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
+ dmlDesc.Type = (DML_OPERATOR_TYPE) 169;
+
+ // TODO: Change as new header is ingested
+ if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
+ dmlDesc.Type = (DML_OPERATOR_TYPE) 170;
+
ComPtr op;
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
allocator.Reset();
NodeInfo nodeInfo = {};
- nodeInfo.op = std::move(op);
+ nodeInfo.nodeDef = std::move(op);
nodeInfo.name = node.Name();
- graphNodes.push_back(std::move(nodeInfo));
+ graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo);
}
}
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
index 0039678c00e59..c95e89b45541b 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
@@ -4,6 +4,7 @@
#pragma once
#include "MLOperatorAuthorImpl.h"
+#include "ExecutionProvider.h"
namespace Dml
{
@@ -27,7 +28,7 @@ namespace Dml
struct NodeInfo
{
- Microsoft::WRL::ComPtr op;
+ std::variant, std::vector> nodeDef;
std::string name;
};
@@ -47,7 +48,7 @@ namespace Dml
const std::unordered_map>& isInitializerTransferable,
const std::unordered_map& graphNodePropertyMap,
IDMLDevice* device,
- const void* executionHandle,
+ const ExecutionProviderImpl* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span subgraphNodes,
gsl::span subgraphInputs,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h
index a8a6d6745e908..17fd7c18ba4a1 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h
@@ -76,6 +76,7 @@ namespace Dml
STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0;
STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0;
+ STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0;
};
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
index 4deec620fe5fb..dbd06abf82f72 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
@@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
}
ORT_CATCH_RETURN
}
-
+
template
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
{
@@ -1153,6 +1153,33 @@ namespace Windows::AI::MachineLearning::Adapter
ORT_CATCH_RETURN
}
+ template
+ HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::TryGetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
+ {
+ ORT_TRY
+ {
+ auto constantInput = m_constantInputGetter(inputIndex);
+ ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative