diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp index 09922310b56c1..facac88985e20 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp @@ -17,6 +17,14 @@ namespace Dml { + GraphTransformer::GraphTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) : onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + onnxruntime::common::Status GraphTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -27,7 +35,7 @@ namespace Dml // Perform fusion { bool transformModifiedGraph = false; - PerformOperatorFusion(&graph, &transformModifiedGraph); + PerformOperatorFusion(&graph, m_providerImpl->IsMcdmDevice(), &transformModifiedGraph); modified |= transformModifiedGraph; if (modified) @@ -50,7 +58,7 @@ namespace Dml return ss.str(); } - void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const + void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const { struct NodeToAdd { @@ -112,7 +120,8 @@ namespace Dml gsl::narrow_cast(node.InputDefs().size()), outputNode.OpType(), outputNode.Domain(), - outputNode.Op()->SinceVersion()); + outputNode.Op()->SinceVersion(), + isMcdmDevice); if (!fusedOpProperties) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h index a7f8186fb3b64..b4e2999ba925b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h @@ -10,6 +10,7 @@ namespace Dml { + class ExecutionProviderImpl; // Applies transforms to a Lotus graph. The graph transformer is responsible for setting the execution provider // on the graph nodes which DML supports. @@ -17,16 +18,17 @@ namespace Dml { public: GraphTransformer( - const std::string& name - ) : onnxruntime::GraphTransformer(name) - { - } + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); private: onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final; private: - void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const; + void PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const; + + const ExecutionProviderImpl* m_providerImpl = nullptr; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index 2965fa32ce131..ff4b2b13c5d2d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -132,6 +132,8 @@ namespace Dml std::string_view domain; int sinceVersion; std::vector activationFilter; + bool enableOnMcdm; + std::vector extraMcdmActivationFilter; std::optional inputCountFilter; }; @@ -142,10 +144,10 @@ namespace Dml static const OperatorInfo c_fusableOps[] = { - OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv }, - OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv }, - OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose }, - OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose }, + OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_BatchNormalization }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_BatchNormalization }, OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_BatchNormalization }, @@ -163,11 +165,11 @@ namespace Dml OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul }, // The filter for activation functions maps to what DML's fused op internally fuses at the shader level. - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, - OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, - OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 }, }; // Not all activations can be fused - only simple elementwise activations (i.e. activation functions which @@ -205,7 +207,8 @@ namespace Dml int candidateOpInputCount, std::string_view activationOpType, std::string_view activationOpDomain, - int activationOpSinceVersion) + int activationOpSinceVersion, + bool isMcdmDevice) { auto opIt = std::find( std::begin(c_fusableOps), @@ -233,6 +236,20 @@ namespace Dml return std::nullopt; } + if (isMcdmDevice) + { + if (!opIt->enableOnMcdm) + { + return std::nullopt; + } + + if (!opIt->extraMcdmActivationFilter.empty() && + std::find(opIt->extraMcdmActivationFilter.begin(), opIt->extraMcdmActivationFilter.end(), activationOpType) == opIt->extraMcdmActivationFilter.end()) + { + return std::nullopt; + } + } + if (opIt->inputCountFilter && *opIt->inputCountFilter != static_cast(candidateOpInputCount)) { return std::nullopt; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index 8b2da6084242d..d3483cb5e8de2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -40,7 +40,8 @@ namespace Dml int candidateOpInputCount, std::string_view activationOpType, std::string_view activationOpDomain, - int activationOpSinceVersion); + int activationOpSinceVersion, + bool isMcdmDevice); // Returns true if the given activation operator type supports being fused with a fusable operator, false // otherwise. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 34aa1093457db..377f33d92a383 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1524,7 +1524,8 @@ common::Status InferenceSession::Initialize() { // This transformer applies DML-specific fusions that go beyond what ORT offers by default bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2; if (dml_operator_fusion_enabled) { - std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer"); + std::unique_ptr dmlOperatorFusionTransformer = std::make_unique("DmlOperatorFusionTransformer", + execution_providers_.Get(kDmlExecutionProvider)); if (dmlOperatorFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr"); }