diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 5c7b7bff1e370..e555c8b344548 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -64,7 +64,8 @@ namespace Dml void TranslateAndCompileGraph( const onnxruntime::OpKernelInfo& kernelInfo, std::vector>& initializeResourceRefs, - std::vector initInputBindings) const + std::vector initInputBindings, + bool reuseCommandList) const { // Allocate a persistent resource and initialize the operator UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; @@ -89,6 +90,11 @@ namespace Dml initializeResourceRefs.end(), [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } ); + + if (reuseCommandList) + { + BuildReusableCommandList(); + } } onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override @@ -152,6 +158,9 @@ namespace Dml if (recompileNeeded) { + m_tempBindingAllocId = 0; + m_completionValue = 0; + // Go through all the node args and replace their shapes with the real ones for (auto& nodeArg : m_intermediateNodeArgs) { @@ -213,42 +222,51 @@ namespace Dml TranslateAndCompileGraph( Info(), initializeResourceRefs, - initInputBindings); + initInputBindings, + graphDesc.reuseCommandList); } - // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator - OpKernelContextWrapper contextWrapper( - kernelContext, - Info().GetExecutionProvider(), - true, - nullptr); + if (!m_graphicsCommandList || + (m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue)) + { + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); - ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); - // Get input resources for execution, excluding those which were specified as owned by DML and provided - // at initialization instead. - std::vector> inputTensors(kernelContext->InputCount()); - std::vector inputPtrs(kernelContext->InputCount()); + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); - for (int i = 0; i < kernelContext->InputCount(); ++i) - { - if (!m_inputsUsed[i]) + for (int i = 0; i < kernelContext->InputCount(); ++i) { - continue; + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); } - ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); - inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); - } + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + outputTensors); - auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); - ExecuteOperator( - m_compiledExecutionPlanOperator.Get(), - m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, - inputPtrs, - outputTensors); - - ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + } + else + { + ExecuteReusableCommandList(kernelContext); + } return onnxruntime::Status::OK(); } @@ -317,6 +335,185 @@ namespace Dml } private: + void BuildReusableCommandList() const + { + ComPtr device; + ORT_THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf())); + + DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); + + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + desc.NumDescriptors = execBindingProps.RequiredDescriptorCount; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + + ComPtr d3dDevice; + ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); + + ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(m_heap.ReleaseAndGetAddressOf()))); + + // Create a binding table for execution. + DML_BINDING_TABLE_DESC bindingTableDesc = {}; + bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get(); + bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart(); + bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart(); + bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount; + + ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable))); + + ORT_THROW_IF_FAILED(d3dDevice->CreateCommandAllocator( + m_provider->GetCommandListTypeForQueue(), + IID_GRAPHICS_PPV_ARGS(m_commandAllocator.ReleaseAndGetAddressOf()))); + + ORT_THROW_IF_FAILED(d3dDevice->CreateCommandList( + 0, + m_provider->GetCommandListTypeForQueue(), + m_commandAllocator.Get(), + nullptr, + IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf()))); + + if (m_persistentResource) + { + DML_BINDING_DESC persistentResourceBindingDesc = + { DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr }; + m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc); + } + + ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() }; + m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); + + ComPtr recorder; + ORT_THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf()))); + + recorder->RecordDispatch(m_graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get()); + + ORT_THROW_IF_FAILED(m_graphicsCommandList->Close()); + } + + void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext) const + { + DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); + + std::vector inputBindings(kernelContext->InputCount()); + std::vector inputBindingDescs(kernelContext->InputCount()); + + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + // Populate input bindings, excluding those which were specified as owned by DML and provided + // at initialization instead. + m_inputBindingAllocIds.resize(inputBindings.size()); + bool inputBindingsChanged = false; + + for (uint32_t i = 0; i < inputBindings.size(); ++i) + { + if (m_inputsUsed[i]) + { + assert(kernelContext->InputType(gsl::narrow_cast(i))->IsTensorType()); + const onnxruntime::Tensor* tensor = kernelContext->Input(gsl::narrow_cast(i)); + + uint64_t allocId; + DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); + inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId); + inputBindings[i].Buffer->Release(); // Avoid holding an additional reference + inputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); + inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; + m_inputBindingAllocIds[i] = allocId; + } + } + + if (inputBindingsChanged) + { + m_bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); + } + + // Populate Output bindings + std::vector outputBindings(kernelContext->OutputCount()); + std::vector outputBindingDescs(kernelContext->OutputCount()); + + m_outputBindingAllocIds.resize(outputBindings.size()); + bool outputBindingsChanged = false; + + for (uint32_t i = 0; i < outputBindings.size(); ++i) + { + std::vector outputDims; + outputDims.reserve(m_outputShapes.GetShape(i).size()); + for (uint32_t dimSize : m_outputShapes.GetShape(i)) + { + outputDims.push_back(dimSize); + } + + onnxruntime::Tensor* tensor = kernelContext->Output( + static_cast(i), + onnxruntime::TensorShape::FromExistingBuffer(outputDims) + ); + + uint64_t allocId; + DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId); + outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId); + outputBindings[i].Buffer->Release(); // Avoid holding an additional reference + outputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); + outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; + m_outputBindingAllocIds[i] = allocId; + } + + if (outputBindingsChanged) + { + m_bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); + } + + if (execBindingProps.TemporaryResourceSize > 0) + { + // Allocate temporary data which will automatically be freed when the GPU work + // which is scheduled up to the point that this method returns has completed. + ComPtr tempAlloc; + uint64_t tempAllocId = 0; + ORT_THROW_IF_FAILED(contextWrapper.AllocateTemporaryData(static_cast(execBindingProps.TemporaryResourceSize), tempAlloc.GetAddressOf(), &tempAllocId)); + + ComPtr tempResourceUnk; + m_winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnk); + + // Bind the temporary resource. + ComPtr tempResource; + ORT_THROW_IF_FAILED(tempResourceUnk->QueryInterface(tempResource.GetAddressOf())); + DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize}; + DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; + + if (!tempAllocId || m_tempBindingAllocId != tempAllocId) + { + m_bindingTable->BindTemporaryResource(&tempBindingDesc); + } + + m_tempBindingAllocId = tempAllocId; + } + + // Execute the command list and if it succeeds, update the fence value at which this command may be + // re-used. + ComPtr fence; + uint64_t completionValue; + HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); + + if (hr == DXGI_ERROR_DEVICE_REMOVED) + { + ComPtr device; + ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(&device)); + ORT_THROW_IF_FAILED(device->GetDeviceRemovedReason()); + } + + ORT_THROW_IF_FAILED(hr); + m_fence = fence; + m_completionValue = completionValue; + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_graphicsCommandList).Get()); + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_heap).Get()); + m_winmlProvider->QueueReference(m_bindingTable.Get()); + m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); + } + ComPtr m_winmlProvider; ComPtr m_provider; @@ -341,6 +538,22 @@ namespace Dml mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; mutable std::unordered_map m_inferredInputShapes; + + // For command list reuse + mutable ComPtr m_heap; + mutable ComPtr m_bindingTable; + mutable ComPtr m_commandAllocator; + mutable ComPtr m_graphicsCommandList; + + // Bindings from previous executions of a re-used command list + mutable std::vector m_inputBindingAllocIds; + mutable std::vector m_outputBindingAllocIds; + mutable uint64_t m_tempBindingAllocId = 0; + + // Fence tracking the status of the command list's last execution, and whether its descriptor heap + // can safely be updated. + mutable ComPtr m_fence; + mutable uint64_t m_completionValue = 0; }; onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel(