From 50ee1b056c31d1364feb3066255d058142072085 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 31 May 2024 17:24:50 -0700 Subject: [PATCH] [DML EP] Improve memory usage and fix memory leak in graph capture (#20879) Phi-3 vision loads 3 models in memory, which means that we have 3 different sessions, 3 different execution providers and 3 different allocators all loaded at the same time. Since the DML EP uses a bucketized allocator, this results in a lot of memory fragmentation across all 3 models that can only be used by the model itself. To fix that, we can disable the memory arena (term for any kind of allocator that reuses memory in ORT) as an opt-in option. In the case of LLMs, we essentially never need to reallocate memory after the initial graphs have been capture, which means that we gain nothing by using the bucketized allocator, and it causes unnecessary fragmentation. --------- Co-authored-by: Patrice Vignola --- .../inc/DmlExecutionProvider.h | 3 +- .../src/BucketizedBufferAllocator.cpp | 16 +++++--- .../src/DmlGraphFusionHelper.cpp | 8 +++- .../src/DmlGraphFusionHelper.h | 3 +- .../src/DmlReusedCommandListState.h | 1 + .../src/DmlRuntimeFusedGraphKernel.cpp | 38 +++++++++---------- .../src/ExecutionContext.h | 1 + .../src/ExecutionProvider.cpp | 23 +++++++---- .../src/ExecutionProvider.h | 7 +++- .../src/FusedGraphKernel.cpp | 9 ++++- .../providers/dml/dml_provider_factory.cc | 4 +- .../dml/dml_session_options_config_keys.h | 1 + 12 files changed, 73 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 8de2a29fe45ba..ecef48dc6d480 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -34,7 +34,8 @@ namespace Dml Dml::ExecutionContext* execution_context, bool enableMetacommands, bool enableGraphCapture, - bool enableCpuSyncSpinning); + bool enableCpuSyncSpinning, + bool disableMemoryArena); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 6261918267d7d..db45908a2dda4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -186,12 +186,16 @@ namespace Dml } else { - // Free the underlying allocation once queued work has completed. -#ifdef _GAMING_XBOX - m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get()); -#else - m_context->QueueReference(allocInfo->GetResource()); -#endif + if (!m_context->IsClosed()) + { + // Free the underlying allocation once queued work has completed. + #ifdef _GAMING_XBOX + m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get()); + #else + m_context->QueueReference(allocInfo->GetResource()); + #endif + } + allocInfo->DetachResourceWrapper(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 90942f4eef353..8f67174157b8a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -935,7 +935,8 @@ namespace DmlGraphFusionHelper const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes, IWinmlExecutionProvider* winmlProvider, IExecutionProvider* provider, - IUnknown* persistentResourceAllocatorUnknown) + IUnknown* persistentResourceAllocatorUnknown, + bool keepTemporaryResourceAlive) { DML_BINDING_PROPERTIES execBindingProps = compiledExecutionPlanOperator->GetBindingProperties(); @@ -1042,6 +1043,11 @@ namespace DmlGraphFusionHelper } commandListState.tempBindingAllocId = tempAllocId; + + if (keepTemporaryResourceAlive) + { + commandListState.temporaryResource = std::move(tempResource); + } } // Execute the command list and if it succeeds, update the fence value at which this command may be diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index a0e2456334e9f..1a810ca58cf45 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -120,6 +120,7 @@ namespace DmlGraphFusionHelper const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes, IWinmlExecutionProvider* winmlProvider, IExecutionProvider* provider, - IUnknown* persistentResourceAllocatorUnknown); + IUnknown* persistentResourceAllocatorUnknown, + bool keepTemporaryResourceAlive); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h index 5c46021a215a6..6c3c2fb8c7094 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h @@ -15,6 +15,7 @@ namespace Dml Microsoft::WRL::ComPtr heap; Microsoft::WRL::ComPtr bindingTable; Microsoft::WRL::ComPtr persistentResource; + Microsoft::WRL::ComPtr temporaryResource; Microsoft::WRL::ComPtr persistentResourceAllocatorUnknown; // Bindings from previous executions of a re-used command list diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index f2c5ae8418332..10b8b7fe42f86 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -62,10 +62,7 @@ namespace Dml } } - void TranslateAndCompileGraph( - const onnxruntime::OpKernelInfo& kernelInfo, - std::vector>& initializeResourceRefs, - std::vector initInputBindings) const + void TranslateAndCompileGraph(const onnxruntime::OpKernelInfo& kernelInfo, std::vector initInputBindings) const { // Allocate a persistent resource and initialize the operator UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; @@ -84,16 +81,14 @@ namespace Dml m_compiledExecutionPlanOperator.Get(), m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, gsl::make_span(initInputBindings))); - - std::for_each( - initializeResourceRefs.begin(), - initializeResourceRefs.end(), - [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } - ); } onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override { + // Release the references from the previous execution since Flush() isn't called for reusable command lists + auto providerImpl = static_cast(m_provider.Get()); + providerImpl->ReleaseCompletedReferences(); + ORT_THROW_HR_IF(E_UNEXPECTED, static_cast(m_subgraphInputs.size()) != kernelContext->InputCount()); bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; @@ -173,7 +168,6 @@ namespace Dml // Populate input bindings for operator initialization const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); - std::vector> initializeResourceRefs; // For lifetime control std::vector initInputBindings(fusedNodeInputCount); std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); @@ -233,10 +227,7 @@ namespace Dml // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - TranslateAndCompileGraph( - Info(), - initializeResourceRefs, - initInputBindings); + TranslateAndCompileGraph(Info(), initInputBindings); std::vector inputBindings(kernelContext->InputCount()); std::vector inputBindingDescs(kernelContext->InputCount()); @@ -244,8 +235,6 @@ namespace Dml m_reusedCommandLists.clear(); } - auto providerImpl = static_cast(m_provider.Get()); - // When we are capturing a graph, we don't pool the command list and instead transfer it to the execution provider. Captured graph // have the same bindings for their entire lifetime. if (providerImpl->GraphCaptureEnabled() && providerImpl->GetCurrentGraphAnnotationId() != -1 && !providerImpl->GraphCaptured(providerImpl->GetCurrentGraphAnnotationId())) @@ -259,6 +248,11 @@ namespace Dml reusableCommandList->persistentResource = m_persistentResource; reusableCommandList->persistentResourceAllocatorUnknown = m_persistentResourceAllocatorUnknown; + // Keep the temporary resource alive since we won't call ExecuteReusableCommandList again, but will merely replay + // the graph in the future. Therefore, all executions of the graph will use the same temporary resource that was + // allocated here the first time. + constexpr bool keepTemporaryResourceAlive = true; + DmlGraphFusionHelper::ExecuteReusableCommandList( kernelContext, *reusableCommandList, @@ -270,7 +264,8 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get()); + m_persistentResourceAllocatorUnknown.Get(), + keepTemporaryResourceAlive); providerImpl->AppendCapturedGraph(providerImpl->GetCurrentGraphAnnotationId(), std::move(reusableCommandList)); } @@ -288,6 +283,10 @@ namespace Dml m_reusedCommandLists.push_front(std::move(reusableCommandList)); } + // We don't need to keep a reference on the temporary resource once we have recorded into the command list, so the + // memory can be reused by the allocator + constexpr bool keepTemporaryResourceAlive = false; + DmlGraphFusionHelper::ExecuteReusableCommandList( kernelContext, *m_reusedCommandLists.front(), @@ -299,7 +298,8 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get()); + m_persistentResourceAllocatorUnknown.Get(), + keepTemporaryResourceAlive); m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); m_reusedCommandLists.pop_front(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index 9ba2f13bf2751..e7a6fa3d07296 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -87,6 +87,7 @@ namespace Dml D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; } + bool IsClosed() const { return m_closed; } private: Microsoft::WRL::ComPtr m_d3dDevice; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 949cd216d5ad9..fb0275deceed6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -72,7 +72,8 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableSyncSpinning) : + bool enableSyncSpinning, + bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); @@ -85,7 +86,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena); } std::vector> @@ -149,12 +150,13 @@ namespace Dml } } - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, bool disableMemoryArena) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), m_areMetacommandsEnabled(enableMetacommands), m_graphCaptureEnabled(enableGraphCapture), m_cpuSyncSpinningEnabled(enableCpuSyncSpinning), + m_memoryArenaDisabled(disableMemoryArena), m_context(executionContext) { D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; @@ -1181,9 +1183,13 @@ namespace Dml m_context->ReleaseCompletedReferences(); m_uploadHeap->Trim(); - // Allocations after this point are potentially transient and their sizes are - // rounded to enable pooling. - m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + if (!m_memoryArenaDisabled) + { + // Allocations after this point are potentially transient and their sizes are + // rounded to enable pooling. + m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + } + m_sessionInitialized = true; return onnxruntime::common::Status::OK(); @@ -1242,9 +1248,10 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableCpuSyncSpinning) + bool enableCpuSyncSpinning, + bool disableMemoryArena) { - return std::make_unique(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning); + return std::make_unique(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning, disableMemoryArena); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 70bc99184b9bc..29961288a51c5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -38,7 +38,8 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableCpuSyncSpinning); + bool enableCpuSyncSpinning, + bool disableMemoryArena); void ReleaseCompletedReferences(); @@ -207,6 +208,7 @@ namespace Dml std::unordered_set m_graphCapturingDone; bool m_sessionInitialized = false; bool m_cpuSyncSpinningEnabled = false; + bool m_memoryArenaDisabled = false; ComPtr m_context; std::unique_ptr m_uploadHeap; std::unique_ptr m_readbackHeap; @@ -260,7 +262,8 @@ namespace Dml Dml::ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, - bool enableSyncSpinning + bool enableSyncSpinning, + bool disableMemoryArena ); std::unique_ptr GetDataTransfer() const final override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 0cbf817a8200d..53538cfb79ab5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -164,6 +164,10 @@ namespace Dml m_reusedCommandLists.push_front(std::move(reusableCommandList)); } + // We don't need to keep a reference on the temporary resource once we have recorded into the command list, so the + // memory can be reused by the allocator + constexpr bool keepTemporaryResourceAlive = false; + DmlGraphFusionHelper::ExecuteReusableCommandList( kernelContext, *m_reusedCommandLists.front(), @@ -175,7 +179,8 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get()); + m_persistentResourceAllocatorUnknown.Get(), + keepTemporaryResourceAlive); m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); m_reusedCommandLists.pop_front(); @@ -245,7 +250,7 @@ namespace Dml persistentResourceBinding, inputBindings, outputBindings)); - } + } private: ComPtr m_compiledExecutionPlanOperator; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 72961b8181a72..e8fe235fc1d46 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -56,6 +56,7 @@ struct DMLProviderFactory : IExecutionProviderFactory { python_api_(python_api) { graph_capture_enabled_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphCapture, "0")); cpu_sync_spinning_enabled_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableCpuSyncSpinning, "0")); + disable_memory_arena_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableMemoryArena, "0")); } ~DMLProviderFactory() override {} @@ -70,6 +71,7 @@ struct DMLProviderFactory : IExecutionProviderFactory { bool metacommands_enabled_ = true; bool graph_capture_enabled_ = false; bool cpu_sync_spinning_enabled_ = false; + bool disable_memory_arena_ = false; bool python_api_ = false; }; @@ -91,7 +93,7 @@ std::unique_ptr DMLProviderFactory::CreateProvider() { execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false); } - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_); return provider; } diff --git a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h index 6418366342161..8cad03c5e85b4 100644 --- a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h +++ b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h @@ -24,3 +24,4 @@ static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization"; static const char* const kOrtSessionOptionsConfigEnableGraphCapture = "ep.dml.enable_graph_capture"; static const char* const kOrtSessionOptionsConfigEnableCpuSyncSpinning = "ep.dml.enable_cpu_sync_spinning"; +static const char* const kOrtSessionOptionsConfigDisableMemoryArena = "ep.dml.disable_memory_arena";