Skip to content

Commit

Permalink
Use thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Apr 5, 2024
1 parent 23af506 commit 6c4a3d5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@ DmlCommandRecorder::DmlCommandRecorder(
{
ORT_THROW_IF_FAILED(dmlDevice->CreateOperatorInitializer(0, nullptr, IID_PPV_ARGS(&m_initializer)));
ORT_THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder)));
}

DmlCommandRecorder::~DmlCommandRecorder()
{
// Detach the threads to avoid crashes when terminating the program
for (auto& resetThread : m_resetThreads)
{
if (resetThread)
{
resetThread->detach();
}
}
m_threadPool = std::make_unique<onnxruntime::concurrency::ThreadPool>(
&onnxruntime::Env::Default(),
onnxruntime::ThreadOptions(),
ORT_TSTR("CommandListPool"),
threadPoolSize,
true);
}

void DmlCommandRecorder::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
Expand Down Expand Up @@ -274,7 +269,7 @@ void DmlCommandRecorder::ExecuteCommandList(
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(&commandList), 1));

// The fence value at which the current command allocator may be re-used will now be higher
m_allocatorRing.back().completionEvent = m_queue->GetNextCompletionEvent();
m_currentCommandListInfo->completionEvent = m_queue->GetNextCompletionEvent();

// Fail early if something horrifying happens
ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason());
Expand Down Expand Up @@ -324,62 +319,35 @@ void DmlCommandRecorder::Open()
{
assert(m_currentDescriptorHeap == nullptr);

if (m_currentCommandList)
if (m_availableCommandLists.empty())
{
if (m_resetThreads.front())
{
m_resetThreads.front()->join();
}

// Rotate the reset threads to the left
for (uint32_t i = 0; i < m_resetThreads.size() - 1; ++i) {
m_resetThreads[i] = std::move(m_resetThreads[i + 1]);
}

// Rotate the allocators to the left
auto firstAllocator = std::move(m_allocatorRing.front());
for (uint32_t i = 0; i < m_allocatorRing.size() - 1; ++i)
{
m_allocatorRing[i] = std::move(m_allocatorRing[i + 1]);
}
m_allocatorRing.back() = std::move(firstAllocator);

// Rotate the command lists to the left
auto firstCommandList = std::move(m_commandListRing.front());
for (uint32_t i = 0; i < m_commandListRing.size() - 1; ++i)
{
m_commandListRing[i] = std::move(m_commandListRing[i + 1]);
}
m_commandListRing.back() = std::move(firstCommandList);

// The newest dirty allocator is now located before the last element in the ring buffer, so start resetting it
m_resetThreads.back() = std::thread([cachedAllocator = m_allocatorRing[m_allocatorRing.size() - 2], cachedCommandList = m_commandListRing[m_commandListRing.size() - 2]]() {
cachedAllocator.completionEvent.WaitForSignal();
ORT_THROW_IF_FAILED(cachedAllocator.allocator->Reset());
ORT_THROW_IF_FAILED(cachedCommandList->Reset(cachedAllocator.allocator.Get(), nullptr));
});
ComPtr<ID3D12CommandAllocator> allocator;
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandAllocator(
m_queue->GetType(),
IID_GRAPHICS_PPV_ARGS(allocator.ReleaseAndGetAddressOf())));

ComPtr<ID3D12GraphicsCommandList> commandList;
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList(
0,
m_queue->GetType(),
allocator.Get(),
nullptr,
IID_GRAPHICS_PPV_ARGS(commandList.ReleaseAndGetAddressOf())));

auto commandListInfo = std::make_shared<CommandListInfo>();

