diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 431113b3e1650..074f13b309181 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -89,8 +89,6 @@ namespace Windows::AI::MachineLearning::Adapter std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; - EdgeShapes outputShapes; - const std::unordered_map>* inferredOutputShapes; }; using GraphNodeFactory = std::function; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index e86737fa9665a..eb068087de4ad 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -492,6 +492,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -499,10 +500,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, outputShapes); - - graphNodeCreateInfo->outputShapes = outputShapes; + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( @@ -510,7 +508,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( executionHandle, true, inputShapesOverrides, - &outputShapes, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 546fb56bc780e..c620859495b15 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -246,7 +246,6 @@ namespace Dml::GraphDescBuilder return tensor; }; - DmlGraphNodeCreateInfo graphNodeCreateInfo; EdgeShapes inputShapesOverrides(node.InputDefs().size()); // Override the input shapes with shapes that were previously inferred @@ -269,19 +268,21 @@ namespace Dml::GraphDescBuilder } } - graphNodeCreateInfo.inferredOutputShapes = &inferredOutputShapes; + EdgeShapes outputShapes; + DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); - ORT_THROW_HR_IF(E_UNEXPECTED, graphNodeCreateInfo.outputShapes.EdgeCount() != node.OutputDefs().size()); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); for (int i = 0; i < node.OutputDefs().size(); ++i) { - inferredOutputShapes[node.OutputDefs()[i]->Name()] = graphNodeCreateInfo.outputShapes.GetShape(i); + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); } // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. @@ -380,7 +381,7 @@ namespace Dml::GraphDescBuilder operatorGraphOutputEdge.FromNodeOutputIndex }; - nodeOutputShapes[arg->Name()] = graphNodeCreateInfo.outputShapes; + nodeOutputShapes[arg->Name()] = outputShapes; } }