diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 8de2a29fe45ba..ecef48dc6d480 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -34,7 +34,8 @@ namespace Dml Dml::ExecutionContext* execution_context, bool enableMetacommands, bool enableGraphCapture, - bool enableCpuSyncSpinning); + bool enableCpuSyncSpinning, + bool disableMemoryArena); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 6261918267d7d..db45908a2dda4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -186,12 +186,16 @@ namespace Dml } else { - // Free the underlying allocation once queued work has completed. -#ifdef _GAMING_XBOX - m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get()); -#else - m_context->QueueReference(allocInfo->GetResource()); -#endif + if (!m_context->IsClosed()) + { + // Free the underlying allocation once queued work has completed. + #ifdef _GAMING_XBOX + m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get()); + #else + m_context->QueueReference(allocInfo->GetResource()); + #endif + } + allocInfo->DetachResourceWrapper(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 90942f4eef353..8f67174157b8a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -935,7 +935,8 @@ namespace DmlGraphFusionHelper const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes, IWinmlExecutionProvider* winmlProvider, IExecutionProvider* provider, - IUnknown* persistentResourceAllocatorUnknown) + IUnknown* persistentResourceAllocatorUnknown, + bool keepTemporaryResourceAlive) { DML_BINDING_PROPERTIES execBindingProps = compiledExecutionPlanOperator->GetBindingProperties(); @@ -1042,6 +1043,11 @@ namespace DmlGraphFusionHelper } commandListState.tempBindingAllocId = tempAllocId; + + if (keepTemporaryResourceAlive) + { + commandListState.temporaryResource = std::move(tempResource); + } } // Execute the command list and if it succeeds, update the fence value at which this command may be diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index a0e2456334e9f..1a810ca58cf45 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -120,6 +120,7 @@ namespace DmlGraphFusionHelper const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes, IWinmlExecutionProvider* winmlProvider, IExecutionProvider* provider, - IUnknown* persistentResourceAllocatorUnknown); + IUnknown* persistentResourceAllocatorUnknown, + bool keepTemporaryResourceAlive); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h index 5c46021a215a6..6c3c2fb8c7094 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h @@ -15,6 +15,7 @@ namespace Dml Microsoft::WRL::ComPtr heap; Microsoft::WRL::ComPtr bindingTable; Microsoft::WRL::ComPtr persistentResource; + Microsoft::WRL::ComPtr temporaryResource; Microsoft::WRL::ComPtr persistentResourceAllocatorUnknown; // Bindings from previous executions of a re-used command list diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index f2c5ae8418332..10b8b7fe42f86 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -62,10 +62,7 @@ namespace Dml } } - void TranslateAndCompileGraph( - const onnxruntime::OpKernelInfo& kernelInfo, - std::vector>& initializeResourceRefs, - std::vector initInputBindings) const + void TranslateAndCompileGraph(const onnxruntime::OpKernelInfo& kernelInfo, std::vector initInputBindings) const { // Allocate a persistent resource and initialize the operator UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; @@ -84,16 +81,14 @@ namespace Dml m_compiledExecutionPlanOperator.Get(), m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, gsl::make_span(initInputBindings))); - - std::for_each( - initializeResourceRefs.begin(), - initializeResourceRefs.end(), - [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } - ); } onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override { + // Release the references from the previous execution since Flush() isn't called for reusable command lists + auto providerImpl = static_cast(m_provider.Get()); + providerImpl->ReleaseCompletedReferences(); + ORT_THROW_HR_IF(E_UNEXPECTED, static_cast(m_subgraphInputs.size()) != kernelContext->InputCount()); bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; @@ -173,7 +168,6 @@ namespace Dml // Populate input bindings for operator initialization const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); - std::vector> initializeResourceRefs; // For lifetime control std::vector initInputBindings(fusedNodeInputCount); std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); @@ -233,10 +227,7 @@ namespace Dml // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - TranslateAndCompileGraph( - Info(), - initializeResourceRefs, - initInputBindings); + TranslateAndCompileGraph(Info(), initInputBindings); std::vector inputBindings(kernelContext->InputCount()); std::vector inputBindingDescs(kernelContext->InputCount()); @@ -244,8 +235,6 @@ namespace Dml m_reusedCommandLists.clear(); } - auto providerImpl = static_cast(m_provider.Get()); - // When we are capturing a graph, we don't pool the command list and instead transfer it to the execution provider. Captured graph // have the same bindings for their entire lifetime. if (providerImpl->GraphCaptureEnabled() && providerImpl->GetCurrentGraphAnnotationId() != -1 && !providerImpl->GraphCaptured(providerImpl->GetCurrentGraphAnnotationId())) @@ -259,6 +248,11 @@ namespace Dml reusableCommandList->persistentResource = m_persistentResource; reusableCommandList->persistentResourceAllocatorUnknown = m_persistentResourceAllocatorUnknown; + // Keep the temporary resource alive since we won't call ExecuteReusableCommandList again, but will merely replay + // the graph in the future. Therefore, all executions of the graph will use the same temporary resource that was + // allocated here the first time. + constexpr bool keepTemporaryResourceAlive = true; + DmlGraphFusionHelper::ExecuteReusableCommandList( kernelContext, *reusableCommandList, @@ -270,7 +264,8 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get()); + m_persistentResourceAllocatorUnknown.Get(), + keepTemporaryResourceAlive); providerImpl->AppendCapturedGraph(providerImpl->GetCurrentGraphAnnotationId(), std::move(reusableCommandList)); } @@ -288,6 +283,10 @@ namespace Dml m_reusedCommandLists.push_front(std::move(reusableCommandList)); } + // We don't need to keep a reference on the temporary resource once we have recorded into the command list, so the + // memory can be reused by the allocator + constexpr bool keepTemporaryResourceAlive = false; + DmlGraphFusionHelper::ExecuteReusableCommandList( kernelContext, *m_reusedCommandLists.front(), @@ -299,7 +298,8 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get()); + m_persistentResourceAllocatorUnknown.Get(), + keepTemporaryResourceAlive); m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); m_reusedCommandLists.pop_front(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index 9ba2f13bf2751..e7a6fa3d07296 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -87,6 +87,7 @@ namespace Dml D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; } + bool IsClosed() const { return m_closed; } private: Microsoft::WRL::ComPtr m_d3dDevice; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 949cd216d5ad9..fb0275deceed6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -72,7 +72,8 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableSyncSpinning) : + bool enableSyncSpinning, + bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); @@ -85,7 +86,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena); } std::vector> @@ -149,12 +150,13 @@ namespace Dml } } - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, bool disableMemoryArena) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), m_areMetacommandsEnabled(enableMetacommands), m_graphCaptureEnabled(enableGraphCapture), m_cpuSyncSpinningEnabled(enableCpuSyncSpinning), + m_memoryArenaDisabled(disableMemoryArena), m_context(executionContext) { D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; @@ -1181,9 +1183,13 @@ namespace Dml m_context->ReleaseCompletedReferences(); m_uploadHeap->Trim(); - // Allocations after this point are potentially transient and their sizes are - // rounded to enable pooling. - m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + if (!m_memoryArenaDisabled) + { + // Allocations after this point are potentially transient and their sizes are + // rounded to enable pooling. + m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + } + m_sessionInitialized = true; return onnxruntime::common::Status::OK(); @@ -1242,9 +1248,10 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableCpuSyncSpinning) + bool enableCpuSyncSpinning, + bool disableMemoryArena) { - return std::make_unique(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning); + return std::make_unique(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning, disableMemoryArena); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 70bc99184b9bc..29961288a51c5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -38,7 +38,8 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableCpuSyncSpinning); + bool enableCpuSyncSpinning, + bool disableMemoryArena); void ReleaseCompletedReferences(); @@ -207,6 +208,7 @@ namespace Dml std::unordered_set m_graphCapturingDone; bool m_sessionInitialized = false; bool m_cpuSyncSpinningEnabled = false; + bool m_memoryArenaDisabled = false; ComPtr m_context; std::unique_ptr m_uploadHeap; std::unique_ptr m_readbackHeap; @@ -260,7 +262,8 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableSyncSpinning + bool enableSyncSpinning, + bool disableMemoryArena ); std::unique_ptr GetDataTransfer() const final override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 0cbf817a8200d..53538cfb79ab5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -164,6 +164,10 @@ namespace Dml m_reusedCommandLists.push_front(std::move(reusableCommandList)); } + // We don't need to keep a reference on the temporary resource once we have recorded into the command list, so the + // memory can be reused by the allocator + constexpr bool keepTemporaryResourceAlive = false; + DmlGraphFusionHelper::ExecuteReusableCommandList( kernelContext, *m_reusedCommandLists.front(), @@ -175,7 +179,8 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get()); + m_persistentResourceAllocatorUnknown.Get(), + keepTemporaryResourceAlive); m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); m_reusedCommandLists.pop_front(); @@ -245,7 +250,7 @@ namespace Dml persistentResourceBinding, inputBindings, outputBindings)); - } + } private: ComPtr m_compiledExecutionPlanOperator; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 72961b8181a72..e8fe235fc1d46 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -56,6 +56,7 @@ struct DMLProviderFactory : IExecutionProviderFactory { python_api_(python_api) { graph_capture_enabled_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphCapture, "0")); cpu_sync_spinning_enabled_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableCpuSyncSpinning, "0")); + disable_memory_arena_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableMemoryArena, "0")); } ~DMLProviderFactory() override {} @@ -70,6 +71,7 @@ struct DMLProviderFactory : IExecutionProviderFactory { bool metacommands_enabled_ = true; bool graph_capture_enabled_ = false; bool cpu_sync_spinning_enabled_ = false; + bool disable_memory_arena_ = false; bool python_api_ = false; }; @@ -91,7 +93,7 @@ std::unique_ptr DMLProviderFactory::CreateProvider() { execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false); } - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_); return provider; } diff --git a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h index 6418366342161..8cad03c5e85b4 100644 --- a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h +++ b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h @@ -24,3 +24,4 @@ static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization"; static const char* const kOrtSessionOptionsConfigEnableGraphCapture = "ep.dml.enable_graph_capture"; static const char* const kOrtSessionOptionsConfigEnableCpuSyncSpinning = "ep.dml.enable_cpu_sync_spinning"; +static const char* const kOrtSessionOptionsConfigDisableMemoryArena = "ep.dml.disable_memory_arena";