Check warning on line 337 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp:337: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
commandListInfo->allocator = std::move(allocator);
commandListInfo->commandList = std::move(commandList);
m_currentCommandListInfo = std::move(commandListInfo);
}
else
{
assert(m_commandListRing.size() == m_allocatorRing.size());

for (uint32_t i = 0; i < m_commandListRing.size(); ++i)
{
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandAllocator(
m_queue->GetType(),
IID_GRAPHICS_PPV_ARGS(m_allocatorRing[i].allocator.ReleaseAndGetAddressOf())));

ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList(
0,
m_queue->GetType(),
m_allocatorRing[i].allocator.Get(),
nullptr,
IID_GRAPHICS_PPV_ARGS(m_commandListRing[i].ReleaseAndGetAddressOf())));
}
std::unique_lock lock(m_mutex);
m_currentCommandListInfo = m_availableCommandLists.back();
m_availableCommandLists.pop_back();
}

m_currentCommandList = m_commandListRing.back();
m_allocatorRing.back().completionEvent = m_queue->GetNextCompletionEvent();
m_currentCommandList = m_currentCommandListInfo->commandList;
m_currentCommandListInfo->completionEvent = m_queue->GetNextCompletionEvent();
}

void DmlCommandRecorder::CloseAndExecute()
Expand All @@ -391,6 +359,17 @@ void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* com
{
ORT_THROW_IF_FAILED(m_currentCommandList->Close());

onnxruntime::concurrency::ThreadPool::Schedule(m_threadPool.get(), [this, currentCommandListInfo = m_currentCommandListInfo]() {
currentCommandListInfo->completionEvent.WaitForSignal();
ORT_THROW_IF_FAILED(currentCommandListInfo->allocator->Reset());
ORT_THROW_IF_FAILED(currentCommandListInfo->commandList->Reset(currentCommandListInfo->allocator.Get(), nullptr));

{
std::unique_lock lock(m_mutex);
m_availableCommandLists.push_back(std::move(currentCommandListInfo));

Check warning on line 369 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp:369: Add #include <utility> for move [build/include_what_you_use] [4]
}
});

ID3D12GraphicsCommandList* commandListsToExecute[2] = {};
uint32_t commandListsToExecuteCount = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ namespace Dml
IDMLDevice* device,
std::shared_ptr<CommandQueue> commandQueue);

~DmlCommandRecorder();

void InitializeOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
Expand Down Expand Up @@ -69,8 +67,9 @@ namespace Dml
}

private:
struct CommandAllocatorInfo
struct CommandListInfo
{
ComPtr<ID3D12GraphicsCommandList> commandList;
ComPtr<ID3D12CommandAllocator> allocator;

// The event which will be signaled when the last command list submitted using this allocator
Expand Down Expand Up @@ -98,19 +97,14 @@ namespace Dml

// The command list currently being recorded into, and whether any command have been recorded yet.
ComPtr<ID3D12GraphicsCommandList> m_currentCommandList;
std::shared_ptr<CommandListInfo> m_currentCommandListInfo;
bool m_operationsRecordedInCurrentCommandList = false;

static constexpr int commandListCount = 3;

// We use enough command lists and allocators to allow command lists to be reset in a different thread while
// there is another command list ready to receive commands. When we execute and close a command list, we start
// the resetting process on a different thread and set m_currentCommandList to the next available one.
std::array<ComPtr<ID3D12GraphicsCommandList>, commandListCount> m_commandListRing;
std::array<CommandAllocatorInfo, commandListCount> m_allocatorRing;
static constexpr int threadPoolSize = 8;

// We should always have 1 less reset thread than command lists since we always need a clean command list, but
// the other ones can all be in the process of getting reset
std::array<std::optional<std::thread>, commandListCount - 1> m_resetThreads;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> m_threadPool;
std::mutex m_mutex;
std::list<std::shared_ptr<CommandListInfo>> m_availableCommandLists;

Check warning on line 107 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <list> for list<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h:107: Add #include <list> for list<> [build/include_what_you_use] [4]

Check warning on line 107 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h:107: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]

void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap);
};
Expand Down

0 comments on commit 6c4a3d5

Please sign in to comment.