-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DML] Hide command list reset latency with multiple threads #20067
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,17 @@ | |
: m_queue(std::move(commandQueue)), | ||
m_d3dDevice(d3dDevice), | ||
m_dmlDevice(dmlDevice), | ||
m_descriptorPool(d3dDevice, 2048), | ||
m_commandAllocatorRing(d3dDevice, m_queue->GetType(), m_queue->GetCurrentCompletionEvent()) | ||
m_descriptorPool(d3dDevice, 2048) | ||
{ | ||
ORT_THROW_IF_FAILED(dmlDevice->CreateOperatorInitializer(0, nullptr, IID_PPV_ARGS(&m_initializer))); | ||
ORT_THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder))); | ||
|
||
m_threadPool = std::make_unique<onnxruntime::concurrency::ThreadPool>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to bypass thread pool work if the only submissions occur from fused partitions? I'm assuming in that the issue is fixed overhead of Reset and that CPU/GPU parallelism is still sufficient in that case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently this is still needed for fused partitions (LLMs only use fused partitions after the first iteration and still benefit from it). The problem isn't executing multiple different DML command lists, but executing multiple GPU->CPU copies. Those copies are not big enough to hide latencies, so we end up waiting on the reset. |
||
&onnxruntime::Env::Default(), | ||
onnxruntime::ThreadOptions(), | ||
ORT_TSTR("CommandListPool"), | ||
threadPoolSize, | ||
true); | ||
} | ||
|
||
void DmlCommandRecorder::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator) | ||
|
@@ -263,7 +269,7 @@ | |
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(&commandList), 1)); | ||
|
||
// The fence value at which the current command allocator may be re-used will now be higher | ||
m_commandAllocatorRing.UpdateCurrentAllocatorCompletionEvent(m_queue->GetNextCompletionEvent()); | ||
m_currentCommandListInfo->completionEvent = m_queue->GetNextCompletionEvent(); | ||
|
||
// Fail early if something horrifying happens | ||
ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason()); | ||
|
@@ -313,23 +319,35 @@ | |
{ | ||
assert(m_currentDescriptorHeap == nullptr); | ||
|
||
ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetNextAllocator(m_queue->GetNextCompletionEvent()); | ||
|
||
if (!m_cachedCommandList) | ||
if (m_availableCommandLists.empty()) | ||
{ | ||
ComPtr<ID3D12CommandAllocator> allocator; | ||
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandAllocator( | ||
m_queue->GetType(), | ||
IID_GRAPHICS_PPV_ARGS(allocator.ReleaseAndGetAddressOf()))); | ||
|
||
ComPtr<ID3D12GraphicsCommandList> commandList; | ||
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList( | ||
0, | ||
m_queue->GetType(), | ||
allocator, | ||
allocator.Get(), | ||
nullptr, | ||
IID_GRAPHICS_PPV_ARGS(m_currentCommandList.ReleaseAndGetAddressOf()))); | ||
IID_GRAPHICS_PPV_ARGS(commandList.ReleaseAndGetAddressOf()))); | ||
|
||
auto commandListInfo = std::make_shared<CommandListInfo>(); | ||
Check warning on line 337 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp GitHub Actions / Lint C++
|
||
commandListInfo->allocator = std::move(allocator); | ||
commandListInfo->commandList = std::move(commandList); | ||
m_currentCommandListInfo = std::move(commandListInfo); | ||
} | ||
else | ||
{ | ||
m_currentCommandList = m_cachedCommandList; | ||
m_cachedCommandList = nullptr; | ||
ORT_THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr)); | ||
std::unique_lock lock(m_mutex); | ||
m_currentCommandListInfo = m_availableCommandLists.back(); | ||
m_availableCommandLists.pop_back(); | ||
} | ||
|
||
m_currentCommandList = m_currentCommandListInfo->commandList; | ||
m_currentCommandListInfo->completionEvent = m_queue->GetNextCompletionEvent(); | ||
} | ||
|
||
void DmlCommandRecorder::CloseAndExecute() | ||
|
@@ -338,7 +356,7 @@ | |
} | ||
|
||
void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList) | ||
{ | ||
{ | ||
ORT_THROW_IF_FAILED(m_currentCommandList->Close()); | ||
|
||
ID3D12GraphicsCommandList* commandListsToExecute[2] = {}; | ||
|
@@ -347,6 +365,17 @@ | |
if (m_operationsRecordedInCurrentCommandList) | ||
{ | ||
commandListsToExecute[commandListsToExecuteCount++] = m_currentCommandList.Get(); | ||
|
||
onnxruntime::concurrency::ThreadPool::Schedule(m_threadPool.get(), [this, currentCommandListInfo = m_currentCommandListInfo]() { | ||
currentCommandListInfo->completionEvent.WaitForSignal(); | ||
ORT_THROW_IF_FAILED(currentCommandListInfo->allocator->Reset()); | ||
ORT_THROW_IF_FAILED(currentCommandListInfo->commandList->Reset(currentCommandListInfo->allocator.Get(), nullptr)); | ||
|
||
{ | ||
std::unique_lock lock(m_mutex); | ||
m_availableCommandLists.push_back(std::move(currentCommandListInfo)); | ||
Check warning on line 376 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp GitHub Actions / Lint C++
|
||
} | ||
}); | ||
} | ||
|
||
if (commandList) | ||
|
@@ -359,9 +388,7 @@ | |
m_queue->ExecuteCommandLists( | ||
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(commandListsToExecute), commandListsToExecuteCount)); | ||
} | ||
|
||
m_cachedCommandList = m_currentCommandList; | ||
m_currentCommandList = nullptr; | ||
|
||
m_operationsRecordedInCurrentCommandList = false; | ||
|
||
// The descriptor heap must be set on the command list the next time it's opened. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want to re-use other thread pool instances and in doing so respect API options for thread count. May be a good idea to get input from ORT devs on this.