Skip to content

Commit

Permalink
Improve GPU utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Jan 17, 2024
1 parent 90af70c commit 28af091
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ namespace Dml

void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);

ExecutionContext* GetContext() { return m_context.get(); }

public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
Expand Down Expand Up @@ -91,6 +93,7 @@ namespace Dml

std::shared_ptr<ExecutionContext> m_context;
std::unique_ptr<DmlSubAllocator> m_subAllocator;
bool m_logAllocs = false;

#ifndef NDEBUG
// Useful for debugging; keeps track of all allocations that haven't been freed yet
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,24 @@ DmlCommandRecorder::DmlCommandRecorder(
: m_queue(std::move(commandQueue)),
m_d3dDevice(d3dDevice),
m_dmlDevice(dmlDevice),
m_descriptorPool(d3dDevice, 2048),
m_commandAllocatorRing(d3dDevice, m_queue->GetType(), m_queue->GetCurrentCompletionEvent())
m_descriptorPool(d3dDevice, 2048)
{
ORT_THROW_IF_FAILED(dmlDevice->CreateOperatorInitializer(0, nullptr, IID_PPV_ARGS(&m_initializer)));
ORT_THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder)));
}

DmlCommandRecorder::~DmlCommandRecorder()
{
// Detach the threads to avoid crashes when terminating the program
for (auto& resetThread : m_resetThreads)
{
if (resetThread)
{
resetThread->detach();
}
}
}

void DmlCommandRecorder::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
{
m_bufferAllocator = allocator;
Expand Down Expand Up @@ -263,7 +274,7 @@ void DmlCommandRecorder::ExecuteCommandList(
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(&commandList), 1));

// The fence value at which the current command allocator may be re-used will now be higher
m_commandAllocatorRing.UpdateCurrentAllocatorCompletionEvent(m_queue->GetNextCompletionEvent());
m_allocatorRing.back().completionEvent = m_queue->GetNextCompletionEvent();

// Fail early if something horrifying happens
ORT_THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason());
Expand Down Expand Up @@ -313,23 +324,62 @@ void DmlCommandRecorder::Open()
{
assert(m_currentDescriptorHeap == nullptr);

ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetNextAllocator(m_queue->GetNextCompletionEvent());

if (!m_cachedCommandList)
if (m_currentCommandList)
{
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList(
0,
m_queue->GetType(),
allocator,
nullptr,
IID_GRAPHICS_PPV_ARGS(m_currentCommandList.ReleaseAndGetAddressOf())));
if (m_resetThreads.front())
{
m_resetThreads.front()->join();
}

// Rotate the reset threads to the left
for (uint32_t i = 0; i < m_resetThreads.size() - 1; ++i) {
m_resetThreads[i] = std::move(m_resetThreads[i + 1]);
}

// Rotate the allocators to the left
auto firstAllocator = std::move(m_allocatorRing.front());
for (uint32_t i = 0; i < m_allocatorRing.size() - 1; ++i)
{
m_allocatorRing[i] = std::move(m_allocatorRing[i + 1]);
}
m_allocatorRing.back() = std::move(firstAllocator);

// Rotate the command lists to the left
auto firstCommandList = std::move(m_commandListRing.front());
for (uint32_t i = 0; i < m_commandListRing.size() - 1; ++i)
{
m_commandListRing[i] = std::move(m_commandListRing[i + 1]);
}
m_commandListRing.back() = std::move(firstCommandList);

// The newest dirty allocator is now located before the last element in the ring buffer, so start resetting it
m_resetThreads.back() = std::thread([cachedAllocator = m_allocatorRing[m_allocatorRing.size() - 2], cachedCommandList = m_commandListRing[m_commandListRing.size() - 2]]() {
cachedAllocator.completionEvent.WaitForSignal();
ORT_THROW_IF_FAILED(cachedAllocator.allocator->Reset());
ORT_THROW_IF_FAILED(cachedCommandList->Reset(cachedAllocator.allocator.Get(), nullptr));
});
}
else
{
m_currentCommandList = m_cachedCommandList;
m_cachedCommandList = nullptr;
ORT_THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr));
assert(m_commandListRing.size() == m_allocatorRing.size());

for (uint32_t i = 0; i < m_commandListRing.size(); ++i)
{
ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandAllocator(
m_queue->GetType(),
IID_GRAPHICS_PPV_ARGS(m_allocatorRing[i].allocator.ReleaseAndGetAddressOf())));

ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommandList(
0,
m_queue->GetType(),
m_allocatorRing[i].allocator.Get(),
nullptr,
IID_GRAPHICS_PPV_ARGS(m_commandListRing[i].ReleaseAndGetAddressOf())));
}
}

m_currentCommandList = m_commandListRing.back();
m_allocatorRing.back().completionEvent = m_queue->GetNextCompletionEvent();
}

void DmlCommandRecorder::CloseAndExecute()
Expand All @@ -338,7 +388,7 @@ void DmlCommandRecorder::CloseAndExecute()
}

void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList)
{
{
ORT_THROW_IF_FAILED(m_currentCommandList->Close());

ID3D12GraphicsCommandList* commandListsToExecute[2] = {};
Expand All @@ -359,9 +409,7 @@ void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* com
m_queue->ExecuteCommandLists(
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(commandListsToExecute), commandListsToExecuteCount));
}

m_cachedCommandList = m_currentCommandList;
m_currentCommandList = nullptr;

m_operationsRecordedInCurrentCommandList = false;

