diff --git a/CMakeLists.txt b/CMakeLists.txt index b325c2763..a91c7f871 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,14 @@ endif() if(USE_DML) if(WIN32) add_compile_definitions(USE_DML=1) + add_compile_definitions(NOMINMAX) + + file(GLOB dml_srcs CONFIGURE_DEPENDS + "${PROJECT_SOURCE_DIR}/src/dml/*.h" + "${PROJECT_SOURCE_DIR}/src/dml/*.cpp" + ) + + list(APPEND generator_srcs ${dml_srcs}) else() message(FATAL_ERROR "USE_DML is ON but this isn't windows.") endif() @@ -135,6 +143,39 @@ endif() file(GLOB onnxruntime_libs "${ORT_LIB_DIR}/${ONNXRUNTIME_FILES}") if(USE_DML) list(APPEND onnxruntime_libs "${ORT_LIB_DIR}/DirectML.dll") + target_include_directories(onnxruntime-genai PRIVATE $) + target_include_directories(onnxruntime-genai PRIVATE $/directx) + target_include_directories(onnxruntime-genai PRIVATE $) + target_include_directories(onnxruntime-genai-static PUBLIC $) + target_include_directories(onnxruntime-genai-static PUBLIC $/directx) + target_include_directories(onnxruntime-genai-static PUBLIC $) + target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib) + target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib) + + get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps ABSOLUTE) + set(DXC_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.Direct3D.DXC.1.7.2308.12) + set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/nuget.config) + set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/packages.config) + + add_custom_command( + OUTPUT + ${DXC_PACKAGE_DIR}/build/native/bin/x64/dxc.exe + DEPENDS + ${PACKAGES_CONFIG} + ${NUGET_CONFIG} + COMMAND ${CMAKE_CURRENT_BINARY_DIR}/nuget/src/nuget restore ${PACKAGES_CONFIG} -PackagesDirectory ${PACKAGES_DIR} -ConfigFile ${NUGET_CONFIG} + VERBATIM + ) + + add_custom_target( + RESTORE_PACKAGES ALL + DEPENDS + ${DXC_PACKAGE_DIR}/build/native/bin/x64/dxc.exe + ) + + add_dependencies(RESTORE_PACKAGES nuget) + add_dependencies(onnxruntime-genai RESTORE_PACKAGES) + add_dependencies(onnxruntime-genai-static RESTORE_PACKAGES) endif() if(NO_TOKENIZEROOT) diff --git a/benchmark/python/benchmark_e2e.py b/benchmark/python/benchmark_e2e.py index 04b4d0d14..d18f6acc9 100644 --- a/benchmark/python/benchmark_e2e.py +++ b/benchmark/python/benchmark_e2e.py @@ -13,13 +13,17 @@ from tqdm import tqdm # Use input model to generate prompt -def generate_prompt(model, tokenizer, prompt_length) -> str: +def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str: temperature = 1.0 prompt = "a" tokens = tokenizer.encode(prompt) params=og.GeneratorParams(model) params.set_search_options(do_sample=True, top_k=5, temperature=temperature, max_length=prompt_length, min_length=prompt_length+1) params.input_ids = tokens + + if use_graph_capture: + params.try_use_cuda_graph_with_max_batch_size(1) + generator=og.Generator(model, params) while not generator.is_done(): generator.compute_logits() @@ -63,13 +67,16 @@ def main(args): tokenizer = og.Tokenizer(model) # Generate prompt - prompt = [generate_prompt(model, tokenizer, prompt_length)] * batch_size + prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size tokens = tokenizer.encode_batch(prompt) params = og.GeneratorParams(model) params.input_ids = tokens params.set_search_options(do_sample=True, top_k=args.top_k, top_p=args.top_p, temperature=temperature, max_length=max_length, min_length=max_length) + if args.use_graph_capture: + params.try_use_cuda_graph_with_max_batch_size(batch_size) + if args.verbose: print("Running warmup runs...") for _ in tqdm(range(args.warmup)): generator = og.Generator(model, params) @@ -100,6 +107,10 @@ def main(args): params = og.GeneratorParams(model) params.input_ids = tokens params.set_search_options(max_length=max_length, min_length=max_length) + + if args.use_graph_capture: + params.try_use_cuda_graph_with_max_batch_size(batch_size) + generator = og.Generator(model, params) # Measure prompt processing @@ -199,5 +210,6 @@ def main(args): parser.add_argument('-o', '--output', type=str, default='genai_e2e', help='Output CSV file name or path (with .csv extension)') parser.add_argument('-v', '--verbose', action='store_true', help='Print extra information') parser.add_argument('-mo', '--print_model_output', action='store_true', help='Print model output') + parser.add_argument('-gc', '--use_graph_capture', action='store_true', help='Use the graph capture feature for CUDA or DML') args = parser.parse_args() main(args) diff --git a/cmake/deps.txt b/cmake/deps.txt index 99f7dc86a..c9a84529f 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,4 +12,6 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 +microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 +directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 77967b8c8..90b9b77af 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -44,3 +44,37 @@ FetchContent_Declare( onnxruntime_fetchcontent_makeavailable(googletest) +if(USE_DML) + set(WIL_BUILD_PACKAGING OFF CACHE BOOL "" FORCE) + set(WIL_BUILD_TESTS OFF CACHE BOOL "" FORCE) + + FetchContent_Declare( + microsoft_wil + URL ${DEP_URL_microsoft_wil} + URL_HASH SHA1=${DEP_SHA1_microsoft_wil} + FIND_PACKAGE_ARGS NAMES wil + ) + + onnxruntime_fetchcontent_makeavailable(microsoft_wil) + set(WIL_TARGET "WIL::WIL") + + FetchContent_Declare( + directx_headers + URL ${DEP_URL_directx_headers} + URL_HASH SHA1=${DEP_SHA1_directx_headers} + ) + + onnxruntime_fetchcontent_makeavailable(directx_headers) + set(DIRECTX_HEADERS_TARGET "DirectX-Headers") + + include(ExternalProject) + ExternalProject_Add(nuget + PREFIX nuget + URL "https://dist.nuget.org/win-x86-commandline/v5.3.0/nuget.exe" + DOWNLOAD_NO_EXTRACT 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND "" + ) +endif() \ No newline at end of file diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index ed1e02058..80da5093e 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -27,6 +27,7 @@ def main(args): input_tokens = tokenizer.encode(args.system_prompt + text) params = og.GeneratorParams(model) + params.try_use_cuda_graph_with_max_batch_size(1) params.set_search_options(do_sample=args.do_random_sampling, max_length=args.max_length, min_length=args.min_length, top_p=args.top_p, top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty) params.input_ids = input_tokens generator = og.Generator(model, params) diff --git a/generate_dml_shaders.bat b/generate_dml_shaders.bat new file mode 100644 index 000000000..62e298d6f --- /dev/null +++ b/generate_dml_shaders.bat @@ -0,0 +1,4 @@ +.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_update_attention_mask.hlsl -E CSMain -T cs_6_2 -DT=int32_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\update_mask_int32.h +.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_update_attention_mask.hlsl -E CSMain -T cs_6_2 -DT=int64_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\update_mask_int64.h +.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_increment_values.hlsl -E CSMain -T cs_6_2 -DT=int32_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\increment_values_int32.h +.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_increment_values.hlsl -E CSMain -T cs_6_2 -DT=int64_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\increment_values_int64.h diff --git a/nuget.config b/nuget.config new file mode 100644 index 000000000..3e0389a52 --- /dev/null +++ b/nuget.config @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/packages.config b/packages.config new file mode 100644 index 000000000..d4902c51f --- /dev/null +++ b/packages.config @@ -0,0 +1,4 @@ + + + + diff --git a/src/dml/dml_command_allocator_ring.h b/src/dml/dml_command_allocator_ring.h new file mode 100644 index 000000000..74722dc68 --- /dev/null +++ b/src/dml/dml_command_allocator_ring.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "dml_gpu_event.h" + +// A fixed-size ring of command allocators. Each time an allocator is retrieved, the allocator will +// be reset if its previously recorded commands have finished executing on the GPU. +template +class DmlCommandAllocatorRing { + public: + DmlCommandAllocatorRing( + ID3D12Device* device, + D3D12_COMMAND_LIST_TYPE commandListType, + DmlGpuEvent initialEvent) { + for (auto& info : command_allocators_) { + THROW_IF_FAILED(device->CreateCommandAllocator( + commandListType, + IID_PPV_ARGS(info.allocator.ReleaseAndGetAddressOf()))); + + info.completion_event = initialEvent; + } + } + + ID3D12CommandAllocator* GetNextAllocator(DmlGpuEvent next_completion_event) { + size_t earliest_other_allocator = (current_command_allocator_ + 1) % AllocatorCount; + + assert(!command_allocators_[current_command_allocator_].completion_event.IsSignaled() || + command_allocators_[earliest_other_allocator].completion_event.IsSignaled()); + + if (command_allocators_[earliest_other_allocator].completion_event.IsSignaled()) { + THROW_IF_FAILED(command_allocators_[earliest_other_allocator].Get()->Reset()); + current_command_allocator_ = earliest_other_allocator; + } + + // Set the completion event for the current allocator so it can be reset eventually. + command_allocators_[current_command_allocator_].completion_event = next_completion_event; + + return command_allocators_[current_command_allocator_].Get(); + } + + // Updates the completion event of the current allocator to a different value. This is used when the caller + // decides to issue an unrelated call to the queue such as ExecuteCommandLists which updates its fence between calling + // GetNextAllocator and executing the work which it recorded using the allocator it received. + void UpdateCurrentAllocatorCompletionEvent(DmlGpuEvent next_completion_event) { + command_allocators_[current_command_allocator_].completion_event = next_completion_event; + } + + private: + struct CommandAllocatorInfo { + ComPtr allocator; + + // The event which will be signaled when the last command list submitted using this allocator + // completes execution on the GPU. + DmlGpuEvent completion_event = {}; + + ID3D12CommandAllocator* Get() const { return allocator.Get(); } + }; + + std::array command_allocators_; + size_t current_command_allocator_ = 0; +}; \ No newline at end of file diff --git a/src/dml/dml_command_queue.cpp b/src/dml/dml_command_queue.cpp new file mode 100644 index 000000000..bb430c1cd --- /dev/null +++ b/src/dml/dml_command_queue.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "dml_command_queue.h" + +DmlCommandQueue::DmlCommandQueue(ID3D12CommandQueue* existing_queue) + : queue_(existing_queue), type_(existing_queue->GetDesc().Type) { + ComPtr device; + THROW_IF_FAILED(queue_->GetDevice(IID_PPV_ARGS(device.GetAddressOf()))); + THROW_IF_FAILED(device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(fence_.ReleaseAndGetAddressOf()))); +} + +void DmlCommandQueue::ExecuteCommandList(ID3D12CommandList* command_list) { + ExecuteCommandLists(std::span(&command_list, 1)); +} + +void DmlCommandQueue::ExecuteCommandLists(std::span command_lists) { + queue_->ExecuteCommandLists(static_cast(command_lists.size()), command_lists.data()); + + ++last_fence_value_; + THROW_IF_FAILED(queue_->Signal(fence_.Get(), last_fence_value_)); +} + +void DmlCommandQueue::Wait(ID3D12Fence* fence, uint64_t value) { + THROW_IF_FAILED(queue_->Wait(fence, value)); + + ++last_fence_value_; + THROW_IF_FAILED(queue_->Signal(fence_.Get(), last_fence_value_)); +} + +DmlGpuEvent DmlCommandQueue::GetCurrentCompletionEvent() { + return DmlGpuEvent{last_fence_value_, fence_}; +} + +DmlGpuEvent DmlCommandQueue::GetNextCompletionEvent() { + return DmlGpuEvent{last_fence_value_ + 1, fence_}; +} + +void DmlCommandQueue::QueueReference(IUnknown* object, bool wait_for_unsubmitted_work) { + // If the DmlCommandQueue is closing, then queued_references_ is being cleared -- it is not OK + // to queue additional references at this time, since those references would be leaked. This + // affects any objects in queued_references_ whose destructors indirectly call QueueReference; + // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference + // to its underlying D3D resource when freed. Furthermore, these references are unnecessary + // since Close() already blocks for scheduled GPU work before clearing queued_references_. + if (!closing_) { + QueuedReference queued_reference = {GetLastFenceValue(), object}; + + // If something has been recorded into a command list but not submitted yet, it means that the *next* fence + // value is the one to signal completion. + if (wait_for_unsubmitted_work) { + ++queued_reference.fence_value; + } + + queued_references_.push_back(queued_reference); + } +} + +void DmlCommandQueue::Close() { + // Wait for flushed work: + assert(!closing_); + closing_ = true; + DmlGpuEvent event = GetCurrentCompletionEvent(); + event.WaitForSignal(); + queued_references_.clear(); + closing_ = false; +} + +void DmlCommandQueue::ReleaseCompletedReferences() { + uint64_t completed_value = GetFence()->GetCompletedValue(); + while (!queued_references_.empty() && queued_references_.front().fence_value <= completed_value) { + queued_references_.pop_front(); + } +} \ No newline at end of file diff --git a/src/dml/dml_command_queue.h b/src/dml/dml_command_queue.h new file mode 100644 index 000000000..cbafdef99 --- /dev/null +++ b/src/dml/dml_command_queue.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "../span.h" +#include "dml_gpu_event.h" + +// Manages a D3D12 command queue and provides a waitable fence which is signaled with a monotonically increasing +// value once each execute completes on the GPU. +class DmlCommandQueue { + public: + // Creates a DmlCommandQueue object that wraps an existing D3D12 queue. + DmlCommandQueue(ID3D12CommandQueue* existing_queue); + + D3D12_COMMAND_LIST_TYPE GetType() const { return type_; } + ComPtr GetFence() const { return fence_; } + uint64_t GetLastFenceValue() const { return last_fence_value_; } + + void ExecuteCommandList(ID3D12CommandList* command_list); + void ExecuteCommandLists(std::span command_lists); + + // Queues a wait to block the GPU until the specified fence is signaled to a given value. + void Wait(ID3D12Fence* fence, uint64_t value); + + // Returns an event that will become signaled when everything submitted to the queue thus far has + // completed execution on the GPU. + DmlGpuEvent GetCurrentCompletionEvent(); + + // Returns an event that will become signaled after the next ExecuteCommandLists call. + DmlGpuEvent GetNextCompletionEvent(); + + void QueueReference(IUnknown* object, bool wait_for_unsubmitted_work); + + void Close(); + void ReleaseCompletedReferences(); + + private: + struct QueuedReference { + uint64_t fence_value; + ComPtr object; + }; + + std::deque queued_references_; + + ComPtr queue_; + D3D12_COMMAND_LIST_TYPE type_; + + ComPtr fence_; + uint64_t last_fence_value_ = 0; + bool closing_ = false; +}; \ No newline at end of file diff --git a/src/dml/dml_command_recorder.cpp b/src/dml/dml_command_recorder.cpp new file mode 100644 index 000000000..3f2fa97dd --- /dev/null +++ b/src/dml/dml_command_recorder.cpp @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "dml_command_recorder.h" +#include "dml_command_queue.h" +#include "../models/onnxruntime_api.h" + +DmlCommandRecorder::DmlCommandRecorder( + ID3D12Device* d3d_device, + IDMLDevice* dml_device, + std::shared_ptr command_queue, + Ort::Allocator& device_allocator, + const OrtDmlApi* ort_dml_api) + : queue_(std::move(command_queue)), + d3d_device_(d3d_device), + dml_device_(dml_device), + descriptor_pool_(d3d_device, 2048), + command_allocator_ring_(d3d_device, queue_->GetType(), queue_->GetCurrentCompletionEvent()), + device_allocator_(device_allocator), + ort_dml_api_(ort_dml_api) { + THROW_IF_FAILED(dml_device->CreateOperatorInitializer(0, nullptr, IID_PPV_ARGS(&initializer_))); + THROW_IF_FAILED(dml_device->CreateCommandRecorder(IID_PPV_ARGS(&recorder_))); +} + +void DmlCommandRecorder::CopyBufferRegion( + ID3D12Resource* dst_buffer, + uint64_t dst_offset, + ID3D12Resource* src_buffer, + uint64_t src_offset, + uint64_t byte_count) { + current_command_list_->CopyBufferRegion(dst_buffer, dst_offset, src_buffer, src_offset, byte_count); + operations_recorded_in_current_command_list = true; +} + +void DmlCommandRecorder::ExecuteCommandList( + ID3D12GraphicsCommandList* command_list, + _Outptr_ ID3D12Fence** fence, + _Out_ uint64_t* completion_value) { + if (!operations_recorded_in_current_command_list) { + // The caller can re-use relevant resources after the next set of work to be + // flushed has completed. Its command list hasn't been executed yet, just batched. + DmlGpuEvent gpu_event = queue_->GetNextCompletionEvent(); + gpu_event.fence.CopyTo(fence); + *completion_value = gpu_event.fence_value; + + queue_->ExecuteCommandLists(std::span(reinterpret_cast(&command_list), 1)); + + // The fence value at which the current command allocator may be re-used will now be higher + command_allocator_ring_.UpdateCurrentAllocatorCompletionEvent(queue_->GetNextCompletionEvent()); + + // Fail early if something horrifying happens + THROW_IF_FAILED(d3d_device_->GetDeviceRemovedReason()); + + return; + } + + // Remember the descriptor heap and apply it to the next command list. This avoids unnecessarily setting it onto + // the D3D object lazily at a point when the operation may not be parallelized with GPU work. + auto heap = current_descriptor_heap_; + + // Execute work in the current command list plus provided command list while closing the recorder. + CloseAndExecute(command_list); + Open(); + + // Reset the descriptor heap opportunistically per above comment + SetDescriptorHeap(heap); + + DmlGpuEvent gpu_event = queue_->GetCurrentCompletionEvent(); + gpu_event.fence.CopyTo(fence); + *completion_value = gpu_event.fence_value; +} + +ComPtr DmlCommandRecorder::GetCommandList() { + // Assume operations are added by the caller after this returns + operations_recorded_in_current_command_list = true; + return current_command_list_; +} + +void DmlCommandRecorder::ResourceBarrier(std::span barriers) { + current_command_list_->ResourceBarrier(static_cast(barriers.size()), barriers.data()); + operations_recorded_in_current_command_list = true; +} + +void DmlCommandRecorder::AddUAVBarrier() { +#pragma warning(suppress : 6387) + auto barrier = CD3DX12_RESOURCE_BARRIER::UAV(nullptr); + current_command_list_->ResourceBarrier(1, &barrier); + operations_recorded_in_current_command_list = true; +} + +void DmlCommandRecorder::Open() { + assert(current_descriptor_heap_ == nullptr); + + ID3D12CommandAllocator* allocator = command_allocator_ring_.GetNextAllocator(queue_->GetNextCompletionEvent()); + + if (!cached_command_list_) { + THROW_IF_FAILED(d3d_device_->CreateCommandList( + 0, + queue_->GetType(), + allocator, + nullptr, + IID_PPV_ARGS(current_command_list_.ReleaseAndGetAddressOf()))); + } else { + current_command_list_ = cached_command_list_; + cached_command_list_ = nullptr; + THROW_IF_FAILED(current_command_list_->Reset(allocator, nullptr)); + } +} + +void DmlCommandRecorder::CloseAndExecute() { + CloseAndExecute(nullptr); +} + +void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* command_list) { + THROW_IF_FAILED(current_command_list_->Close()); + + ID3D12GraphicsCommandList* command_lists_to_execute[2] = {}; + uint32_t command_lists_to_execute_count = 0; + + if (operations_recorded_in_current_command_list) { + command_lists_to_execute[command_lists_to_execute_count++] = current_command_list_.Get(); + } + + if (command_list) { + command_lists_to_execute[command_lists_to_execute_count++] = command_list; + } + + if (command_lists_to_execute_count > 0) { + queue_->ExecuteCommandLists(std::span(reinterpret_cast(command_lists_to_execute), command_lists_to_execute_count)); + } + + cached_command_list_ = current_command_list_; + current_command_list_ = nullptr; + operations_recorded_in_current_command_list = false; + + // The descriptor heap must be set on the command list the next time it's opened. + current_descriptor_heap_ = nullptr; + + // Fail early if something horrifying happens + THROW_IF_FAILED(d3d_device_->GetDeviceRemovedReason()); +} + +void DmlCommandRecorder::SetDescriptorHeap(ID3D12DescriptorHeap* descriptor_heap) { + if (descriptor_heap != nullptr && descriptor_heap != current_descriptor_heap_) { + current_descriptor_heap_ = descriptor_heap; + + ID3D12DescriptorHeap* descriptor_heaps[] = {descriptor_heap}; + current_command_list_->SetDescriptorHeaps(ARRAYSIZE(descriptor_heaps), descriptor_heaps); + } +} + +void DmlCommandRecorder::InitializeOperator( + IDMLCompiledOperator* op, + const DML_BINDING_DESC& persistent_resource_binding, + const DML_BINDING_DESC& input_array_binding) { + // Reset the initializer to reference the input operator. + IDMLCompiledOperator* ops[] = {op}; + THROW_IF_FAILED(initializer_->Reset(ARRAYSIZE(ops), ops)); + + DML_BINDING_PROPERTIES init_binding_props = initializer_->GetBindingProperties(); + + const uint32_t num_descriptors = init_binding_props.RequiredDescriptorCount; + DmlDescriptorRange descriptor_range = descriptor_pool_.AllocDescriptors( + num_descriptors, + queue_->GetNextCompletionEvent()); + + // Create a binding table for initialization. + DML_BINDING_TABLE_DESC binding_table_desc = {}; + binding_table_desc.Dispatchable = initializer_.Get(); + binding_table_desc.CPUDescriptorHandle = descriptor_range.cpuHandle; + binding_table_desc.GPUDescriptorHandle = descriptor_range.gpuHandle; + binding_table_desc.SizeInDescriptors = num_descriptors; + + ComPtr binding_table; + THROW_IF_FAILED(dml_device_->CreateBindingTable(&binding_table_desc, IID_PPV_ARGS(&binding_table))); + + // Create a temporary resource for initializing the op, if it's required. + uint64_t temporary_resource_size = init_binding_props.TemporaryResourceSize; + if (temporary_resource_size > 0) { + // Allocate and immediately free a temporary buffer. The buffer resource will still be + // alive (managed by the pool); freeing allows the resource to be shared with other operators. + std::array temporary_resource_shape = {static_cast(temporary_resource_size)}; + + ComPtr buffer; + auto temp_resource = OrtValue::CreateTensor(device_allocator_, temporary_resource_shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Ort::ThrowOnError(ort_dml_api_->GetD3D12ResourceFromAllocation(&device_allocator_, temp_resource->GetTensorMutableRawData(), &buffer)); + + // Bind the temporary resource. + DML_BUFFER_BINDING buffer_binding = {buffer.Get(), 0, temporary_resource_size}; + DML_BINDING_DESC binding_desc = {DML_BINDING_TYPE_BUFFER, &buffer_binding}; + binding_table->BindTemporaryResource(&binding_desc); + } + + // Bind inputs, if provided. + if (input_array_binding.Type != DML_BINDING_TYPE_NONE) { + // An operator with inputs to bind MUST use a BUFFER_ARRAY. + assert(input_array_binding.Type == DML_BINDING_TYPE_BUFFER_ARRAY); + binding_table->BindInputs(1, &input_array_binding); + } + + // Bind the persistent resource, which is an output of initialization. + if (persistent_resource_binding.Type != DML_BINDING_TYPE_NONE) { + // Persistent resources MUST be bound as buffers. + assert(persistent_resource_binding.Type == DML_BINDING_TYPE_BUFFER); + binding_table->BindOutputs(1, &persistent_resource_binding); + } + + // Record the initialization work. + SetDescriptorHeap(descriptor_range.heap); + recorder_->RecordDispatch(current_command_list_.Get(), initializer_.Get(), binding_table.Get()); + operations_recorded_in_current_command_list = true; + + // Barrier if there's an output (i.e. persistent resource), or if any temps are used. + if ((persistent_resource_binding.Type != DML_BINDING_TYPE_NONE) || + (temporary_resource_size > 0)) { + auto uav = CD3DX12_RESOURCE_BARRIER::UAV(nullptr); + current_command_list_->ResourceBarrier(1, &uav); + } +} + +void DmlCommandRecorder::ExecuteOperator( + IDMLCompiledOperator* op, + const DML_BINDING_DESC& persistent_resource_binding, + std::span input_bindings, + std::span output_bindings) { + DML_BINDING_PROPERTIES exec_binding_props = op->GetBindingProperties(); + + const uint32_t num_descriptors = exec_binding_props.RequiredDescriptorCount; + DmlDescriptorRange descriptor_range = descriptor_pool_.AllocDescriptors( + num_descriptors, + queue_->GetNextCompletionEvent()); + + // Create a binding table for execution. + DML_BINDING_TABLE_DESC binding_table_desc = {}; + binding_table_desc.Dispatchable = op; + binding_table_desc.CPUDescriptorHandle = descriptor_range.cpuHandle; + binding_table_desc.GPUDescriptorHandle = descriptor_range.gpuHandle; + binding_table_desc.SizeInDescriptors = num_descriptors; + + ComPtr binding_table; + THROW_IF_FAILED(dml_device_->CreateBindingTable(&binding_table_desc, IID_PPV_ARGS(&binding_table))); + + // Create a temporary resource for executing the op, if it's required. + uint64_t temporary_resource_size = exec_binding_props.TemporaryResourceSize; + if (temporary_resource_size > 0) { + // Allocate and immediately free a temporary buffer. The buffer resource will still be + // alive (managed by the pool); freeing allows the resource to be shared with other operators. + std::array temporary_resource_shape = {static_cast(temporary_resource_size)}; + + ComPtr buffer; + auto temp_resource = OrtValue::CreateTensor(device_allocator_, temporary_resource_shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Ort::ThrowOnError(ort_dml_api_->GetD3D12ResourceFromAllocation(&device_allocator_, temp_resource->GetTensorMutableRawData(), &buffer)); + + // Bind the temporary resource. + DML_BUFFER_BINDING buffer_binding = {buffer.Get(), 0, temporary_resource_size}; + DML_BINDING_DESC binding_desc = {DML_BINDING_TYPE_BUFFER, &buffer_binding}; + binding_table->BindTemporaryResource(&binding_desc); + } + + if (persistent_resource_binding.Type != DML_BINDING_TYPE_NONE) { + binding_table->BindPersistentResource(&persistent_resource_binding); + } + + binding_table->BindInputs(static_cast(input_bindings.size()), input_bindings.data()); + binding_table->BindOutputs(static_cast(output_bindings.size()), output_bindings.data()); + + // Record the execution work. + SetDescriptorHeap(descriptor_range.heap); + recorder_->RecordDispatch(current_command_list_.Get(), op, binding_table.Get()); + operations_recorded_in_current_command_list = true; + +// Barrier all outputs. +#pragma warning(push) +#pragma warning(disable : 6387) + auto uav = CD3DX12_RESOURCE_BARRIER::UAV(nullptr); + current_command_list_->ResourceBarrier(1, &uav); +#pragma warning(pop) +} \ No newline at end of file diff --git a/src/dml/dml_command_recorder.h b/src/dml/dml_command_recorder.h new file mode 100644 index 000000000..656d04f38 --- /dev/null +++ b/src/dml/dml_command_recorder.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "../span.h" +#include "dml_command_allocator_ring.h" +#include "dml_descriptor_pool.h" +#include "dml_command_queue.h" +#include "dml_descriptor_pool.h" +#include "dml_provider_factory.h" + +struct OrtDmlApi; + +namespace Ort { +struct Allocator; +} + +class DmlCommandRecorder { + public: + DmlCommandRecorder( + ID3D12Device* d3d_device, + IDMLDevice* dml_device, + std::shared_ptr command_queue, + Ort::Allocator& device_allocator, + const OrtDmlApi* ort_dml_api); + + void InitializeOperator( + IDMLCompiledOperator* op, + const DML_BINDING_DESC& persistent_resource_binding, + const DML_BINDING_DESC& input_array_binding); + + void ExecuteOperator( + IDMLCompiledOperator* op, + const DML_BINDING_DESC& persistent_resource_binding, + std::span input_bindings, + std::span output_bindings); + + void CopyBufferRegion( + ID3D12Resource* dst_buffer, + uint64_t dst_offset, + ID3D12Resource* src_buffer, + uint64_t src_offset, + uint64_t byte_count); + + void ExecuteCommandList( + ID3D12GraphicsCommandList* command_list, + _Outptr_ ID3D12Fence** fence, + _Out_ uint64_t* completion_value); + + ComPtr GetCommandList(); + + void ResourceBarrier(std::span barriers); + void AddUAVBarrier(); + + void Open(); + void CloseAndExecute(); + + bool HasUnsubmittedWork() { + return operations_recorded_in_current_command_list; + } + + // Forces the descriptor heap to be reset to D3D before executing future operations + void InvalidateDescriptorHeap() { + current_descriptor_heap_ = nullptr; + } + + private: + void CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* command_list); + + std::shared_ptr queue_; + ComPtr d3d_device_; + Microsoft::WRL::ComPtr dml_device_; + Microsoft::WRL::ComPtr initializer_; + Microsoft::WRL::ComPtr recorder_; + + // Descriptors are allocated from a pool. The current heap pointer is only used to avoid redundantly + // setting the same heap; it does not have ownership of the heap object. + DescriptorPool descriptor_pool_; + ID3D12DescriptorHeap* current_descriptor_heap_ = nullptr; + + DmlCommandAllocatorRing<2> command_allocator_ring_; + + // The command list currently being recorded into, and whether any command have been recorded yet. + ComPtr current_command_list_; + bool operations_recorded_in_current_command_list = false; + + // A cached command list which may be re-used. + ComPtr cached_command_list_; + + Ort::Allocator& device_allocator_; + const OrtDmlApi* ort_dml_api_; + + void SetDescriptorHeap(ID3D12DescriptorHeap* descriptor_heap); +}; \ No newline at end of file diff --git a/src/dml/dml_descriptor_pool.cpp b/src/dml/dml_descriptor_pool.cpp new file mode 100644 index 000000000..ff88752e1 --- /dev/null +++ b/src/dml/dml_descriptor_pool.cpp @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "dml_descriptor_pool.h" + +DmlDescriptorHeap::DmlDescriptorHeap(ID3D12DescriptorHeap* heap) : heap_(heap), + capacity_(heap->GetDesc().NumDescriptors), + head_cpu_handle_(heap->GetCPUDescriptorHandleForHeapStart()), + head_gpu_handle_(heap->GetGPUDescriptorHandleForHeapStart()), + heap_flags_(heap->GetDesc().Flags) { + ComPtr device; + THROW_IF_FAILED(heap->GetDevice(IID_PPV_ARGS(device.GetAddressOf()))); + handle_increment_size_ = device->GetDescriptorHandleIncrementSize(heap->GetDesc().Type); +} + +std::optional DmlDescriptorHeap::TryAllocDescriptors( + uint32_t num_descriptors, + DmlGpuEvent completion_event, + D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags) { + // Bail if the desired heap creation flags are incompatible with the existing heap. + if (heap_flags_ != heap_flags) { + return std::nullopt; + } + + if ((completion_event_.fence != nullptr) && (completion_event_.IsSignaled())) { + // This class always allocates descriptors from the end of the heap. + // If the most recent completion event is signaled, then all previous + // allocations have completed; the entire capacity is available to use. + size_ = 0; + head_cpu_handle_ = heap_->GetCPUDescriptorHandleForHeapStart(); + head_gpu_handle_ = heap_->GetGPUDescriptorHandleForHeapStart(); + } + + // The caller will need to create a new heap if there is no space left in this one. + uint32_t space_remaining = capacity_ - size_; + if (space_remaining < num_descriptors) { + return std::nullopt; + } + + DmlDescriptorRange range = {heap_.Get(), head_cpu_handle_, head_gpu_handle_}; + + size_ += num_descriptors; + completion_event_ = completion_event; + head_cpu_handle_.Offset(num_descriptors, handle_increment_size_); + head_gpu_handle_.Offset(num_descriptors, handle_increment_size_); + + return range; +} + +DescriptorPool::DescriptorPool(ID3D12Device* device, uint32_t initial_capacity) : device_(device), + initial_heap_capacity_(initial_capacity) { + CreateHeap(initial_capacity, D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE); +} + +DmlDescriptorRange DescriptorPool::AllocDescriptors( + uint32_t num_descriptors, + DmlGpuEvent completion_event, + D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags) { + // Attempt to allocate from an existing heap. + for (DmlDescriptorHeap& heap : heaps_) { + auto descriptor_range = heap.TryAllocDescriptors(num_descriptors, completion_event, heap_flags); + if (descriptor_range.has_value()) { + return descriptor_range.value(); + } + } + + // A new descriptor heap must be created. + uint32_t new_heap_capacity = std::max(num_descriptors, initial_heap_capacity_); + CreateHeap(new_heap_capacity, heap_flags); + auto descriptor_range = heaps_.back().TryAllocDescriptors(num_descriptors, completion_event, heap_flags); + assert(descriptor_range.has_value()); + return descriptor_range.value(); +} + +void DescriptorPool::Trim() { + // Remove any heaps that are not pending execution. + auto it = std::remove_if(heaps_.begin(), heaps_.end(), [](const DmlDescriptorHeap& heap) { + auto completion_event = heap.GetLastCompletionEvent(); + return !completion_event.fence || completion_event.IsSignaled(); + }); + + heaps_.erase(it, heaps_.end()); +} + +void DescriptorPool::CreateHeap(uint32_t num_descriptors, D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags) { + // This pool only manages CBV/SRV/UAV descriptors. + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Flags = heap_flags; + desc.NumDescriptors = num_descriptors; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + + ComPtr heap; + THROW_IF_FAILED(device_->CreateDescriptorHeap(&desc, IID_PPV_ARGS(heap.GetAddressOf()))); + + heaps_.push_back(DmlDescriptorHeap{heap.Get()}); +} + +uint32_t DescriptorPool::GetTotalCapacity() const { + uint32_t capacity = 0; + + for (auto& heap : heaps_) { + capacity += heap.GetCapacity(); + } + + return capacity; +} \ No newline at end of file diff --git a/src/dml/dml_descriptor_pool.h b/src/dml/dml_descriptor_pool.h new file mode 100644 index 000000000..e297a034e --- /dev/null +++ b/src/dml/dml_descriptor_pool.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "dml_gpu_event.h" + +// A contiguous range of descriptors. +struct DmlDescriptorRange { + ID3D12DescriptorHeap* heap; + D3D12_CPU_DESCRIPTOR_HANDLE cpuHandle; + D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle; +}; + +// Wraps an ID3D12DescriptorHeap to allocate descriptor ranges. +class DmlDescriptorHeap { + public: + // Wraps an existing heap. + explicit DmlDescriptorHeap(ID3D12DescriptorHeap* heap); + + // Reserves descriptors from the end of the heap. Returns nullopt if there is + // no space left in the heap. + std::optional TryAllocDescriptors( + uint32_t num_descriptors, + DmlGpuEvent completion_event, + D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE); + + DmlGpuEvent GetLastCompletionEvent() const { + return completion_event_; + } + + uint32_t GetCapacity() const { + return capacity_; + } + + private: + ComPtr heap_; + uint32_t capacity_ = 0; + uint32_t size_ = 0; + uint32_t handle_increment_size_ = 0; + CD3DX12_CPU_DESCRIPTOR_HANDLE head_cpu_handle_; + CD3DX12_GPU_DESCRIPTOR_HANDLE head_gpu_handle_; + D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags_ = D3D12_DESCRIPTOR_HEAP_FLAG_NONE; + + // Most recent GPU completion event. Allocations are always done at the end, + // so there is no fragmentation of the heap. + DmlGpuEvent completion_event_; +}; + +// Manages a pool of CBV/SRV/UAV descriptors. +class DescriptorPool { + public: + DescriptorPool(ID3D12Device* device, uint32_t initial_capacity); + + // Reserves a contiguous range of descriptors from a single descriptor heap. The + // lifetime of the referenced descriptor heap is managed by the DescriptorPool class. + // The caller must supply a DmlGpuEvent that informs the pool when the reserved descriptors + // are no longer required. + DmlDescriptorRange AllocDescriptors( + uint32_t num_descriptors, + DmlGpuEvent completion_event, + D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE); + + // Releases all descriptor heaps that contain only descriptors which have completed + // their work on the GPU. + void Trim(); + + // Returns the total capacity of all heaps. + uint32_t GetTotalCapacity() const; + + private: + ComPtr device_; + std::vector heaps_; + const uint32_t initial_heap_capacity_; + + void CreateHeap(uint32_t num_descriptors, D3D12_DESCRIPTOR_HEAP_FLAGS heap_flags); +}; \ No newline at end of file diff --git a/src/dml/dml_execution_context.cpp b/src/dml/dml_execution_context.cpp new file mode 100644 index 000000000..05d8b75fc --- /dev/null +++ b/src/dml/dml_execution_context.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "dml_execution_context.h" +#include "dml_command_queue.h" + +DmlExecutionContext::DmlExecutionContext( + ID3D12Device* d3d12_device, + IDMLDevice* dml_device, + ID3D12CommandQueue* queue, + Ort::Allocator& device_allocator, + const OrtDmlApi* ort_dml_api) + : queue_(std::make_shared(queue)), dml_recorder_(d3d12_device, dml_device, queue_, device_allocator, ort_dml_api) { +} + +void DmlExecutionContext::CopyBufferRegion( + ID3D12Resource* dst_buffer, + uint64_t dst_offset, + D3D12_RESOURCE_STATES dst_state, + ID3D12Resource* src_buffer, + uint64_t src_offset, + D3D12_RESOURCE_STATES src_state, + uint64_t byte_count) { + assert(!closed_); + + SetCommandRecorder(&dml_recorder_); + + std::vector barriers; + + if (!(dst_state & D3D12_RESOURCE_STATE_COPY_DEST)) { + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(dst_buffer, dst_state, D3D12_RESOURCE_STATE_COPY_DEST)); + } + if (!(src_state & D3D12_RESOURCE_STATE_COPY_SOURCE)) { + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(src_buffer, src_state, D3D12_RESOURCE_STATE_COPY_SOURCE)); + } + + if (!barriers.empty()) { + dml_recorder_.ResourceBarrier(barriers); + } + + dml_recorder_.CopyBufferRegion(dst_buffer, dst_offset, src_buffer, src_offset, byte_count); + + // Reset barrier state + if (!barriers.empty()) { + for (auto& barrier : barriers) { + std::swap(barrier.Transition.StateBefore, barrier.Transition.StateAfter); + } + + dml_recorder_.ResourceBarrier(barriers); + } +} + +void DmlExecutionContext::InitializeOperator( + IDMLCompiledOperator* op, + const DML_BINDING_DESC& persistent_resource_binding, + const DML_BINDING_DESC& input_array_binding) { + assert(!closed_); + SetCommandRecorder(&dml_recorder_); + + dml_recorder_.InitializeOperator(op, persistent_resource_binding, input_array_binding); +} + +void DmlExecutionContext::ExecuteCommandList( + ID3D12GraphicsCommandList* command_list, + _Outptr_ ID3D12Fence** fence, + _Out_ uint64_t* completion_value) { + assert(!closed_); + + SetCommandRecorder(&dml_recorder_); + dml_recorder_.ExecuteCommandList(command_list, fence, completion_value); +} + +void DmlExecutionContext::AddUAVBarrier() { + assert(!closed_); + SetCommandRecorder(&dml_recorder_); + + dml_recorder_.AddUAVBarrier(); +} + +void DmlExecutionContext::ResourceBarrier(std::span barriers) { + assert(!closed_); + SetCommandRecorder(&dml_recorder_); + + dml_recorder_.ResourceBarrier(barriers); +} + +void DmlExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** command_list) { + assert(!closed_); + SetCommandRecorder(&dml_recorder_); + + // Ensure the descriptor heap is reset to D3D as something external may change it before recording + dml_recorder_.InvalidateDescriptorHeap(); + + dml_recorder_.GetCommandList().CopyTo(command_list); +} + +void DmlExecutionContext::SetCommandRecorder(DmlCommandRecorder* new_recorder) { + assert(!closed_); + + // If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct + // ordering of operations on the command queue. + if (current_recorder_ != new_recorder) { + Flush(); + current_recorder_ = new_recorder; + + if (current_recorder_ != nullptr) { + current_recorder_->Open(); + } + } +} + +void DmlExecutionContext::Flush() { + assert(!closed_); + + if (!current_recorder_ || !current_recorder_->HasUnsubmittedWork()) { + // Nothing to flush + return; + } + + current_recorder_->CloseAndExecute(); + ReleaseCompletedReferences(); + + // Pre-emptively set the DML command recorder. It's the only command recorder right now, + // and doing this here causes work and allocations resetting the command list to occur at + // a point where it's going to be parallelized with GPU work. + current_recorder_ = nullptr; + SetCommandRecorder(&dml_recorder_); +} + +void DmlExecutionContext::QueueReference(IUnknown* object) { + assert(!closed_); + // If something has been recorded into a command list but not submitted yet, it means that the *next* fence + // value is the one to signal completion. + bool wait_for_unsubmitted_work = (current_recorder_ != nullptr); + queue_->QueueReference(object, wait_for_unsubmitted_work); +} + +void DmlExecutionContext::Close() { + assert(!closed_); + + // Discard unflushed work and clear queued references. This prevents the circular reference: + // Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel + queue_->Close(); + current_recorder_ = nullptr; + closed_ = true; +} + +DmlGpuEvent DmlExecutionContext::GetCurrentCompletionEvent() { + assert(!closed_); + + DmlGpuEvent event = queue_->GetCurrentCompletionEvent(); + + // If something has been recorded into a command list but not submitted yet, it means that the *next* fence + // value is the one to signal completion. + const bool unflushed_work_exists = (current_recorder_ != nullptr) && current_recorder_->HasUnsubmittedWork(); + if (unflushed_work_exists) { + ++event.fence_value; + } + + return event; +} + +void DmlExecutionContext::ReleaseCompletedReferences() { + assert(!closed_); + queue_->ReleaseCompletedReferences(); +} + +D3D12_COMMAND_LIST_TYPE DmlExecutionContext::GetCommandListTypeForQueue() const { + assert(!closed_); + return queue_->GetType(); +} \ No newline at end of file diff --git a/src/dml/dml_execution_context.h b/src/dml/dml_execution_context.h new file mode 100644 index 000000000..363e9f1bc --- /dev/null +++ b/src/dml/dml_execution_context.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "dml_command_recorder.h" +#include "dml_gpu_event.h" +#include "../models/onnxruntime_api.h" + +// Asynchronously performs GPU work, and automatically manages command list recording and submission to queues. +// Work submitted to the DmlExecutionContext is typically recorded onto a command list and may not immediately begin +// execution on the GPU. Call Flush() to force all recorded work to be submitted to the command queue for execution +// on the GPU. +class DmlExecutionContext { + public: + // Constructs an DmlExecutionContext that executes on the supplied queue. + DmlExecutionContext( + ID3D12Device* d3d12_device, + IDMLDevice* dml_device, + ID3D12CommandQueue* queue, + Ort::Allocator& device_allocator, + const OrtDmlApi* ort_dml_api); + + // Waits for flushed work, discards unflushed work, and discards associated references to + // prevent circular references. Must be the last call on the object before destruction. + void Close(); + + // Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition + // barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and + // COPY_DEST if necessary. + void CopyBufferRegion( + ID3D12Resource* dst_buffer, + uint64_t dst_offset, + D3D12_RESOURCE_STATES dst_state, + ID3D12Resource* src_buffer, + uint64_t src_offset, + D3D12_RESOURCE_STATES src_state, + uint64_t byte_count); + + void InitializeOperator( + IDMLCompiledOperator* op, + const DML_BINDING_DESC& persistent_resource_binding, + const DML_BINDING_DESC& input_array_binding); + + void ExecuteCommandList( + ID3D12GraphicsCommandList* command_list, + _Outptr_ ID3D12Fence** fence, + _Out_ uint64_t* completion_value); + + void AddUAVBarrier(); + void ResourceBarrier(std::span barriers); + + void GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** command_list); + + // Forces all queued work to begin executing on the GPU. This method returns immediately and does not wait + // for the submitted work to complete execution on the GPU. + void Flush(); + + // Returns an event which will become signaled when everything submitted to the execution context thus far has + // completed execution on the GPU, including work that has yet to be flushed to the queue. + DmlGpuEvent GetCurrentCompletionEvent(); + + // Adds a reference which will be released when queued GPU work is completed + void QueueReference(IUnknown* object); + + // Release any accumulated references who corresponding GPU fence values have + // been reached. + void ReleaseCompletedReferences(); + + D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; + + private: + void SetCommandRecorder(DmlCommandRecorder* new_recorder); + + std::shared_ptr queue_; + + DmlCommandRecorder* current_recorder_ = nullptr; + + // Up to one of these is active at a time + DmlCommandRecorder dml_recorder_; + + bool closed_ = false; +}; \ No newline at end of file diff --git a/src/dml/dml_gpu_event.h b/src/dml/dml_gpu_event.h new file mode 100644 index 000000000..7ca873c75 --- /dev/null +++ b/src/dml/dml_gpu_event.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +using Microsoft::WRL::ComPtr; + +// Represents a fence which will be signaled at some point (usually by the GPU). +struct DmlGpuEvent { + uint64_t fence_value; + ComPtr fence; + + bool IsSignaled() const { + return fence->GetCompletedValue() >= fence_value; + } + + // Blocks until IsSignaled returns true. + void WaitForSignal() const { + if (IsSignaled()) { + return; // early-out + } + + while (!IsSignaled()) { +#if defined(_M_AMD64) || defined(__x86_64__) + _mm_pause(); +#endif + } + } +}; \ No newline at end of file diff --git a/src/dml/dml_helpers.cpp b/src/dml/dml_helpers.cpp new file mode 100644 index 000000000..9741438a2 --- /dev/null +++ b/src/dml/dml_helpers.cpp @@ -0,0 +1,333 @@ +#pragma once + +#include +#include +#include "dml_helpers.h" + +namespace DmlHelpers { + +DmlObjects CreateDmlObjects() { + D3D12_COMMAND_QUEUE_DESC command_queue_description = { + D3D12_COMMAND_LIST_TYPE_COMPUTE, + 0, + D3D12_COMMAND_QUEUE_FLAG_NONE, + 0, + }; + + DmlObjects dml_objects; + + THROW_IF_FAILED(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); + THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandQueue(&command_queue_description, IID_PPV_ARGS(&dml_objects.command_queue))); + THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&dml_objects.command_allocator))); + THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, dml_objects.command_allocator.Get(), nullptr, IID_PPV_ARGS(&dml_objects.command_list))); + return dml_objects; +} + +DmlReusedCommandListState BuildReusableCommandList( + IDMLDevice* dml_device, + IDMLCompiledOperator* compiled_operator, + ID3D12Resource* persistent_resource, + std::optional persistent_resource_binding) { + DmlReusedCommandListState command_list_state{}; + + DML_BINDING_PROPERTIES exec_binding_props = compiled_operator->GetBindingProperties(); + + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + desc.NumDescriptors = exec_binding_props.RequiredDescriptorCount; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + + ComPtr d3d_device; + THROW_IF_FAILED(dml_device->GetParentDevice(IID_PPV_ARGS(&d3d_device))); + + THROW_IF_FAILED(d3d_device->CreateDescriptorHeap(&desc, IID_PPV_ARGS(command_list_state.heap.ReleaseAndGetAddressOf()))); + + // Create a binding table for execution. + DML_BINDING_TABLE_DESC binding_table_desc = {}; + binding_table_desc.Dispatchable = compiled_operator; + binding_table_desc.CPUDescriptorHandle = command_list_state.heap->GetCPUDescriptorHandleForHeapStart(); + binding_table_desc.GPUDescriptorHandle = command_list_state.heap->GetGPUDescriptorHandleForHeapStart(); + binding_table_desc.SizeInDescriptors = exec_binding_props.RequiredDescriptorCount; + + THROW_IF_FAILED(dml_device->CreateBindingTable(&binding_table_desc, IID_PPV_ARGS(&command_list_state.binding_table))); + + THROW_IF_FAILED(d3d_device->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_COMPUTE, + IID_PPV_ARGS(command_list_state.command_allocator.ReleaseAndGetAddressOf()))); + + THROW_IF_FAILED(d3d_device->CreateCommandList( + 0, + D3D12_COMMAND_LIST_TYPE_COMPUTE, + command_list_state.command_allocator.Get(), + nullptr, + IID_PPV_ARGS(command_list_state.graphics_command_list.ReleaseAndGetAddressOf()))); + + if (persistent_resource) { + DML_BINDING_DESC persistent_resource_binding_desc = {DML_BINDING_TYPE_BUFFER, persistent_resource_binding ? &*persistent_resource_binding : nullptr}; + command_list_state.binding_table->BindPersistentResource(&persistent_resource_binding_desc); + command_list_state.persistent_resource = persistent_resource; + } + + ID3D12DescriptorHeap* descriptor_heaps[] = {command_list_state.heap.Get()}; + command_list_state.graphics_command_list->SetDescriptorHeaps(ARRAYSIZE(descriptor_heaps), descriptor_heaps); + + ComPtr recorder; + THROW_IF_FAILED(dml_device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf()))); + + recorder->RecordDispatch(command_list_state.graphics_command_list.Get(), compiled_operator, command_list_state.binding_table.Get()); + command_list_state.compiled_operator = compiled_operator; + + THROW_IF_FAILED(command_list_state.graphics_command_list->Close()); + + return command_list_state; +} + +void ExecuteReusableCommandList( + DmlExecutionContext* execution_context, + DmlReusedCommandListState& command_list_state, + OrtAllocator& allocator, + const OrtDmlApi* ort_dml_api, + std::span input_resources, + std::span input_sizes, + std::span output_resources, + std::span output_sizes, + bool bindings_changed) { + assert(input_resources.size() == input_sizes.size()); + assert(output_resources.size() == output_sizes.size()); + + DML_BINDING_PROPERTIES exec_binding_props = command_list_state.compiled_operator->GetBindingProperties(); + + std::vector input_bindings(input_resources.size()); + std::vector input_binding_descs(output_resources.size()); + + std::vector output_bindings(output_resources.size()); + std::vector output_binding_descs(output_resources.size()); + + if (bindings_changed) { + // Bind the inputs + for (uint32_t i = 0; i < input_bindings.size(); ++i) { + input_bindings[i].Buffer = input_resources[i]; + input_bindings[i].SizeInBytes = input_sizes[i]; + input_binding_descs[i] = {DML_BINDING_TYPE_BUFFER, &input_bindings[i]}; + } + + command_list_state.binding_table->BindInputs(static_cast(input_binding_descs.size()), input_binding_descs.data()); + + // Bind the outputs + for (uint32_t i = 0; i < output_bindings.size(); ++i) { + output_bindings[i].Buffer = output_resources[i]; + output_bindings[i].SizeInBytes = output_sizes[i]; + output_binding_descs[i] = {DML_BINDING_TYPE_BUFFER, &output_bindings[i]}; + } + + command_list_state.binding_table->BindOutputs(static_cast(output_binding_descs.size()), output_binding_descs.data()); + + // Create the temporary resource + if (exec_binding_props.TemporaryResourceSize > 0) { + ComPtr temporary_resource; + std::array persistent_resource_shape = {static_cast(exec_binding_props.TemporaryResourceSize)}; + auto persistent_tensor = OrtValue::CreateTensor(allocator, persistent_resource_shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Ort::ThrowOnError(ort_dml_api->GetD3D12ResourceFromAllocation(&allocator, persistent_tensor->GetTensorMutableRawData(), &temporary_resource)); + } + } + + // Execute the command list and if it succeeds, update the fence value at which this command may be + // re-used. + ComPtr fence; + uint64_t completion_value; + execution_context->ExecuteCommandList(command_list_state.graphics_command_list.Get(), fence.GetAddressOf(), &completion_value); +} + +static uint64_t DataTypeSizeInBytes(DML_TENSOR_DATA_TYPE dml_data_type) { + switch (dml_data_type) { + case DML_TENSOR_DATA_TYPE_FLOAT16: + return sizeof(Ort::Float16_t); + case DML_TENSOR_DATA_TYPE_FLOAT32: + return sizeof(float); + case DML_TENSOR_DATA_TYPE_FLOAT64: + return sizeof(double); + case DML_TENSOR_DATA_TYPE_UINT8: + return sizeof(uint8_t); + case DML_TENSOR_DATA_TYPE_UINT16: + return sizeof(uint16_t); + case DML_TENSOR_DATA_TYPE_UINT32: + return sizeof(uint32_t); + case DML_TENSOR_DATA_TYPE_UINT64: + return sizeof(uint64_t); + case DML_TENSOR_DATA_TYPE_INT8: + return sizeof(int8_t); + case DML_TENSOR_DATA_TYPE_INT16: + return sizeof(int16_t); + case DML_TENSOR_DATA_TYPE_INT32: + return sizeof(int32_t); + case DML_TENSOR_DATA_TYPE_INT64: + return sizeof(int64_t); + default: + THROW_HR(E_NOTIMPL); + } +} + +ComPtr CreateCastOperator( + IDMLDevice* dml_device, + uint32_t num_elements, + DML_TENSOR_DATA_TYPE source_data_type, + DML_TENSOR_DATA_TYPE target_data_type) { + // Create the input tensor desc + DML_BUFFER_TENSOR_DESC input_buffer_desc{}; + input_buffer_desc.Sizes = &num_elements; + input_buffer_desc.DimensionCount = 1; + input_buffer_desc.DataType = source_data_type; + input_buffer_desc.TotalTensorSizeInBytes = num_elements * DataTypeSizeInBytes(source_data_type); + DML_TENSOR_DESC input_tensor_desc = {DML_TENSOR_TYPE_BUFFER, &input_buffer_desc}; + + // Create the output tensor desc + DML_BUFFER_TENSOR_DESC output_buffer_desc{}; + output_buffer_desc.Sizes = &num_elements; + output_buffer_desc.DimensionCount = 1; + output_buffer_desc.DataType = target_data_type; + output_buffer_desc.TotalTensorSizeInBytes = num_elements * DataTypeSizeInBytes(target_data_type); + DML_TENSOR_DESC output_tensor_desc = {DML_TENSOR_TYPE_BUFFER, &output_buffer_desc}; + + DML_CAST_OPERATOR_DESC cast_op_desc{}; + cast_op_desc.InputTensor = &input_tensor_desc; + cast_op_desc.OutputTensor = &output_tensor_desc; + DML_OPERATOR_DESC cast_op_dml_desc = {DML_OPERATOR_CAST, &cast_op_desc}; + + ComPtr cast_op; + THROW_IF_FAILED(dml_device->CreateOperator(&cast_op_dml_desc, IID_PPV_ARGS(&cast_op))); + + ComPtr compiled_cast_op; + THROW_IF_FAILED(dml_device->CompileOperator(cast_op.Get(), DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE, IID_PPV_ARGS(&compiled_cast_op))); + + return compiled_cast_op; +} + +void GetNextDispatchSize( + uint32_t element_count, + uint32_t num_threads, + uint32_t& dispatch, + uint32_t& pending_element_count) { + // Max threads per workgroup is 2^10 (1024). Max dispatch per dimension is 2^16. Taken together, we can dispatch a maximum of + // 2^26 (268,435,456) threads along a single dimension. This should suffice for a majority of the workload. Therefore, even + // though it is possible to dispatch up to (2^16)^3 workgroups simultaneously, we stick to the simpler 1D dispatch alternative. + assert(num_threads <= D3D12_CS_THREAD_GROUP_MAX_THREADS_PER_GROUP); + + const uint32_t max_threads_per_dispatch = num_threads * D3D12_CS_DISPATCH_MAX_THREAD_GROUPS_PER_DIMENSION; + + // Compute max dispatchable elements + const uint32_t available_thread_count = std::min(element_count, max_threads_per_dispatch); + + // Compute required thread group count + uint32_t workgroup_count_1d = (available_thread_count + num_threads - 1) / num_threads; + + // Compute min dispatch size + dispatch = workgroup_count_1d; + + // With the dispatch size computed, compute the dispatched element count + const uint32_t dispatched_element_count = workgroup_count_1d * num_threads; + + // Update the pending element count + pending_element_count = (dispatched_element_count < element_count) ? element_count - dispatched_element_count : 0; +} + +DML_TENSOR_DATA_TYPE OrtToDmlDataType(ONNXTensorElementDataType ort_dtype) { + switch (ort_dtype) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return DML_TENSOR_DATA_TYPE_FLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return DML_TENSOR_DATA_TYPE_FLOAT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return DML_TENSOR_DATA_TYPE_FLOAT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return DML_TENSOR_DATA_TYPE_UINT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return DML_TENSOR_DATA_TYPE_UINT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return DML_TENSOR_DATA_TYPE_UINT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return DML_TENSOR_DATA_TYPE_UINT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return DML_TENSOR_DATA_TYPE_INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return DML_TENSOR_DATA_TYPE_INT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return DML_TENSOR_DATA_TYPE_INT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return DML_TENSOR_DATA_TYPE_INT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return DML_TENSOR_DATA_TYPE_UINT8; + default: + THROW_HR(E_NOTIMPL); + } +} + +void DmlCastInputToOutput( + DmlExecutionContext* execution_context, + OrtAllocator& allocator, + OrtValue& in, + std::unique_ptr& p_out, + IDMLDevice* dml_device, + const OrtDmlApi* ort_dml_api, + DmlReusedCommandListState& command_list_state) { + auto shape_info = in.GetTensorTypeAndShapeInfo(); + auto shape = shape_info->GetShape(); + + bool allocate_p_out = p_out == nullptr; + if (p_out) { + auto out_shape_info = p_out->GetTensorTypeAndShapeInfo(); + auto out_shape = out_shape_info->GetShape(); + allocate_p_out = shape != out_shape; + } + + if (allocate_p_out) { + p_out = OrtValue::CreateTensor(allocator, shape); + } + + int element_count = static_cast(shape_info->GetElementCount()); + auto dml_from_type = DmlHelpers::OrtToDmlDataType(in.GetTensorTypeAndShapeInfo()->GetElementType()); + auto dml_to_type = DmlHelpers::OrtToDmlDataType(p_out->GetTensorTypeAndShapeInfo()->GetElementType()); + + bool rebind = command_list_state.previousOutput != p_out.get(); + + // If the sizes change, we need to recompile the operator and rebuild the command lists. It should only happen + // once after the very first iteration. + if (rebind) { + auto compiled_cast_operator = DmlHelpers::CreateCastOperator(dml_device, element_count, dml_from_type, dml_to_type); + + ComPtr persistent_resource; + uint64_t persistent_resource_size = compiled_cast_operator->GetBindingProperties().PersistentResourceSize; + + std::optional persistent_resource_binding; + + if (persistent_resource_size > 0) { + std::array persistent_resource_shape = {static_cast(persistent_resource_size)}; + auto persistent_tensor = OrtValue::CreateTensor(allocator, persistent_resource_shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); + Ort::ThrowOnError(ort_dml_api->GetD3D12ResourceFromAllocation(&allocator, persistent_tensor->GetTensorMutableRawData(), &persistent_resource)); + persistent_resource_binding = DML_BUFFER_BINDING{persistent_resource.Get(), 0, persistent_resource_size}; + } + + DML_BINDING_DESC persistent_resource_bindingDesc = persistent_resource_binding + ? DML_BINDING_DESC{DML_BINDING_TYPE_BUFFER, &*persistent_resource_binding} + : DML_BINDING_DESC{DML_BINDING_TYPE_NONE, nullptr}; + + DML_BINDING_DESC input_array_binding_desc = DML_BINDING_DESC{DML_BINDING_TYPE_NONE, nullptr}; + execution_context->InitializeOperator(compiled_cast_operator.Get(), persistent_resource_bindingDesc, input_array_binding_desc); + command_list_state = DmlHelpers::BuildReusableCommandList(dml_device, compiled_cast_operator.Get(), persistent_resource.Get(), persistent_resource_binding); + command_list_state.previousOutput = p_out.get(); + } + + ComPtr source_resource; + Ort::ThrowOnError(ort_dml_api->GetD3D12ResourceFromAllocation(&allocator, in.GetTensorMutableData(), &source_resource)); + + ComPtr target_resource; + Ort::ThrowOnError(ort_dml_api->GetD3D12ResourceFromAllocation(&allocator, p_out->GetTensorMutableData(), &target_resource)); + + std::array input_resources = {source_resource.Get()}; + std::array input_sizes = {element_count * DataTypeSizeInBytes(dml_from_type)}; + + std::array output_resources = {target_resource.Get()}; + std::array output_sizes = {element_count * DataTypeSizeInBytes(dml_to_type)}; + + DmlHelpers::ExecuteReusableCommandList(execution_context, command_list_state, allocator, ort_dml_api, input_resources, input_sizes, output_resources, output_sizes, rebind); +} +} // namespace DmlHelpers diff --git a/src/dml/dml_helpers.h b/src/dml/dml_helpers.h new file mode 100644 index 000000000..22da63254 --- /dev/null +++ b/src/dml/dml_helpers.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include +#include +#include "dml_execution_context.h" + +using Microsoft::WRL::ComPtr; + +struct DmlReusedCommandListState { + // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. + Microsoft::WRL::ComPtr compiled_operator; + Microsoft::WRL::ComPtr graphics_command_list; + Microsoft::WRL::ComPtr command_allocator; + Microsoft::WRL::ComPtr heap; + Microsoft::WRL::ComPtr binding_table; + Microsoft::WRL::ComPtr persistent_resource; + OrtValue* previousInput = nullptr; + OrtValue* previousOutput = nullptr; +}; + +struct DmlObjects { + ComPtr d3d12_device; + ComPtr command_queue; + ComPtr command_allocator; + ComPtr command_list; + ComPtr upload_buffer; +}; + +namespace DmlHelpers { +DmlObjects CreateDmlObjects(); + +DmlReusedCommandListState BuildReusableCommandList( + IDMLDevice* dml_device, + IDMLCompiledOperator* compiled_operator, + ID3D12Resource* persistent_resource, + std::optional persistent_resource_binding); + +void ExecuteReusableCommandList( + DmlExecutionContext* execution_context, + DmlReusedCommandListState& command_list_state, + OrtAllocator& allocator, + const OrtDmlApi* ort_dml_api, + std::span input_resources, + std::span input_sizes, + std::span output_resources, + std::span output_sizes, + bool bindings_changed); + +ComPtr CreateCastOperator( + IDMLDevice* dml_device, + uint32_t num_elements, + DML_TENSOR_DATA_TYPE source_data_type, + DML_TENSOR_DATA_TYPE target_data_type); + +void GetNextDispatchSize( + uint32_t element_count, + uint32_t num_threads, + uint32_t& dispatch, + uint32_t& pending_element_count); + +DML_TENSOR_DATA_TYPE OrtToDmlDataType(ONNXTensorElementDataType ort_dtype); + +void DmlCastInputToOutput( + DmlExecutionContext* execution_context, + OrtAllocator& allocator, + OrtValue& in, + std::unique_ptr& p_out, + IDMLDevice* dml_device, + const OrtDmlApi* ort_dml_api, + DmlReusedCommandListState& command_list_state); +} // namespace DmlHelpers diff --git a/src/dml/dml_increment_values_kernel.cpp b/src/dml/dml_increment_values_kernel.cpp new file mode 100644 index 000000000..989da6288 --- /dev/null +++ b/src/dml/dml_increment_values_kernel.cpp @@ -0,0 +1,122 @@ +#include +#include +#include +#include +#include "dml_increment_values_kernel.h" +#include "dml_helpers.h" + +namespace DmlIncrementValues_Int32 { +#include "generated_dml_shaders/increment_values_int32.h" +} + +namespace DmlIncrementValues_Int64 { +#include "generated_dml_shaders/increment_values_int64.h" +} + +DmlIncrementValuesKernel::DmlIncrementValuesKernel( + ID3D12Device* d3d12_device, + DmlExecutionContext* execution_context, + uint32_t element_count, + ONNXTensorElementDataType dtype, + ID3D12Resource* values_resource) + : device_(d3d12_device), + execution_context_(execution_context), + dtype_(dtype), + values_resource_(values_resource) { + constants_.element_count = element_count; + total_element_count_ = element_count; + + // Compute root signature. + std::vector root_parameters; + root_parameters.resize(uav_count_ + 1); + + for (UINT i = 0; i < uav_count_; i++) { + root_parameters[i].InitAsUnorderedAccessView(i); + } + + root_parameters[uav_count_].InitAsConstants(constant_count_, 0); + + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC desc; + desc.Init_1_1(static_cast(root_parameters.size()), root_parameters.data()); + + ComPtr root_signature_blob; + ComPtr root_signature_error_blob; + THROW_IF_FAILED(D3D12SerializeVersionedRootSignature( + &desc, + root_signature_blob.GetAddressOf(), + root_signature_error_blob.GetAddressOf())); + + THROW_IF_FAILED(device_->CreateRootSignature( + 0, + root_signature_blob->GetBufferPointer(), + root_signature_blob->GetBufferSize(), + IID_PPV_ARGS(&root_signature_))); + + D3D12_COMPUTE_PIPELINE_STATE_DESC compute_pso_desc = {}; + compute_pso_desc.pRootSignature = root_signature_.Get(); + + if (dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + compute_pso_desc.CS = CD3DX12_SHADER_BYTECODE(DmlIncrementValues_Int32::g_CSMain, sizeof(DmlIncrementValues_Int32::g_CSMain)); + } else if (dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + compute_pso_desc.CS = CD3DX12_SHADER_BYTECODE(DmlIncrementValues_Int64::g_CSMain, sizeof(DmlIncrementValues_Int64::g_CSMain)); + } else { + THROW_HR(E_NOTIMPL); + } + + THROW_IF_FAILED(device_->CreateComputePipelineState(&compute_pso_desc, IID_PPV_ARGS(&pipeline_state_))); + + THROW_IF_FAILED(d3d12_device->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_COMPUTE, + IID_PPV_ARGS(command_allocator_.ReleaseAndGetAddressOf()))); + + THROW_IF_FAILED(d3d12_device->CreateCommandList( + 0, + D3D12_COMMAND_LIST_TYPE_COMPUTE, + command_allocator_.Get(), + nullptr, + IID_PPV_ARGS(graphics_command_list_.ReleaseAndGetAddressOf()))); + + D3D12_DESCRIPTOR_HEAP_DESC heap_desc = {}; + heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + heap_desc.NumDescriptors = uav_count_; + heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + + THROW_IF_FAILED(d3d12_device->CreateDescriptorHeap(&heap_desc, IID_PPV_ARGS(heap_.ReleaseAndGetAddressOf()))); + + ID3D12DescriptorHeap* descriptor_heaps[] = {heap_.Get()}; + graphics_command_list_->SetDescriptorHeaps(ARRAYSIZE(descriptor_heaps), descriptor_heaps); + + // Set the root signature and pipeline state + graphics_command_list_->SetComputeRootSignature(root_signature_.Get()); + graphics_command_list_->SetPipelineState(pipeline_state_.Get()); + graphics_command_list_->SetComputeRootUnorderedAccessView(0, values_resource_->GetGPUVirtualAddress()); + + auto pending_element_count = total_element_count_; + auto constants = constants_; + + // Dispatch up to the maximum number of threads per iteration until + // all elements are completed + while (pending_element_count > 0) { + constants.start_index = total_element_count_ - pending_element_count; + + uint32_t dispatch_size_x; + + DmlHelpers::GetNextDispatchSize( + pending_element_count, + 256, + dispatch_size_x, + pending_element_count); + + // Set root constants + graphics_command_list_->SetComputeRoot32BitConstants( + uav_count_, // root parameter index + constant_count_, // Constant count + &constants, + 0 // offset + ); + + graphics_command_list_->Dispatch(dispatch_size_x, 1, 1); + } + + graphics_command_list_->Close(); +} \ No newline at end of file diff --git a/src/dml/dml_increment_values_kernel.h b/src/dml/dml_increment_values_kernel.h new file mode 100644 index 000000000..b7b096bf8 --- /dev/null +++ b/src/dml/dml_increment_values_kernel.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include +#include "dml_execution_context.h" + +using Microsoft::WRL::ComPtr; + +class DmlIncrementValuesKernel { + public: + DmlIncrementValuesKernel( + ID3D12Device* d3d12_device, + DmlExecutionContext* execution_context, + uint32_t element_count, + ONNXTensorElementDataType dtype, + ID3D12Resource* values_resource); + + ID3D12GraphicsCommandList* GetCommandList() { return graphics_command_list_.Get(); } + + private: + struct Constants { + uint32_t element_count; + uint32_t start_index; + }; + + ComPtr device_; + ComPtr root_signature_; + ComPtr pipeline_state_; + Constants constants_; + DmlExecutionContext* execution_context_; + + ComPtr graphics_command_list_; + ComPtr command_allocator_; + ComPtr heap_; + + ONNXTensorElementDataType dtype_; + ComPtr values_resource_; + uint32_t total_element_count_; + + constexpr static uint32_t constant_count_ = sizeof(Constants) / sizeof(uint32_t); + constexpr static uint32_t uav_count_ = 1; +}; \ No newline at end of file diff --git a/src/dml/dml_pooled_upload_heap.cpp b/src/dml/dml_pooled_upload_heap.cpp new file mode 100644 index 000000000..db9380f23 --- /dev/null +++ b/src/dml/dml_pooled_upload_heap.cpp @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "dml_pooled_upload_heap.h" +#include "dml_execution_context.h" + +DmlPooledUploadHeap::DmlPooledUploadHeap(ID3D12Device* device, DmlExecutionContext* execution_context) + : device_(device), execution_context_(execution_context) { +} + +static size_t Align(size_t offset, size_t alignment) { + assert(alignment != 0); + return (offset + alignment - 1) & ~(alignment - 1); +} + +std::optional DmlPooledUploadHeap::FindOffsetForAllocation(const Chunk& chunk, size_t size_in_bytes) { + assert(size_in_bytes != 0); + + if (chunk.capacity_in_bytes < size_in_bytes) { + // This chunk isn't even big enough to accommodate this allocation + return std::nullopt; + } + + if (chunk.allocations.empty()) { + // The entire chunk is empty - allocate from the beginning + return 0; + } + + // Chunks are used as ring buffers, which means this allocation should go after the most recent previous + // allocation + + const auto& last_allocation = chunk.allocations.back(); + size_t new_allocation_begin = last_allocation.offset_in_chunk + last_allocation.size_in_bytes; + new_allocation_begin = Align(new_allocation_begin, c_allocation_alignment); + + if (new_allocation_begin + size_in_bytes < new_allocation_begin) { + // Overflow + return std::nullopt; + } + + const auto& first_allocation = chunk.allocations.front(); + if (first_allocation.offset_in_chunk <= last_allocation.offset_in_chunk) { + // This is the case where there's potentially free space at the beginning and end of the chunk, but not + // the middle: + // e.g. + // |------XXXXYYYZZ------| + // ^^^^ ^^ + // first last + + if (new_allocation_begin + size_in_bytes <= chunk.capacity_in_bytes) { + // There's enough space between the end of the last allocation and the end of the chunk + return new_allocation_begin; + } else { + // Otherwise there's not enough space at the end of the chunk - try the beginning of the chunk instead + new_allocation_begin = 0; + if (new_allocation_begin + size_in_bytes <= first_allocation.offset_in_chunk) { + // There was enough space between the start of the buffer, and the start of the first allocation + return new_allocation_begin; + } + } + } else { + // This is the case where there's potentially free space in the middle of the chunk, but not at the edges + // e.g. + // |YYYZZ---------XXXX-| + // ^^ ^^^^ + // last first + + if (new_allocation_begin + size_in_bytes <= first_allocation.offset_in_chunk) { + // There's enough space between the end of the last allocation, and the start of the first one + return new_allocation_begin; + } + } + + // Not enough space in this chunk to accommodate the requested allocation + return std::nullopt; +} + +/* static */ DmlPooledUploadHeap::Chunk DmlPooledUploadHeap::CreateChunk(ID3D12Device* device, size_t size_in_bytes) { + ComPtr upload_buffer; + auto heap = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size_in_bytes); + + THROW_IF_FAILED(device->CreateCommittedResource( + &heap, + D3D12_HEAP_FLAG_NONE, + &buffer, + D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_PPV_ARGS(upload_buffer.ReleaseAndGetAddressOf()))); + + return Chunk{size_in_bytes, std::move(upload_buffer)}; +} + +std::pair DmlPooledUploadHeap::Reserve(size_t size_in_bytes) { + // Try to find a chunk with enough free space to accommodate the requested allocation size + for (Chunk& chunk : chunks_) { + std::optional offset_for_allocation = FindOffsetForAllocation(chunk, size_in_bytes); + if (offset_for_allocation) { + // There's enough space in this chunk - return + return std::make_pair(&chunk, *offset_for_allocation); + } + } + + // No chunks were able to accommodate the allocation - create a new chunk and return that instead + + // At least double the capacity of the pool + const size_t new_chunk_size = std::max({total_capacity_, c_min_chunk_size, size_in_bytes}); + chunks_.push_back(CreateChunk(device_.Get(), new_chunk_size)); + total_capacity_ += new_chunk_size; + + // Allocate from the beginning of the new chunk + return std::make_pair(&chunks_.back(), 0); +} + +void DmlPooledUploadHeap::ReclaimAllocations() { + for (Chunk& chunk : chunks_) { + auto* allocs = &chunk.allocations; + + // Remove all allocations which have had their fences signaled - this indicates that they are no longer + // being used by the GPU. We can stop as soon as we find an allocation which is still in use, because we + // only use a single command queue and executions always complete in the order they were submitted. + while (!allocs->empty() && allocs->front().done_event.IsSignaled()) { + allocs->pop_front(); + } + } +} + +DmlGpuEvent DmlPooledUploadHeap::BeginUploadToGpu( + ID3D12Resource* dst, + uint64_t dst_offset, + D3D12_RESOURCE_STATES dst_state, + std::span src) { + assert(!src.empty()); + assert(dst->GetDesc().Dimension == D3D12_RESOURCE_DIMENSION_BUFFER); + + InvariantChecker checker(this); + + ReclaimAllocations(); + + // Allocate space from the upload heap + Chunk* chunk = nullptr; + size_t offset_in_chunk = 0; + std::tie(chunk, offset_in_chunk) = Reserve(src.size()); + + assert(chunk != nullptr); + assert(offset_in_chunk + src.size() <= chunk->capacity_in_bytes); + + // Map the upload heap and copy the source data into it at the specified offset + void* upload_heap_data = nullptr; + THROW_IF_FAILED(chunk->resource->Map(0, nullptr, &upload_heap_data)); + memcpy(static_cast(upload_heap_data) + offset_in_chunk, src.data(), src.size()); + chunk->resource->Unmap(0, nullptr); + + // Copy from the upload heap into the destination resource + execution_context_->CopyBufferRegion( + dst, + dst_offset, + dst_state, + chunk->resource.Get(), + offset_in_chunk, + D3D12_RESOURCE_STATE_GENERIC_READ, + src.size()); + + DmlGpuEvent done_event = execution_context_->GetCurrentCompletionEvent(); + + execution_context_->Flush(); + done_event.WaitForSignal(); + + // Add an allocation entry to the chunk + chunk->allocations.push_back(Allocation{static_cast(src.size()), offset_in_chunk, done_event}); + + return done_event; +} + +void DmlPooledUploadHeap::Trim() { + InvariantChecker checker(this); + + ReclaimAllocations(); + + // Release any chunks which have no allocations + auto it = std::remove_if(chunks_.begin(), chunks_.end(), [](const Chunk& c) { + return c.allocations.empty(); + }); + chunks_.erase(it, chunks_.end()); + + // Re-calculate total capacity + total_capacity_ = 0; + for (const auto& chunk : chunks_) { + total_capacity_ += chunk.capacity_in_bytes; + } +} + +void DmlPooledUploadHeap::AssertInvariants() { +#ifdef _DEBUG + + auto chunk_capacity_comparer = [](const Chunk& lhs, const Chunk& rhs) { + return lhs.capacity_in_bytes < rhs.capacity_in_bytes; + }; + + // Chunks should be sorted by ascending capacity + assert(std::is_sorted(chunks_.begin(), chunks_.end(), chunk_capacity_comparer)); + + // Allocations in a chunk should be sorted by ascending fence value + for (const auto& chunk : chunks_) { + auto alloc_fence_value_comparer = [](const Allocation& lhs, const Allocation& rhs) { + return lhs.done_event.fence_value < rhs.done_event.fence_value; + }; + assert(std::is_sorted(chunk.allocations.begin(), chunk.allocations.end(), alloc_fence_value_comparer)); + } + + // Validate chunk properties + for (const auto& chunk : chunks_) { + assert(chunk.resource != nullptr); + assert(chunk.capacity_in_bytes == chunk.resource->GetDesc().Width); + } + + // Validate allocation properties + for (const auto& chunk : chunks_) { + for (const auto& alloc : chunk.allocations) { + assert(alloc.offset_in_chunk + alloc.size_in_bytes <= chunk.capacity_in_bytes); + assert(alloc.offset_in_chunk % c_allocation_alignment == 0); // Validate alignment + } + } + + // Validate no overlapping allocations + for (const auto& chunk : chunks_) { + auto alloc_offset_comparer = [](const Allocation& lhs, const Allocation& rhs) { + return lhs.offset_in_chunk < rhs.offset_in_chunk; + }; + + std::vector allocations_sorted_by_offset(chunk.allocations.begin(), chunk.allocations.end()); + std::sort(allocations_sorted_by_offset.begin(), allocations_sorted_by_offset.end(), alloc_offset_comparer); + + for (size_t i = 1; i < allocations_sorted_by_offset.size(); ++i) { + const auto& alloc = allocations_sorted_by_offset[i - 1]; + const auto& next_alloc = allocations_sorted_by_offset[i]; + assert(alloc.offset_in_chunk + alloc.size_in_bytes <= next_alloc.offset_in_chunk); + } + } + + // Validate total capacity of pool + size_t calculated_capacity = 0; + for (const auto& chunk : chunks_) { + calculated_capacity += chunk.capacity_in_bytes; + } + assert(calculated_capacity == total_capacity_); + +#endif // #ifdef _DEBUG +} \ No newline at end of file diff --git a/src/dml/dml_pooled_upload_heap.h b/src/dml/dml_pooled_upload_heap.h new file mode 100644 index 000000000..817f817a4 --- /dev/null +++ b/src/dml/dml_pooled_upload_heap.h @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include "../span.h" +#include "dml_gpu_event.h" +#include "dml_execution_context.h" + +// Implements a non-blocking, ring-buffer style upload heap for copying CPU data to GPU resources. +class DmlPooledUploadHeap { + public: + DmlPooledUploadHeap(ID3D12Device* device, DmlExecutionContext* execution_context); + + // Makes a copy of the source data and begins copying it into the destination resource, and returns a GpuEvent + // which will become signaled when the copy is complete. The destination resource must be a default or readback + // buffer. + DmlGpuEvent BeginUploadToGpu( + ID3D12Resource* dst, + uint64_t dst_offset, + D3D12_RESOURCE_STATES dst_state, + std::span src); + + // Releases unused capacity. + void Trim(); + + size_t Capacity() const { return total_capacity_; } + + private: + static constexpr size_t c_min_chunk_size = 1024 * 1024; // 1MB + static constexpr size_t c_allocation_alignment = 512; // In bytes; as per D3D12 requirement for buffers + + // A suballoction from a chunk + struct Allocation { + size_t size_in_bytes; + + // The offset, in bytes, from the beginning of the chunk to the beginning of this allocation + size_t offset_in_chunk; + + // The event that will be signaled to when the GPU is done executing work that uses this allocation + DmlGpuEvent done_event; + }; + + // Represents a single contiguous upload heap from which we carve out suballocations. Ranges are suballocated + // from the upload heap in a ring-buffer fashion. + struct Chunk { + size_t capacity_in_bytes; // The total size of the upload heap, in bytes + ComPtr resource; + + // Allocations are sorted by ascending fence value - that is, least to most recently allocated + std::list allocations; + }; + + // Calls AssertInvariants on construction and again on destruction + class InvariantChecker { + public: + InvariantChecker(DmlPooledUploadHeap* parent) + : parent_(parent) { + parent_->AssertInvariants(); + } + + ~InvariantChecker() { + parent_->AssertInvariants(); + } + + private: + DmlPooledUploadHeap* parent_; + }; + + // Attempts to find enough unused space in the supplied chunk to accommodate the given allocation size. + // Returns the offset of that memory if successful, null if there wasn't enough space. + static std::optional FindOffsetForAllocation(const Chunk& chunk, size_t size_in_bytes); + + static Chunk CreateChunk(ID3D12Device* device, size_t size_in_bytes); + + // Finds or creates a chunk with enough space to accommodate an allocation of the given size, and returns a + // pointer to the chunk and allocation offset. + std::pair Reserve(size_t size_in_bytes); + + void ReclaimAllocations(); // Frees all allocations which are no longer being used by the GPU. + void AssertInvariants(); + + ComPtr device_; + DmlExecutionContext* execution_context_; + + std::vector chunks_; // sorted ascending by capacity (upload heap size) + size_t total_capacity_ = 0; // Total size of all chunks, in bytes +}; \ No newline at end of file diff --git a/src/dml/dml_readback_heap.cpp b/src/dml/dml_readback_heap.cpp new file mode 100644 index 000000000..2726fe565 --- /dev/null +++ b/src/dml/dml_readback_heap.cpp @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "dml_readback_heap.h" +#include "dml_execution_context.h" + +static ComPtr CreateReadbackHeap(ID3D12Device* device, size_t size) { + ComPtr readback_heap; + auto heap = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_READBACK); + auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size); + + THROW_IF_FAILED(device->CreateCommittedResource( + &heap, + D3D12_HEAP_FLAG_NONE, + &buffer, + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_PPV_ARGS(readback_heap.ReleaseAndGetAddressOf()))); + + return readback_heap; +} + +DmlReadbackHeap::DmlReadbackHeap(ID3D12Device* device, DmlExecutionContext* execution_context) + : device_(device), + execution_context_(execution_context) { +} + +static size_t ComputeNewCapacity(size_t existing_capacity, size_t desired_capacity) { + size_t new_capacity = existing_capacity; + + while (new_capacity < desired_capacity) { + if (new_capacity >= std::numeric_limits::max() / 2) { + // Overflow; there's no way we can satisfy this allocation request + THROW_HR(E_OUTOFMEMORY); + } + + new_capacity *= 2; // geometric growth + } + + return new_capacity; +} + +void DmlReadbackHeap::EnsureReadbackHeap(size_t size) { + if (!readback_heap_) { + // Initialize the readback heap for the first time + assert(capacity_ == 0); + capacity_ = ComputeNewCapacity(c_initial_capacity, size); + readback_heap_ = CreateReadbackHeap(device_.Get(), capacity_); + } else if (capacity_ < size) { + // Ensure there's sufficient capacity + capacity_ = ComputeNewCapacity(capacity_, size); + + readback_heap_ = nullptr; + readback_heap_ = CreateReadbackHeap(device_.Get(), capacity_); + } + + assert(readback_heap_->GetDesc().Width >= size); +} + +void DmlReadbackHeap::ReadbackFromGpu( + std::span dst, + ID3D12Resource* src, + uint64_t src_offset, + D3D12_RESOURCE_STATES src_state) { + assert(!dst.empty()); + + EnsureReadbackHeap(dst.size()); + + // Copy from the source resource into the readback heap + execution_context_->CopyBufferRegion( + readback_heap_.Get(), + 0, + D3D12_RESOURCE_STATE_COPY_DEST, + src, + src_offset, + src_state, + dst.size()); + + // Wait for completion and map the result + execution_context_->Flush(); + execution_context_->GetCurrentCompletionEvent().WaitForSignal(); + execution_context_->ReleaseCompletedReferences(); + + // Map the readback heap and copy it into the destination + void* readback_heap_data = nullptr; + THROW_IF_FAILED(readback_heap_->Map(0, nullptr, &readback_heap_data)); + memcpy(dst.data(), readback_heap_data, dst.size()); + readback_heap_->Unmap(0, nullptr); +} + +void DmlReadbackHeap::ReadbackFromGpu( + std::span dst, + std::span dst_sizes, + std::span src, + D3D12_RESOURCE_STATES src_state) { + assert(dst.size() == src.size()); + assert(dst_sizes.size() == src.size()); + + if (dst.empty()) { + return; + } + + uint32_t total_size = 0; + for (auto size : dst_sizes) { + total_size += size; + } + + EnsureReadbackHeap(total_size); + + // Copy from the source resource into the readback heap + uint32_t offset = 0; + for (uint32_t i = 0; i < dst.size(); ++i) { + execution_context_->CopyBufferRegion( + readback_heap_.Get(), + offset, + D3D12_RESOURCE_STATE_COPY_DEST, + src[i], + 0, + src_state, + dst_sizes[i]); + + offset += dst_sizes[i]; + } + + // Wait for completion and map the result + execution_context_->Flush(); + execution_context_->GetCurrentCompletionEvent().WaitForSignal(); + execution_context_->ReleaseCompletedReferences(); + + // Map the readback heap and copy it into the destination + void* readback_heap_data = nullptr; + THROW_IF_FAILED(readback_heap_->Map(0, nullptr, &readback_heap_data)); + + // Copy from the source resource into the readback heap + offset = 0; + for (uint32_t i = 0; i < dst.size(); ++i) { + memcpy(dst[i], static_cast(readback_heap_data) + offset, dst_sizes[i]); + offset += dst_sizes[i]; + } + + readback_heap_->Unmap(0, nullptr); +} \ No newline at end of file diff --git a/src/dml/dml_readback_heap.h b/src/dml/dml_readback_heap.h new file mode 100644 index 000000000..1d6777185 --- /dev/null +++ b/src/dml/dml_readback_heap.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "dml_execution_context.h" + +// Because we never perform more than one readback at a time, we don't need anything fancy for managing the +// readback heap - just maintain a single resource and reallocate it if it's not big enough. +class DmlReadbackHeap { + public: + DmlReadbackHeap(ID3D12Device* device, DmlExecutionContext* execution_context); + + // Copies data from the specified GPU resource into CPU memory pointed-to by the span. This method will block + // until the copy is complete. + void ReadbackFromGpu( + std::span dst, + ID3D12Resource* src, + uint64_t src_offset, + D3D12_RESOURCE_STATES src_state); + + // Overload supporting batching + void ReadbackFromGpu( + std::span dst, + std::span dst_sizes, + std::span src, + D3D12_RESOURCE_STATES src_state); + + private: + void EnsureReadbackHeap(size_t size); + + static constexpr size_t c_initial_capacity = 1024 * 1024; // 1MB + + ComPtr device_; + DmlExecutionContext* execution_context_; + + ComPtr readback_heap_; + size_t capacity_ = 0; +}; \ No newline at end of file diff --git a/src/dml/dml_shaders/dml_increment_values.hlsl b/src/dml/dml_shaders/dml_increment_values.hlsl new file mode 100644 index 000000000..ad31240fe --- /dev/null +++ b/src/dml/dml_shaders/dml_increment_values.hlsl @@ -0,0 +1,28 @@ + +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#define ROOT_SIG_DEF "DescriptorTable(UAV(u0, numDescriptors=1, flags=DATA_VOLATILE | DESCRIPTORS_VOLATILE)), RootConstants(num32BitConstants=1, b0)" +#define NUM_THREADS 256 + +RWStructuredBuffer values : register(u0); + +cbuffer Constants +{ + uint element_count; + uint start_index; +}; + +[RootSignature(ROOT_SIG_DEF)] +[numthreads(NUM_THREADS, 1, 1)] +void CSMain(uint3 dispatch_thread_id : SV_DispatchThreadID) +{ + uint global_index = dispatch_thread_id.x + start_index; + if (global_index < element_count) + { + ++values[global_index]; + } +} diff --git a/src/dml/dml_shaders/dml_update_attention_mask.hlsl b/src/dml/dml_shaders/dml_update_attention_mask.hlsl new file mode 100644 index 000000000..22e7c77c9 --- /dev/null +++ b/src/dml/dml_shaders/dml_update_attention_mask.hlsl @@ -0,0 +1,41 @@ + +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#define ROOT_SIG_DEF "DescriptorTable(UAV(u0, numDescriptors=2, flags=DATA_VOLATILE | DESCRIPTORS_VOLATILE)), RootConstants(num32BitConstants=1, b0)" +#define NUM_THREADS 256 + +RWStructuredBuffer input_mask : register(u0); +RWStructuredBuffer output_mask : register(u1); + +cbuffer Constants +{ + uint max_seq_len; + uint seq_len; + uint element_count; + uint start_index; +}; + +[RootSignature(ROOT_SIG_DEF)] +[numthreads(NUM_THREADS, 1, 1)] +void CSMain(uint3 dispatch_thread_id : SV_DispatchThreadID) +{ + uint global_index = dispatch_thread_id.x + start_index; + if (global_index < element_count) + { + uint sequence_index = global_index % max_seq_len; + + if (seq_len > 1) + { + const T value = sequence_index < seq_len ? 1 : 0; + output_mask[global_index] = value; + } + else + { + output_mask[global_index] = (sequence_index == 0 || input_mask[sequence_index] == 1 || input_mask[sequence_index - 1] == 1) ? 1 : 0; + } + } +} diff --git a/src/dml/dml_smart_container.h b/src/dml/dml_smart_container.h new file mode 100644 index 000000000..c025f7018 --- /dev/null +++ b/src/dml/dml_smart_container.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "../models/onnxruntime_api.h" + +// Allows objects to be added to a D3D12 object via SetPrivateDataInterface and extend its lifetime beyond the life of the model. For +// example, we can put the DML allocator on the D3D12 device (which is a unique singleton for each adapter) and be sure that the allocator won't be +// destroyed until nothing holds on to the device anymore. +class DmlSmartContainer : public Microsoft::WRL::RuntimeClass, IUnknown> { + public: + DmlSmartContainer(std::unique_ptr&& memory_info, std::unique_ptr&& allocator) + : memory_info_(std::move(memory_info)), allocator_(std::move(allocator)) {} + + const OrtMemoryInfo* GetMemoryInfo() const { return memory_info_.get(); } + Ort::Allocator* GetAllocator() const { return allocator_.get(); } + + private: + std::unique_ptr memory_info_; + std::unique_ptr allocator_; +}; \ No newline at end of file diff --git a/src/dml/dml_update_mask_kernel.cpp b/src/dml/dml_update_mask_kernel.cpp new file mode 100644 index 000000000..a927b78d9 --- /dev/null +++ b/src/dml/dml_update_mask_kernel.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include +#include "dml_update_mask_kernel.h" +#include "dml_helpers.h" + +namespace DmlUpdateMask_Int32 { +#include "generated_dml_shaders/update_mask_int32.h" +} + +namespace DmlUpdateMask_Int64 { +#include "generated_dml_shaders/update_mask_int64.h" +} + +DmlUpdateMaskKernel::DmlUpdateMaskKernel( + ID3D12Device* d3d12_device, + DmlExecutionContext* execution_context, + uint32_t batch_size, + uint32_t max_seq_len, + ONNXTensorElementDataType dtype, + uint32_t seq_len, + ID3D12Resource* attention_mask_resource, + ID3D12Resource* attention_mask_next_resource) + : device_(d3d12_device), + execution_context_(execution_context), + dtype_(dtype), + attention_mask_resource_(attention_mask_resource), + attention_mask_next_resource_(attention_mask_next_resource) { + constants_.element_count = batch_size * max_seq_len; + constants_.max_seq_len = max_seq_len; + constants_.seq_len = seq_len; + total_element_count_ = batch_size * max_seq_len; + + // Compute root signature. + std::vector root_parameters; + root_parameters.resize(uav_count_ + 1); + + for (UINT i = 0; i < uav_count_; i++) { + root_parameters[i].InitAsUnorderedAccessView(i); + } + + root_parameters[uav_count_].InitAsConstants(constant_count_, 0); + + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC desc; + desc.Init_1_1(static_cast(root_parameters.size()), root_parameters.data()); + + ComPtr root_signature_blob; + ComPtr root_signature_error_blob; + THROW_IF_FAILED(D3D12SerializeVersionedRootSignature( + &desc, + root_signature_blob.GetAddressOf(), + root_signature_error_blob.GetAddressOf())); + + THROW_IF_FAILED(device_->CreateRootSignature( + 0, + root_signature_blob->GetBufferPointer(), + root_signature_blob->GetBufferSize(), + IID_PPV_ARGS(&root_signature_))); + + D3D12_COMPUTE_PIPELINE_STATE_DESC compute_pso_desc = {}; + compute_pso_desc.pRootSignature = root_signature_.Get(); + + if (dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + compute_pso_desc.CS = CD3DX12_SHADER_BYTECODE(DmlUpdateMask_Int32::g_CSMain, sizeof(DmlUpdateMask_Int32::g_CSMain)); + } else if (dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + compute_pso_desc.CS = CD3DX12_SHADER_BYTECODE(DmlUpdateMask_Int64::g_CSMain, sizeof(DmlUpdateMask_Int64::g_CSMain)); + } else { + THROW_HR(E_NOTIMPL); + } + + THROW_IF_FAILED(device_->CreateComputePipelineState(&compute_pso_desc, IID_PPV_ARGS(&pipeline_state_))); + + THROW_IF_FAILED(d3d12_device->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_COMPUTE, + IID_PPV_ARGS(command_allocator_.ReleaseAndGetAddressOf()))); + + THROW_IF_FAILED(d3d12_device->CreateCommandList( + 0, + D3D12_COMMAND_LIST_TYPE_COMPUTE, + command_allocator_.Get(), + nullptr, + IID_PPV_ARGS(graphics_command_list_.ReleaseAndGetAddressOf()))); + + D3D12_DESCRIPTOR_HEAP_DESC heap_desc = {}; + heap_desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + heap_desc.NumDescriptors = uav_count_; + heap_desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + + THROW_IF_FAILED(d3d12_device->CreateDescriptorHeap(&heap_desc, IID_PPV_ARGS(heap_.ReleaseAndGetAddressOf()))); + + ID3D12DescriptorHeap* descriptor_heaps[] = {heap_.Get()}; + graphics_command_list_->SetDescriptorHeaps(ARRAYSIZE(descriptor_heaps), descriptor_heaps); + + // Set the root signature and pipeline state + graphics_command_list_->SetComputeRootSignature(root_signature_.Get()); + graphics_command_list_->SetPipelineState(pipeline_state_.Get()); + graphics_command_list_->SetComputeRootUnorderedAccessView(0, attention_mask_resource_->GetGPUVirtualAddress()); + graphics_command_list_->SetComputeRootUnorderedAccessView(1, attention_mask_next_resource_->GetGPUVirtualAddress()); + + auto pending_element_count = total_element_count_; + auto constants = constants_; + + // Dispatch up to the maximum number of threads per iteration until + // all elements are completed + while (pending_element_count > 0) { + constants.start_index = total_element_count_ - pending_element_count; + + uint32_t dispatch_size_x; + + DmlHelpers::GetNextDispatchSize( + pending_element_count, + 256, + dispatch_size_x, + pending_element_count); + + // Set root constants + graphics_command_list_->SetComputeRoot32BitConstants( + uav_count_, // root parameter index + constant_count_, // Constant count + &constants, + 0 // offset + ); + + graphics_command_list_->Dispatch(dispatch_size_x, 1, 1); + } + + // Barrier before doing the copy + std::array before_copy_barriers = { + CD3DX12_RESOURCE_BARRIER::UAV(attention_mask_next_resource_.Get()), + CD3DX12_RESOURCE_BARRIER::Transition(attention_mask_resource_.Get(), D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_DEST), + CD3DX12_RESOURCE_BARRIER::Transition(attention_mask_next_resource_.Get(), D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE), + }; + graphics_command_list_->ResourceBarrier(static_cast(before_copy_barriers.size()), before_copy_barriers.data()); + + // Copy the next mask to the current mask for next iteration + if (dtype_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + graphics_command_list_->CopyBufferRegion(attention_mask_resource_.Get(), 0, attention_mask_next_resource_.Get(), 0, constants_.element_count * sizeof(int32_t)); + } else if (dtype_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + graphics_command_list_->CopyBufferRegion(attention_mask_resource_.Get(), 0, attention_mask_next_resource_.Get(), 0, constants_.element_count * sizeof(int64_t)); + } else { + THROW_HR(E_NOTIMPL); + } + + // Barrier after doing the copy + std::array after_copy_barriers = { + CD3DX12_RESOURCE_BARRIER::Transition(attention_mask_resource_.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_UNORDERED_ACCESS), + CD3DX12_RESOURCE_BARRIER::Transition(attention_mask_next_resource_.Get(), D3D12_RESOURCE_STATE_COPY_SOURCE, D3D12_RESOURCE_STATE_UNORDERED_ACCESS), + }; + graphics_command_list_->ResourceBarrier(static_cast(after_copy_barriers.size()), after_copy_barriers.data()); + + graphics_command_list_->Close(); +} \ No newline at end of file diff --git a/src/dml/dml_update_mask_kernel.h b/src/dml/dml_update_mask_kernel.h new file mode 100644 index 000000000..a0e2e779a --- /dev/null +++ b/src/dml/dml_update_mask_kernel.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include +#include "dml_execution_context.h" + +using Microsoft::WRL::ComPtr; + +class DmlUpdateMaskKernel { + public: + DmlUpdateMaskKernel( + ID3D12Device* d3d12_device, + DmlExecutionContext* execution_context, + uint32_t batch_size, + uint32_t max_seq_len, + ONNXTensorElementDataType dtype, + uint32_t seq_len, + ID3D12Resource* attention_mask_resource, + ID3D12Resource* attention_mask_next_resource); + + ID3D12GraphicsCommandList* GetCommandList() { return graphics_command_list_.Get(); } + + private: + struct Constants { + uint32_t max_seq_len; + uint32_t seq_len; + uint32_t element_count; + uint32_t start_index; + }; + + ComPtr device_; + ComPtr root_signature_; + ComPtr pipeline_state_; + Constants constants_; + DmlExecutionContext* execution_context_; + + ComPtr graphics_command_list_; + ComPtr command_allocator_; + ComPtr heap_; + + ONNXTensorElementDataType dtype_; + ComPtr attention_mask_resource_; + ComPtr attention_mask_next_resource_; + uint32_t total_element_count_; + + constexpr static uint32_t constant_count_ = sizeof(Constants) / sizeof(uint32_t); + constexpr static uint32_t uav_count_ = 2; +}; \ No newline at end of file diff --git a/src/dml/generated_dml_shaders/.clang-format b/src/dml/generated_dml_shaders/.clang-format new file mode 100644 index 000000000..57fe428dc --- /dev/null +++ b/src/dml/generated_dml_shaders/.clang-format @@ -0,0 +1,3 @@ +--- +DisableFormat: true +... diff --git a/src/dml/generated_dml_shaders/increment_values_int32.h b/src/dml/generated_dml_shaders/increment_values_int32.h new file mode 100644 index 000000000..6072b1966 --- /dev/null +++ b/src/dml/generated_dml_shaders/increment_values_int32.h @@ -0,0 +1,269 @@ +#if 0 +; +; Input signature: +; +; Name Index Mask Register SysValue Format Used +; -------------------- ----- ------ -------- -------- ------- ------ +; no parameters +; +; Output signature: +; +; Name Index Mask Register SysValue Format Used +; -------------------- ----- ------ -------- -------- ------- ------ +; no parameters +; shader hash: 9f7e9ca05d93206ff785399ea6b26de6 +; +; Pipeline Runtime Information: +; +; Compute Shader +; NumThreads=(256,1,1) +; +; +; Buffer Definitions: +; +; cbuffer +; { +; +; [8 x i8] (type annotation not present) +; +; } +; +; Resource bind info for +; { +; +; [4 x i8] (type annotation not present) +; +; } +; +; +; Resource Bindings: +; +; Name Type Format Dim ID HLSL Bind Count +; ------------------------------ ---------- ------- ----------- ------- -------------- ------ +; cbuffer NA NA CB0 cb0 1 +; UAV struct r/w U0 u0 1 +; +target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-ms-dx" + +%dx.types.Handle = type { i8* } +%dx.types.CBufRet.i32 = type { i32, i32, i32, i32 } +%dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 } +%"class.RWStructuredBuffer" = type { i32 } +%Constants = type { i32, i32 } + +define void @CSMain() { + %1 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) + %2 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 2, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) + %3 = call i32 @dx.op.threadId.i32(i32 93, i32 0) ; ThreadId(component) + %4 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %2, i32 0) ; CBufferLoadLegacy(handle,regIndex) + %5 = extractvalue %dx.types.CBufRet.i32 %4, 1 + %6 = add i32 %5, %3 + %7 = extractvalue %dx.types.CBufRet.i32 %4, 0 + %8 = icmp ult i32 %6, %7 + br i1 %8, label %9, label %13 + +;