Skip to content

Commit

Permalink
[DML EP] Improve memory usage and fix memory leak in graph capture (#…
Browse files Browse the repository at this point in the history
…20879)

Phi-3 vision loads 3 models in memory, which means that we have 3
different sessions, 3 different execution providers and 3 different
allocators all loaded at the same time. Since the DML EP uses a
bucketized allocator, this results in a lot of memory fragmentation
across all 3 models that can only be used by the model itself.

To fix that, we can disable the memory arena (term for any kind of
allocator that reuses memory in ORT) as an opt-in option. In the case of
LLMs, we essentially never need to reallocate memory after the initial
graphs have been capture, which means that we gain nothing by using the
bucketized allocator, and it causes unnecessary fragmentation.

---------

Co-authored-by: Patrice Vignola <[email protected]>
  • Loading branch information
PatriceVignola and Patrice Vignola authored Jun 1, 2024
1 parent ad769f1 commit 50ee1b0
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace Dml
Dml::ExecutionContext* execution_context,
bool enableMetacommands,
bool enableGraphCapture,
bool enableCpuSyncSpinning);
bool enableCpuSyncSpinning,
bool disableMemoryArena);

ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,16 @@ namespace Dml
}
else
{
// Free the underlying allocation once queued work has completed.
#ifdef _GAMING_XBOX
m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get());
#else
m_context->QueueReference(allocInfo->GetResource());
#endif
if (!m_context->IsClosed())
{
// Free the underlying allocation once queued work has completed.
#ifdef _GAMING_XBOX
m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get());
#else
m_context->QueueReference(allocInfo->GetResource());
#endif
}

allocInfo->DetachResourceWrapper();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,8 @@ namespace DmlGraphFusionHelper
const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes,
IWinmlExecutionProvider* winmlProvider,
IExecutionProvider* provider,
IUnknown* persistentResourceAllocatorUnknown)
IUnknown* persistentResourceAllocatorUnknown,
bool keepTemporaryResourceAlive)
{
DML_BINDING_PROPERTIES execBindingProps = compiledExecutionPlanOperator->GetBindingProperties();

Expand Down Expand Up @@ -1042,6 +1043,11 @@ namespace DmlGraphFusionHelper
}

commandListState.tempBindingAllocId = tempAllocId;

if (keepTemporaryResourceAlive)
{
commandListState.temporaryResource = std::move(tempResource);
}
}

// Execute the command list and if it succeeds, update the fence value at which this command may be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ namespace DmlGraphFusionHelper
const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes,
IWinmlExecutionProvider* winmlProvider,
IExecutionProvider* provider,
IUnknown* persistentResourceAllocatorUnknown);
IUnknown* persistentResourceAllocatorUnknown,
bool keepTemporaryResourceAlive);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace Dml
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> heap;
Microsoft::WRL::ComPtr<IDMLBindingTable> bindingTable;
Microsoft::WRL::ComPtr<ID3D12Resource> persistentResource;
Microsoft::WRL::ComPtr<ID3D12Resource> temporaryResource;
Microsoft::WRL::ComPtr<IUnknown> persistentResourceAllocatorUnknown;

// Bindings from previous executions of a re-used command list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ namespace Dml
}
}

void TranslateAndCompileGraph(
const onnxruntime::OpKernelInfo& kernelInfo,
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>>& initializeResourceRefs,
std::vector<DML_BUFFER_BINDING> initInputBindings) const
void TranslateAndCompileGraph(const onnxruntime::OpKernelInfo& kernelInfo, std::vector<DML_BUFFER_BINDING> initInputBindings) const
{
// Allocate a persistent resource and initialize the operator
UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize;
Expand All @@ -84,16 +81,14 @@ namespace Dml
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(initInputBindings)));

std::for_each(
initializeResourceRefs.begin(),
initializeResourceRefs.end(),
[&](ComPtr<ID3D12Resource>& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); }
);
}

onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override
{
// Release the references from the previous execution since Flush() isn't called for reusable command lists
auto providerImpl = static_cast<ExecutionProviderImpl*>(m_provider.Get());
providerImpl->ReleaseCompletedReferences();

ORT_THROW_HR_IF(E_UNEXPECTED, static_cast<ptrdiff_t>(m_subgraphInputs.size()) != kernelContext->InputCount());

bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr;
Expand Down Expand Up @@ -173,7 +168,6 @@ namespace Dml

// Populate input bindings for operator initialization
const uint32_t fusedNodeInputCount = gsl::narrow_cast<uint32_t>(m_indexedSubGraph->GetMetaDef()->inputs.size());
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> initializeResourceRefs; // For lifetime control
std::vector<DML_BUFFER_BINDING> initInputBindings(fusedNodeInputCount);
std::vector<uint8_t> isInputsUploadedByDmlEP(fusedNodeInputCount);
auto providerImpl = static_cast<const ExecutionProvider*>(Info().GetExecutionProvider())->GetImpl();
Expand Down Expand Up @@ -233,19 +227,14 @@ namespace Dml
// Queue references to objects which must be kept alive until resulting GPU work completes
m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get());

TranslateAndCompileGraph(
Info(),
initializeResourceRefs,
initInputBindings);
TranslateAndCompileGraph(Info(), initInputBindings);

std::vector<DML_BUFFER_BINDING> inputBindings(kernelContext->InputCount());
std::vector<DML_BINDING_DESC> inputBindingDescs(kernelContext->InputCount());

m_reusedCommandLists.clear();
}

auto providerImpl = static_cast<ExecutionProviderImpl*>(m_provider.Get());

// When we are capturing a graph, we don't pool the command list and instead transfer it to the execution provider. Captured graph
// have the same bindings for their entire lifetime.
if (providerImpl->GraphCaptureEnabled() && providerImpl->GetCurrentGraphAnnotationId() != -1 && !providerImpl->GraphCaptured(providerImpl->GetCurrentGraphAnnotationId()))
Expand All @@ -259,6 +248,11 @@ namespace Dml
reusableCommandList->persistentResource = m_persistentResource;
reusableCommandList->persistentResourceAllocatorUnknown = m_persistentResourceAllocatorUnknown;

// Keep the temporary resource alive since we won't call ExecuteReusableCommandList again, but will merely replay
// the graph in the future. Therefore, all executions of the graph will use the same temporary resource that was
// allocated here the first time.
constexpr bool keepTemporaryResourceAlive = true;

DmlGraphFusionHelper::ExecuteReusableCommandList(
kernelContext,
*reusableCommandList,
Expand All @@ -270,7 +264,8 @@ namespace Dml
m_outputShapes,
m_winmlProvider.Get(),
m_provider.Get(),
m_persistentResourceAllocatorUnknown.Get());
m_persistentResourceAllocatorUnknown.Get(),
keepTemporaryResourceAlive);

providerImpl->AppendCapturedGraph(providerImpl->GetCurrentGraphAnnotationId(), std::move(reusableCommandList));
}
Expand All @@ -288,6 +283,10 @@ namespace Dml
m_reusedCommandLists.push_front(std::move(reusableCommandList));
}

// We don't need to keep a reference on the temporary resource once we have recorded into the command list, so the
// memory can be reused by the allocator
constexpr bool keepTemporaryResourceAlive = false;

DmlGraphFusionHelper::ExecuteReusableCommandList(
kernelContext,
*m_reusedCommandLists.front(),
Expand All @@ -299,7 +298,8 @@ namespace Dml
m_outputShapes,
m_winmlProvider.Get(),
m_provider.Get(),
m_persistentResourceAllocatorUnknown.Get());
m_persistentResourceAllocatorUnknown.Get(),
keepTemporaryResourceAlive);

m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front()));
m_reusedCommandLists.pop_front();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ namespace Dml

D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
bool IsClosed() const { return m_closed; }

private:
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ namespace Dml
Dml::ExecutionContext* executionContext,
bool enableMetacommands,
bool enableGraphCapture,
bool enableSyncSpinning) :
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
Expand All @@ -85,7 +86,7 @@ namespace Dml
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));

m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning);
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
}