// The descriptor heap must be set on the command list the next time it's opened.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#pragma once

#include "ICommandRecorder.h"
#include "CommandAllocatorRing.h"

namespace Dml
{
Expand All @@ -16,9 +15,11 @@ namespace Dml
public:
DmlCommandRecorder(
ID3D12Device* d3dDevice,
IDMLDevice* device,
IDMLDevice* device,
std::shared_ptr<CommandQueue> commandQueue);

~DmlCommandRecorder();

void InitializeOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
Expand Down Expand Up @@ -47,13 +48,13 @@ namespace Dml
_Out_ uint64_t* completionValue);

ComPtr<ID3D12GraphicsCommandList> GetCommandList();

void ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers);
void AddUAVBarrier();

void Open() final;
void CloseAndExecute() final;

void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);

bool HasUnsubmittedWork() override
Expand All @@ -68,8 +69,19 @@ namespace Dml
}

private:
struct CommandAllocatorInfo
{
ComPtr<ID3D12CommandAllocator> allocator;

// The event which will be signaled when the last command list submitted using this allocator
// completes execution on the GPU.
GpuEvent completionEvent = {};

ID3D12CommandAllocator* Get() const { return allocator.Get(); }
};

void CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList);

std::shared_ptr<CommandQueue> m_queue;
ComPtr<ID3D12Device> m_d3dDevice;
ComPtr<IDMLDevice> m_dmlDevice;
Expand All @@ -84,14 +96,21 @@ namespace Dml
// The weak pointer avoids a circular reference from context->recorder->allocator->context
std::weak_ptr<BucketizedBufferAllocator> m_bufferAllocator;

CommandAllocatorRing<2> m_commandAllocatorRing;

// The command list currently being recorded into, and whether any command have been recorded yet.
ComPtr<ID3D12GraphicsCommandList> m_currentCommandList;
bool m_operationsRecordedInCurrentCommandList = false;

// A cached command list which may be re-used.
ComPtr<ID3D12GraphicsCommandList> m_cachedCommandList;
static constexpr int commandListCount = 3;

// We use enough command lists and allocators to allow command lists to be reset in a different thread while
// there is another command list ready to receive commands. When we execute and close a command list, we start
// the resetting process on a different thread and set m_currentCommandList to the next available one.
std::array<ComPtr<ID3D12GraphicsCommandList>, commandListCount> m_commandListRing;
std::array<CommandAllocatorInfo, commandListCount> m_allocatorRing;

// We should always have 1 less reset thread than command lists since we always need a clean command list, but
// the other ones can all be in the process of getting reset
std::array<std::optional<std::thread>, commandListCount - 1> m_resetThreads;

void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ namespace Dml
m_context = std::make_shared<ExecutionContext>(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);

m_uploadHeap = std::make_unique<PooledUploadHeap>(m_d3d12Device.Get(), m_context);
m_readbackHeap = std::make_unique<ReadbackHeap>(m_d3d12Device.Get(), m_context);
m_readbackHeap = std::make_unique<ReadbackHeap>(m_d3d12Device.Get());

CreateDmlKernelRegistry(&m_kernelRegistry, &m_internalRegInfoMap);

Expand Down Expand Up @@ -474,7 +474,12 @@ namespace Dml
const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state

m_uploadHeap->BeginUploadToGpu(dstData, dstOffset, dstState, AsByteSpan(srcData, dataSizeInBytes));
FlushUploadsIfReady();

// Continuously upload memory located in upload heaps during session initialization to avoid running out of it
if (!m_sessionInitialized)
{
FlushUploadsIfReady();
}
}
else if (!src->IsCpuData() && dst->IsCpuData())
{
Expand All @@ -491,7 +496,7 @@ namespace Dml
const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state

// Performs a blocking call to synchronize and read back data from the GPU into the destination buffer
m_readbackHeap->ReadbackFromGpu(AsByteSpan(dstData, dataSizeInBytes), srcData, srcOffset, srcState);
m_readbackHeap->ReadbackFromGpu(m_context.get(), AsByteSpan(dstData, dataSizeInBytes), srcData, srcOffset, srcState);
}
else if (!src->IsCpuData() && !dst->IsCpuData())
{
Expand Down Expand Up @@ -560,7 +565,7 @@ namespace Dml
const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state

// Performs a blocking call to synchronize and read back data from the GPU into the destination buffer
m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcDatas, srcState);
m_readbackHeap->ReadbackFromGpu(m_context.get(), dstDatas, dataSizesInBytes, srcDatas, srcState);

return S_OK;
}
Expand Down Expand Up @@ -978,7 +983,7 @@ namespace Dml
const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state

// Performs a blocking call to synchronize and read back data from the GPU into the destination buffer
m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcDatas, srcState);
m_readbackHeap->ReadbackFromGpu(m_context.get(), dstDatas, dataSizesInBytes, srcDatas, srcState);

return onnxruntime::common::Status::OK();
}
Expand Down Expand Up @@ -1159,6 +1164,7 @@ namespace Dml
// Allocations after this point are potentially transient and their sizes are
// rounded to enable pooling.
m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled);
m_sessionInitialized = true;

return onnxruntime::common::Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ namespace Dml
bool m_areMetacommandsEnabled = true;
bool m_dynamicGraphFusionEnabled = false;
bool m_native16BitShaderOpsSupported = false;
bool m_sessionInitialized = false;
std::shared_ptr<ExecutionContext> m_context;
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;
Expand Down
Loading

0 comments on commit 28af091

Please sign in to comment.