Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML] Hide command list reset latency with multiple threads #20067

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
: m_queue(std::move(commandQueue)),
m_d3dDevice(d3dDevice),
m_dmlDevice(dmlDevice),
m_descriptorPool(d3dDevice, 2048),
m_commandAllocatorRing(d3dDevice, m_queue->GetType(), m_queue->GetCurrentCompletionEvent())
m_descriptorPool(d3dDevice, 2048)
{
ORT_THROW_IF_FAILED(dmlDevice->CreateOperatorInitializer(0, nullptr, IID_PPV_ARGS(&m_initializer)));
ORT_THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder)));

m_threadPool = std::make_unique<onnxruntime::concurrency::ThreadPool>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to re-use other thread pool instances and in doing so respect API options for thread count. May be a good idea to get input from ORT devs on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to bypass thread pool work if the only submissions occur from fused partitions? I'm assuming in that the issue is fixed overhead of Reset and that CPU/GPU parallelism is still sufficient in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is still needed for fused partitions (LLMs only use fused partitions after the first iteration and still benefit from it). The problem isn't executing multiple different DML command lists, but executing multiple GPU->CPU copies. Those copies are not big enough to hide latencies, so we end up waiting on the reset.

&onnxruntime::Env::Default(),
onnxruntime::ThreadOptions(),
ORT_TSTR("CommandListPool"),
threadPoolSize,
true);
}

void DmlCommandRecorder::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
Expand Down Expand Up @@ -263,7 +269,7 @@
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(&commandList), 1));

// The fence value at which the current command allocator may be re-used will now be higher
m_commandAllocatorRing.UpdateCurrentAllocatorCompletionEvent(m_queue->GetNextCompletionEvent());
m_currentCommandListInfo->completionEvent = m_queue->GetNextCompletionEvent();

// Fail early if something horrifying happens
ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason());
Expand Down Expand Up @@ -313,23 +319,35 @@
{
assert(m_currentDescriptorHeap == nullptr);

ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetNextAllocator(m_queue->GetNextCompletionEvent());

if (!m_cachedCommandList)
if (m_availableCommandLists.empty())
{
ComPtr<ID3D12CommandAllocator> allocator;
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandAllocator(
m_queue->GetType(),
IID_GRAPHICS_PPV_ARGS(allocator.ReleaseAndGetAddressOf())));

ComPtr<ID3D12GraphicsCommandList> commandList;
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList(
0,
m_queue->GetType(),
allocator,
allocator.Get(),
nullptr,
IID_GRAPHICS_PPV_ARGS(m_currentCommandList.ReleaseAndGetAddressOf())));
IID_GRAPHICS_PPV_ARGS(commandList.ReleaseAndGetAddressOf())));

auto commandListInfo = std::make_shared<CommandListInfo>();

Check warning on line 337 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp:337: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
commandListInfo->allocator = std::move(allocator);
commandListInfo->commandList = std::move(commandList);
m_currentCommandListInfo = std::move(commandListInfo);
}
else
{
m_currentCommandList = m_cachedCommandList;
m_cachedCommandList = nullptr;
ORT_THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr));
std::unique_lock lock(m_mutex);
m_currentCommandListInfo = m_availableCommandLists.back();
m_availableCommandLists.pop_back();
}

m_currentCommandList = m_currentCommandListInfo->commandList;
m_currentCommandListInfo->completionEvent = m_queue->GetNextCompletionEvent();
}

void DmlCommandRecorder::CloseAndExecute()
Expand All @@ -338,7 +356,7 @@
}

void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList)
{
{
ORT_THROW_IF_FAILED(m_currentCommandList->Close());

ID3D12GraphicsCommandList* commandListsToExecute[2] = {};
Expand All @@ -347,6 +365,17 @@
if (m_operationsRecordedInCurrentCommandList)
{
commandListsToExecute[commandListsToExecuteCount++] = m_currentCommandList.Get();

onnxruntime::concurrency::ThreadPool::Schedule(m_threadPool.get(), [this, currentCommandListInfo = m_currentCommandListInfo]() {
currentCommandListInfo->completionEvent.WaitForSignal();
ORT_THROW_IF_FAILED(currentCommandListInfo->allocator->Reset());
ORT_THROW_IF_FAILED(currentCommandListInfo->commandList->Reset(currentCommandListInfo->allocator.Get(), nullptr));

{
std::unique_lock lock(m_mutex);
m_availableCommandLists.push_back(std::move(currentCommandListInfo));

Check warning on line 376 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp:376: Add #include <utility> for move [build/include_what_you_use] [4]
}
});
}

if (commandList)
Expand All @@ -359,9 +388,7 @@
m_queue->ExecuteCommandLists(
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(commandListsToExecute), commandListsToExecuteCount));
}

m_cachedCommandList = m_currentCommandList;
m_currentCommandList = nullptr;

m_operationsRecordedInCurrentCommandList = false;

// The descriptor heap must be set on the command list the next time it's opened.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#pragma once

#include "ICommandRecorder.h"
#include "CommandAllocatorRing.h"

namespace Dml
{
Expand All @@ -16,7 +15,7 @@
public:
DmlCommandRecorder(
ID3D12Device* d3dDevice,
IDMLDevice* device,
IDMLDevice* device,
std::shared_ptr<CommandQueue> commandQueue);

void InitializeOperator(
Expand Down Expand Up @@ -47,13 +46,13 @@
_Out_ uint64_t* completionValue);

ComPtr<ID3D12GraphicsCommandList> GetCommandList();

void ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers);
void AddUAVBarrier();

void Open() final;
void CloseAndExecute() final;

void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);

bool HasUnsubmittedWork() override
Expand All @@ -68,8 +67,20 @@
}

private:
struct CommandListInfo
{
ComPtr<ID3D12GraphicsCommandList> commandList;
ComPtr<ID3D12CommandAllocator> allocator;

// The event which will be signaled when the last command list submitted using this allocator
// completes execution on the GPU.
GpuEvent completionEvent = {};

ID3D12CommandAllocator* Get() const { return allocator.Get(); }
};

void CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList);

std::shared_ptr<CommandQueue> m_queue;
ComPtr<ID3D12Device> m_d3dDevice;
ComPtr<IDMLDevice> m_dmlDevice;
Expand All @@ -84,14 +95,16 @@
// The weak pointer avoids a circular reference from context->recorder->allocator->context
std::weak_ptr<BucketizedBufferAllocator> m_bufferAllocator;

CommandAllocatorRing<2> m_commandAllocatorRing;

// The command list currently being recorded into, and whether any command have been recorded yet.
ComPtr<ID3D12GraphicsCommandList> m_currentCommandList;
std::shared_ptr<CommandListInfo> m_currentCommandListInfo;
bool m_operationsRecordedInCurrentCommandList = false;

// A cached command list which may be re-used.
ComPtr<ID3D12GraphicsCommandList> m_cachedCommandList;
static constexpr int threadPoolSize = 8;

std::unique_ptr<onnxruntime::concurrency::ThreadPool> m_threadPool;
std::mutex m_mutex;
std::list<std::shared_ptr<CommandListInfo>> m_availableCommandLists;

Check warning on line 107 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <list> for list<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h:107: Add #include <list> for list<> [build/include_what_you_use] [4]

Check warning on line 107 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h:107: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]

void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap);
};
Expand Down
Loading