std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
Expand Down Expand Up @@ -149,12 +150,13 @@ namespace Dml
}
}

ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning)
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, bool disableMemoryArena)
: m_d3d12Device(d3d12Device),
m_dmlDevice(dmlDevice),
m_areMetacommandsEnabled(enableMetacommands),
m_graphCaptureEnabled(enableGraphCapture),
m_cpuSyncSpinningEnabled(enableCpuSyncSpinning),
m_memoryArenaDisabled(disableMemoryArena),
m_context(executionContext)
{
D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {};
Expand Down Expand Up @@ -1181,9 +1183,13 @@ namespace Dml
m_context->ReleaseCompletedReferences();
m_uploadHeap->Trim();

// Allocations after this point are potentially transient and their sizes are
// rounded to enable pooling.
m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled);
if (!m_memoryArenaDisabled)
{
// Allocations after this point are potentially transient and their sizes are
// rounded to enable pooling.
m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled);
}

m_sessionInitialized = true;

return onnxruntime::common::Status::OK();
Expand Down Expand Up @@ -1242,9 +1248,10 @@ namespace Dml
Dml::ExecutionContext* executionContext,
bool enableMetacommands,
bool enableGraphCapture,
bool enableCpuSyncSpinning)
bool enableCpuSyncSpinning,
bool disableMemoryArena)
{
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning);
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning, disableMemoryArena);
}

ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ namespace Dml
Dml::ExecutionContext* executionContext,
bool enableMetacommands,
bool enableGraphCapture,
bool enableCpuSyncSpinning);
bool enableCpuSyncSpinning,
bool disableMemoryArena);

void ReleaseCompletedReferences();

Expand Down Expand Up @@ -207,6 +208,7 @@ namespace Dml
std::unordered_set<int> m_graphCapturingDone;
bool m_sessionInitialized = false;
bool m_cpuSyncSpinningEnabled = false;
bool m_memoryArenaDisabled = false;
ComPtr<ExecutionContext> m_context;
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;
Expand Down Expand Up @@ -260,7 +262,8 @@ namespace Dml
Dml::ExecutionContext* executionContext,
bool enableMetacommands,
bool enableGraphCapture,
bool enableSyncSpinning
bool enableSyncSpinning,
bool disableMemoryArena
);

std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ namespace Dml
m_reusedCommandLists.push_front(std::move(reusableCommandList));
}

// We don't need to keep a reference on the temporary resource once we have recorded into the command list, so the
// memory can be reused by the allocator
constexpr bool keepTemporaryResourceAlive = false;

DmlGraphFusionHelper::ExecuteReusableCommandList(
kernelContext,
*m_reusedCommandLists.front(),
Expand All @@ -175,7 +179,8 @@ namespace Dml
m_outputShapes,
m_winmlProvider.Get(),
m_provider.Get(),
m_persistentResourceAllocatorUnknown.Get());
m_persistentResourceAllocatorUnknown.Get(),
keepTemporaryResourceAlive);

m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front()));
m_reusedCommandLists.pop_front();
Expand Down Expand Up @@ -245,7 +250,7 @@ namespace Dml
persistentResourceBinding,
inputBindings,
outputBindings));
}
}

private:
ComPtr<IDMLCompiledOperator> m_compiledExecutionPlanOperator;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct DMLProviderFactory : IExecutionProviderFactory {
python_api_(python_api) {
graph_capture_enabled_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphCapture, "0"));
cpu_sync_spinning_enabled_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableCpuSyncSpinning, "0"));
disable_memory_arena_ = ConfigValueIsTrue(config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableMemoryArena, "0"));
}

~DMLProviderFactory() override {}
Expand All @@ -70,6 +71,7 @@ struct DMLProviderFactory : IExecutionProviderFactory {
bool metacommands_enabled_ = true;
bool graph_capture_enabled_ = false;
bool cpu_sync_spinning_enabled_ = false;
bool disable_memory_arena_ = false;
bool python_api_ = false;
};

Expand All @@ -91,7 +93,7 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
}

auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_);
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);
return provider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml
static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization";
static const char* const kOrtSessionOptionsConfigEnableGraphCapture = "ep.dml.enable_graph_capture";
static const char* const kOrtSessionOptionsConfigEnableCpuSyncSpinning = "ep.dml.enable_cpu_sync_spinning";
static const char* const kOrtSessionOptionsConfigDisableMemoryArena = "ep.dml.disable_memory_arena";

0 comments on commit 50ee1b0

Please sign in to comment.