Skip to content

Commit

Permalink
Add command list reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Nov 5, 2023
1 parent 208bab4 commit 025695e
Showing 1 changed file with 241 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ namespace Dml
void TranslateAndCompileGraph(
const onnxruntime::OpKernelInfo& kernelInfo,
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>>& initializeResourceRefs,
std::vector<DML_BUFFER_BINDING> initInputBindings) const
std::vector<DML_BUFFER_BINDING> initInputBindings,
bool reuseCommandList) const
{
// Allocate a persistent resource and initialize the operator
UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize;
Expand All @@ -89,6 +90,11 @@ namespace Dml
initializeResourceRefs.end(),
[&](ComPtr<ID3D12Resource>& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); }
);

if (reuseCommandList)
{
BuildReusableCommandList();
}
}

onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override
Expand Down Expand Up @@ -152,6 +158,9 @@ namespace Dml

if (recompileNeeded)
{
m_tempBindingAllocId = 0;
m_completionValue = 0;

// Go through all the node args and replace their shapes with the real ones
for (auto& nodeArg : m_intermediateNodeArgs)
{
Expand Down Expand Up @@ -213,42 +222,51 @@ namespace Dml
TranslateAndCompileGraph(
Info(),
initializeResourceRefs,
initInputBindings);
initInputBindings,
graphDesc.reuseCommandList);
}

// Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator
OpKernelContextWrapper contextWrapper(
kernelContext,
Info().GetExecutionProvider(),
true,
nullptr);
if (!m_graphicsCommandList ||
(m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue))
{
// Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator
OpKernelContextWrapper contextWrapper(
kernelContext,
Info().GetExecutionProvider(),
true,
nullptr);

ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());

// Get input resources for execution, excluding those which were specified as owned by DML and provided
// at initialization instead.
std::vector<ComPtr<IMLOperatorTensor>> inputTensors(kernelContext->InputCount());
std::vector<ID3D12Resource*> inputPtrs(kernelContext->InputCount());
// Get input resources for execution, excluding those which were specified as owned by DML and provided
// at initialization instead.
std::vector<ComPtr<IMLOperatorTensor>> inputTensors(kernelContext->InputCount());
std::vector<ID3D12Resource*> inputPtrs(kernelContext->InputCount());

for (int i = 0; i < kernelContext->InputCount(); ++i)
{
if (!m_inputsUsed[i])
for (int i = 0; i < kernelContext->InputCount(); ++i)
{
continue;
if (!m_inputsUsed[i])
{
continue;
}

ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf()));
inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get());
}

ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf()));
inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get());
}
auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes);
ExecuteOperator(
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
inputPtrs,
outputTensors);

auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes);
ExecuteOperator(
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
inputPtrs,
outputTensors);

ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());
}
else
{
ExecuteReusableCommandList(kernelContext);
}

return onnxruntime::Status::OK();
}
Expand Down Expand Up @@ -317,6 +335,185 @@ namespace Dml
}

private:
void BuildReusableCommandList() const
{
ComPtr<IDMLDevice> device;
ORT_THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf()));

DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties();

D3D12_DESCRIPTOR_HEAP_DESC desc = {};
desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
desc.NumDescriptors = execBindingProps.RequiredDescriptorCount;
desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;

ComPtr<ID3D12Device> d3dDevice;
ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));

ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(m_heap.ReleaseAndGetAddressOf())));

// Create a binding table for execution.
DML_BINDING_TABLE_DESC bindingTableDesc = {};
bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get();
bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart();
bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart();
bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount;

ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable)));

ORT_THROW_IF_FAILED(d3dDevice->CreateCommandAllocator(
m_provider->GetCommandListTypeForQueue(),
IID_GRAPHICS_PPV_ARGS(m_commandAllocator.ReleaseAndGetAddressOf())));

ORT_THROW_IF_FAILED(d3dDevice->CreateCommandList(
0,
m_provider->GetCommandListTypeForQueue(),
m_commandAllocator.Get(),
nullptr,
IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf())));

if (m_persistentResource)
{
DML_BINDING_DESC persistentResourceBindingDesc =
{ DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr };
m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc);
}

ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() };
m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps);

ComPtr<IDMLCommandRecorder> recorder;
ORT_THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf())));

recorder->RecordDispatch(m_graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get());

ORT_THROW_IF_FAILED(m_graphicsCommandList->Close());
}

void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext) const
{
DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties();

std::vector<DML_BUFFER_BINDING> inputBindings(kernelContext->InputCount());
std::vector<DML_BINDING_DESC> inputBindingDescs(kernelContext->InputCount());

OpKernelContextWrapper contextWrapper(
kernelContext,
Info().GetExecutionProvider(),
true,
nullptr);

// Populate input bindings, excluding those which were specified as owned by DML and provided
// at initialization instead.
m_inputBindingAllocIds.resize(inputBindings.size());
bool inputBindingsChanged = false;

for (uint32_t i = 0; i < inputBindings.size(); ++i)
{
if (m_inputsUsed[i])
{
assert(kernelContext->InputType(gsl::narrow_cast<int>(i))->IsTensorType());
const onnxruntime::Tensor* tensor = kernelContext->Input<onnxruntime::Tensor>(gsl::narrow_cast<int>(i));

uint64_t allocId;
DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId);
inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId);
inputBindings[i].Buffer->Release(); // Avoid holding an additional reference
inputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]};
m_inputBindingAllocIds[i] = allocId;
}
}

if (inputBindingsChanged)
{
m_bindingTable->BindInputs(gsl::narrow_cast<uint32_t>(inputBindingDescs.size()), inputBindingDescs.data());
}

// Populate Output bindings
std::vector<DML_BUFFER_BINDING> outputBindings(kernelContext->OutputCount());
std::vector<DML_BINDING_DESC> outputBindingDescs(kernelContext->OutputCount());

m_outputBindingAllocIds.resize(outputBindings.size());
bool outputBindingsChanged = false;

for (uint32_t i = 0; i < outputBindings.size(); ++i)
{
std::vector<int64_t> outputDims;
outputDims.reserve(m_outputShapes.GetShape(i).size());
for (uint32_t dimSize : m_outputShapes.GetShape(i))
{
outputDims.push_back(dimSize);
}

onnxruntime::Tensor* tensor = kernelContext->Output(
static_cast<int>(i),
onnxruntime::TensorShape::FromExistingBuffer(outputDims)
);

uint64_t allocId;
DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId);
outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId);
outputBindings[i].Buffer->Release(); // Avoid holding an additional reference
outputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]};
m_outputBindingAllocIds[i] = allocId;
}

if (outputBindingsChanged)
{
m_bindingTable->BindOutputs(gsl::narrow_cast<uint32_t>(outputBindingDescs.size()), outputBindingDescs.data());
}

if (execBindingProps.TemporaryResourceSize > 0)
{
// Allocate temporary data which will automatically be freed when the GPU work
// which is scheduled up to the point that this method returns has completed.
ComPtr<IUnknown> tempAlloc;
uint64_t tempAllocId = 0;
ORT_THROW_IF_FAILED(contextWrapper.AllocateTemporaryData(static_cast<size_t>(execBindingProps.TemporaryResourceSize), tempAlloc.GetAddressOf(), &tempAllocId));

ComPtr<IUnknown> tempResourceUnk;
m_winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnk);

// Bind the temporary resource.
ComPtr<ID3D12Resource> tempResource;
ORT_THROW_IF_FAILED(tempResourceUnk->QueryInterface(tempResource.GetAddressOf()));
DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize};
DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding };

if (!tempAllocId || m_tempBindingAllocId != tempAllocId)
{
m_bindingTable->BindTemporaryResource(&tempBindingDesc);
}

m_tempBindingAllocId = tempAllocId;
}

// Execute the command list and if it succeeds, update the fence value at which this command may be
// re-used.
ComPtr<ID3D12Fence> fence;
uint64_t completionValue;
HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue);

if (hr == DXGI_ERROR_DEVICE_REMOVED)
{
ComPtr<ID3D12Device> device;
ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(&device));
ORT_THROW_IF_FAILED(device->GetDeviceRemovedReason());
}

ORT_THROW_IF_FAILED(hr);
m_fence = fence;
m_completionValue = completionValue;

// Queue references to objects which must be kept alive until resulting GPU work completes
m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_graphicsCommandList).Get());
m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_heap).Get());
m_winmlProvider->QueueReference(m_bindingTable.Get());
m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get());
}

ComPtr<IWinmlExecutionProvider> m_winmlProvider;
ComPtr<Dml::IExecutionProvider> m_provider;

Expand All @@ -341,6 +538,22 @@ namespace Dml
mutable ComPtr<IUnknown> m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator
mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes;
mutable std::unordered_map<std::string, onnxruntime::TensorShape> m_inferredInputShapes;

// For command list reuse
mutable ComPtr<ID3D12DescriptorHeap> m_heap;
mutable ComPtr<IDMLBindingTable> m_bindingTable;
mutable ComPtr<ID3D12CommandAllocator> m_commandAllocator;
mutable ComPtr<ID3D12GraphicsCommandList> m_graphicsCommandList;

// Bindings from previous executions of a re-used command list
mutable std::vector<uint64_t> m_inputBindingAllocIds;
mutable std::vector<uint64_t> m_outputBindingAllocIds;
mutable uint64_t m_tempBindingAllocId = 0;

// Fence tracking the status of the command list's last execution, and whether its descriptor heap
// can safely be updated.
mutable ComPtr<ID3D12Fence> m_fence;
mutable uint64_t m_completionValue = 0;
};

onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel(
Expand Down

0 comments on commit 025695e

Please sign in to comment.