diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 099777b85452d..2c9d27b0546f6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -20,7 +20,6 @@ namespace Dml const onnxruntime::OpKernelInfo& kernelInfo, std::shared_ptr indexedSubGraph, const onnxruntime::Path& modelPath, - std::shared_ptr>> inputDimParams, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, @@ -30,7 +29,6 @@ namespace Dml : OpKernel(kernelInfo), m_indexedSubGraph(std::move(indexedSubGraph)), m_modelPath(modelPath), - m_inputDimParams(std::move(inputDimParams)), m_subgraphNodes(std::move(subgraphNodes)), m_subgraphInputs(std::move(subgraphInputs)), m_subgraphOutputs(std::move(subgraphOutputs)), @@ -68,8 +66,6 @@ namespace Dml std::vector>& initializeResourceRefs, std::vector initInputBindings) const { - std::optional persistentResourceBinding; - // Allocate a persistent resource and initialize the operator UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) @@ -80,12 +76,12 @@ namespace Dml m_persistentResource.GetAddressOf(), m_persistentResourceAllocatorUnk.GetAddressOf())); - persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; } ORT_THROW_IF_FAILED(m_provider->InitializeOperator( m_compiledExecutionPlanOperator.Get(), - persistentResourceBinding ? &*persistentResourceBinding : nullptr, + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, gsl::make_span(initInputBindings))); // Queue references to objects which must be kept alive until resulting GPU work completes @@ -303,17 +299,10 @@ namespace Dml ComPtr m_winmlProvider; ComPtr m_provider; - // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. - ComPtr m_graphicsCommandList; - ComPtr m_commandAllocator; - ComPtr m_heap; - ComPtr m_bindingTable; - std::optional m_persistentResourceBinding; + mutable std::optional m_persistentResourceBinding; std::shared_ptr m_indexedSubGraph; const onnxruntime::Path& m_modelPath; - // TODO (pavignol): Remove m_inputDimParams if truly not needed - std::shared_ptr>> m_inputDimParams; std::vector> m_subgraphNodes; std::vector m_subgraphInputs; std::vector m_subgraphOutputs; @@ -326,26 +315,17 @@ namespace Dml // Bindings from previous executions of a re-used command list mutable std::vector> m_ownedCpuInputs; mutable ComPtr m_compiledExecutionPlanOperator; - mutable std::vector m_inputBindingAllocIds; - mutable std::vector m_outputBindingAllocIds; - mutable uint64_t m_tempBindingAllocId = 0; mutable std::vector m_inputsUsed; mutable ComPtr m_persistentResource; mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; mutable std::unordered_map m_inferredInputShapes; - - // Fence tracking the status of the command list's last execution, and whether its descriptor heap - // can safely be updated. - mutable ComPtr m_fence; - mutable uint64_t m_completionValue = 0; }; onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& info, std::shared_ptr indexedSubGraph, const onnxruntime::Path& modelPath, - std::shared_ptr>> inputDimParams, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, @@ -357,7 +337,6 @@ namespace Dml info, std::move(indexedSubGraph), modelPath, - std::move(inputDimParams), std::move(subgraphNodes), std::move(subgraphInputs), std::move(subgraphOutputs), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h index d18a6d4671bc4..d679c5aa5667c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -11,7 +11,6 @@ namespace Dml const onnxruntime::OpKernelInfo& info, std::shared_ptr indexedSubGraph, const onnxruntime::Path& modelPath, - std::shared_ptr>> inputDimParams, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp index bd787dbfb4382..71ef8e7962f6c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp @@ -369,9 +369,6 @@ namespace DmlRuntimeGraphFusionHelper subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); } - // We store the input dim params that haven't been overriden yet so that we can map their value at runtime once the real inputs are provided - auto inputDimParams = std::make_shared>>(); - // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph std::vector ownedInitializers; ownedInitializers.reserve(isInitializerTransferable.size()); @@ -393,7 +390,6 @@ namespace DmlRuntimeGraphFusionHelper // lamda captures for the kernel registration auto fused_kernel_func = [ - inputDimParams, indexedSubGraph, &modelPath, nodesInfo = std::move(nodesInfo), @@ -422,7 +418,6 @@ namespace DmlRuntimeGraphFusionHelper info, indexedSubGraph, modelPath, - std::move(inputDimParams), std::move(subgraphNodes), std::move(subgraphInputs), std::move(subgraphOutputs), @@ -453,26 +448,6 @@ namespace DmlRuntimeGraphFusionHelper auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); - inputDimParams->resize(fusedNode.InputDefs().size()); - - for (int inputIndex = 0; inputIndex < fusedNode.InputDefs().size(); ++inputIndex) - { - const onnxruntime::NodeArg* inputDef = fusedNode.InputDefs()[inputIndex]; - - ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); - const auto& tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); - - (*inputDimParams)[inputIndex].resize(tensorShape.dim_size()); - - for (int i = 0; i < tensorShape.dim_size(); ++i) - { - if (tensorShape.dim(i).has_dim_param()) - { - (*inputDimParams)[inputIndex][i] = tensorShape.dim(i).dim_param(); - } - } - } - graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); } }