diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h index 570f62aac8105..2eee9c9a9e5a3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandAllocatorRing.h @@ -47,6 +47,14 @@ namespace Dml return m_commandAllocators[m_currentCommandAllocator].Get(); } + // Updates the completion event of the current allocator to a different value. This is used when the caller + // decides to issue an unrelated call to the queue such as ExecuteCommandLists which updates its fence between calling + // GetNextAllocator and executing the work which it recorded using the allocator it received. + void UpdateCurrentAllocatorCompletionEvent(GpuEvent nextCompletionEvent) + { + m_commandAllocators[m_currentCommandAllocator].completionEvent = nextCompletionEvent; + } + private: struct CommandAllocatorInfo { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp index 98345f37b68d4..5254b23f56376 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp @@ -262,6 +262,9 @@ void DmlCommandRecorder::ExecuteCommandList( m_queue->ExecuteCommandLists( gsl::span(reinterpret_cast(&commandList), 1)); + // The fence value at which the current command allocator may be re-used will now be higher + m_commandAllocatorRing.UpdateCurrentAllocatorCompletionEvent(m_queue->GetNextCompletionEvent()); + // Fail early if something horrifying happens ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason()); ORT_THROW_IF_FAILED(m_d3dDevice->GetDeviceRemovedReason()); @@ -269,42 +272,20 @@ void DmlCommandRecorder::ExecuteCommandList( return; } - ORT_THROW_IF_FAILED(m_currentCommandList->Close()); - - if (m_operationsRecordedInCurrentCommandList) - { - m_pendingCommandLists.push_back(m_currentCommandList.Get()); - m_pendingCommandListsCacheable.push_back(true); - } - else - { - m_cachedCommandLists.push_back(m_currentCommandList.Get()); - } - - m_currentCommandList = nullptr; - m_operationsRecordedInCurrentCommandList = false; - - m_pendingCommandLists.push_back(commandList); - m_pendingCommandListsCacheable.push_back(false); - - // Remember the descriptor heap and apply it to the next command list + // Remember the descriptor heap and apply it to the next command list. This avoids unnecessarily setting it onto + // the D3D object lazily at a point when the operation may not be parallelized with GPU work. auto heap = m_currentDescriptorHeap; - m_currentDescriptorHeap = nullptr; - Open(); - - // The caller can re-use relevant resources after the next set of work to be - // flushed has completed. Its command list hasn't been executed yet, just batched. - GpuEvent gpuEvent = m_queue->GetNextCompletionEvent(); - gpuEvent.fence.CopyTo(fence); - *completionValue = gpuEvent.fenceValue; - // Trigger a flush of the command list, with the assumption that it contains enough GPU work that this - // will help parallelize GPU work with subsequent CPU work. This policy is related to the choice of - // minNodeCountToReuseCommandList within FusedGraphKernel, so both should be tuned together. - CloseAndExecute(); + // Execute work in the current command list plus provided command list while closing the recorder. + CloseAndExecute(commandList); Open(); + // Reset the descriptor heap opportunistically per above comment SetDescriptorHeap(heap); + + GpuEvent gpuEvent = m_queue->GetCurrentCompletionEvent(); + gpuEvent.fence.CopyTo(fence); + *completionValue = gpuEvent.fenceValue; } ComPtr DmlCommandRecorder::GetCommandList() @@ -334,7 +315,7 @@ void DmlCommandRecorder::Open() ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetNextAllocator(m_queue->GetNextCompletionEvent()); - if (m_cachedCommandLists.empty()) + if (!m_cachedCommandList) { ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList( 0, @@ -345,47 +326,43 @@ void DmlCommandRecorder::Open() } else { - m_currentCommandList = m_cachedCommandLists.front(); - m_cachedCommandLists.pop_front(); + m_currentCommandList = m_cachedCommandList; + m_cachedCommandList = nullptr; ORT_THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr)); } } void DmlCommandRecorder::CloseAndExecute() { + CloseAndExecute(nullptr); +} + +void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList) +{ ORT_THROW_IF_FAILED(m_currentCommandList->Close()); + ID3D12GraphicsCommandList* commandListsToExecute[2] = {}; + uint32_t commandListsToExecuteCount = 0; + if (m_operationsRecordedInCurrentCommandList) { - m_pendingCommandLists.push_back(m_currentCommandList.Get()); - m_pendingCommandListsCacheable.push_back(true); + commandListsToExecute[commandListsToExecuteCount++] = m_currentCommandList.Get(); } - else + + if (commandList) { - m_cachedCommandLists.push_back(m_currentCommandList.Get()); + commandListsToExecute[commandListsToExecuteCount++] = commandList; } - m_currentCommandList = nullptr; - m_operationsRecordedInCurrentCommandList = false; - - if (!m_pendingCommandLists.empty()) + if (commandListsToExecuteCount > 0) { - // Close and execute the command list m_queue->ExecuteCommandLists( - gsl::span(reinterpret_cast(m_pendingCommandLists.data()), m_pendingCommandLists.size())); - - assert(m_pendingCommandLists.size() == m_pendingCommandListsCacheable.size()); - for (size_t i = 0; i < m_pendingCommandLists.size(); ++i) - { - if (m_pendingCommandListsCacheable[i]) - { - m_cachedCommandLists.push_back(m_pendingCommandLists[i]); - } - } - - m_pendingCommandLists.clear(); - m_pendingCommandListsCacheable.clear(); + gsl::span(reinterpret_cast(commandListsToExecute), commandListsToExecuteCount)); } + + m_cachedCommandList = m_currentCommandList; + m_currentCommandList = nullptr; + m_operationsRecordedInCurrentCommandList = false; // The descriptor heap must be set on the command list the next time it's opened. m_currentDescriptorHeap = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h index 7ad7032317d77..83051c8ca4ff9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h @@ -58,7 +58,7 @@ namespace Dml bool HasUnsubmittedWork() override { - return m_operationsRecordedInCurrentCommandList || !m_pendingCommandLists.empty(); + return m_operationsRecordedInCurrentCommandList; } // Forces the descriptor heap to be reset to D3D before executing future operations @@ -68,7 +68,8 @@ namespace Dml } private: - + void CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList); + std::shared_ptr m_queue; ComPtr m_d3dDevice; ComPtr m_dmlDevice; @@ -89,15 +90,8 @@ namespace Dml ComPtr m_currentCommandList; bool m_operationsRecordedInCurrentCommandList = false; - // Command lists which have been batched up for execution. The values in - // m_pendingCommandListsCacheable indicate whether they can be moved into this - // class's cache after execution, versus if they belong to the caller and were - // passed to ExecuteCommandList. - std::vector> m_pendingCommandLists; - std::vector m_pendingCommandListsCacheable; - - // A pool of cached command lists which may be re-used. - std::deque> m_cachedCommandLists; + // A cached command list which may be re-used. + ComPtr m_cachedCommandList; void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap); };