Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph capture and shaders to improve DirectML performance #272

Merged
merged 47 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
bdf41a0
Add DML support
PatriceVignola Apr 16, 2024
67d00e2
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 16, 2024
63c3ddf
elif -> elseif in cmake file
PatriceVignola Apr 16, 2024
d20b3cc
revert seqlen_k/total_seq_len in model builder
wangyems Apr 17, 2024
eb247c8
genai tool: remove seqlen_k/total_seq_len and add static buffer for a…
wangyems Apr 17, 2024
0eadc9f
readme
wangyems Apr 18, 2024
953e90c
Update README.md
wangyems Apr 18, 2024
ee8a9f3
update
wangyems Apr 18, 2024
fecca46
Merge branch 'wangye/attn_mask' of github.com:microsoft/onnxruntime-g…
wangyems Apr 18, 2024
c5c862a
pass device allocator
wangyems Apr 18, 2024
c926336
WIP
PatriceVignola Apr 18, 2024
d3af609
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 18, 2024
ae462dd
WIP (not working, messy)
PatriceVignola Apr 18, 2024
af3a013
Fix attention mask
PatriceVignola Apr 18, 2024
b801134
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 18, 2024
9ba5770
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 18, 2024
ba515bc
WIP
PatriceVignola Apr 19, 2024
e4a1664
WIP
PatriceVignola Apr 19, 2024
248bba1
Fix crash when closing application
PatriceVignola Apr 19, 2024
9118e66
Remove test.py
PatriceVignola Apr 19, 2024
c48ece0
Cleanup
PatriceVignola Apr 19, 2024
6f64a40
Rename variables from dml files
PatriceVignola Apr 19, 2024
164d189
Split dml_helpers
PatriceVignola Apr 19, 2024
e25ece3
Fixes
PatriceVignola Apr 19, 2024
d8c1d74
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 19, 2024
8ca3490
Disable generated files formatting
PatriceVignola Apr 19, 2024
8f99087
Add missing barriers
PatriceVignola Apr 19, 2024
a20bf62
Remove unneeded code
PatriceVignola Apr 19, 2024
2a0658c
Support cuda graphs with more than 1 generators per session
PatriceVignola Apr 20, 2024
bea3148
Fix build break
PatriceVignola Apr 20, 2024
a005c65
Fix build
PatriceVignola Apr 20, 2024
80a38d4
Fix build
PatriceVignola Apr 20, 2024
f608dc1
Fix build
PatriceVignola Apr 20, 2024
5098ab2
Fix test failures
PatriceVignola Apr 20, 2024
4fa1964
Add some DML tests
PatriceVignola Apr 20, 2024
3939361
Fix issue with different max lengths
PatriceVignola Apr 21, 2024
a076257
Fix e2e script
PatriceVignola Apr 21, 2024
d95bedf
Remove debug info from example
PatriceVignola Apr 21, 2024
84a888c
Fix crash on Nvidia
PatriceVignola Apr 21, 2024
399cae5
Fix crash on AMD
PatriceVignola Apr 21, 2024
02cff23
Revert change
PatriceVignola Apr 21, 2024
bbfe6a0
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 22, 2024
aba3ee4
Revert bad changes
PatriceVignola Apr 22, 2024
261b0b7
Change target name to onnxruntime-genai-directml
PatriceVignola Apr 22, 2024
d33f5ef
Address PR comments
PatriceVignola Apr 22, 2024
a499d6a
Address PR comments
PatriceVignola Apr 22, 2024
27c834c
Merge branch 'main' of https://github.com/microsoft/onnxruntime-genai…
PatriceVignola Apr 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ endif()
if(USE_DML)
if(WIN32)
add_compile_definitions(USE_DML=1)
add_compile_definitions(NOMINMAX)

file(GLOB dml_srcs CONFIGURE_DEPENDS
"${PROJECT_SOURCE_DIR}/src/dml/*.h"
"${PROJECT_SOURCE_DIR}/src/dml/*.cpp"
)

list(APPEND generator_srcs ${dml_srcs})
else()
message(FATAL_ERROR "USE_DML is ON but this isn't windows.")
endif()
Expand Down Expand Up @@ -135,6 +143,39 @@ endif()
file(GLOB onnxruntime_libs "${ORT_LIB_DIR}/${ONNXRUNTIME_FILES}")
if(USE_DML)
list(APPEND onnxruntime_libs "${ORT_LIB_DIR}/DirectML.dll")
target_include_directories(onnxruntime-genai PRIVATE $<TARGET_PROPERTY:${WIL_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(onnxruntime-genai PRIVATE $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>/directx)
target_include_directories(onnxruntime-genai PRIVATE $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${WIL_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>/directx)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib)

get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps ABSOLUTE)
set(DXC_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.Direct3D.DXC.1.7.2308.12)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/nuget.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/packages.config)

add_custom_command(
OUTPUT
${DXC_PACKAGE_DIR}/build/native/bin/x64/dxc.exe
DEPENDS
${PACKAGES_CONFIG}
${NUGET_CONFIG}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/nuget/src/nuget restore ${PACKAGES_CONFIG} -PackagesDirectory ${PACKAGES_DIR} -ConfigFile ${NUGET_CONFIG}
VERBATIM
)

add_custom_target(
RESTORE_PACKAGES ALL
DEPENDS
${DXC_PACKAGE_DIR}/build/native/bin/x64/dxc.exe
)

add_dependencies(RESTORE_PACKAGES nuget)
add_dependencies(onnxruntime-genai RESTORE_PACKAGES)
add_dependencies(onnxruntime-genai-static RESTORE_PACKAGES)
endif()

if(NO_TOKENIZEROOT)
Expand Down
16 changes: 14 additions & 2 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
from tqdm import tqdm

# Use input model to generate prompt
def generate_prompt(model, tokenizer, prompt_length) -> str:
def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
temperature = 1.0
prompt = "a"
tokens = tokenizer.encode(prompt)
params=og.GeneratorParams(model)
params.set_search_options(do_sample=True, top_k=5, temperature=temperature, max_length=prompt_length, min_length=prompt_length+1)
params.input_ids = tokens

if use_graph_capture:
params.try_use_cuda_graph_with_max_batch_size(1)

generator=og.Generator(model, params)
while not generator.is_done():
generator.compute_logits()
Expand Down Expand Up @@ -63,13 +67,16 @@ def main(args):
tokenizer = og.Tokenizer(model)

# Generate prompt
prompt = [generate_prompt(model, tokenizer, prompt_length)] * batch_size
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
tokens = tokenizer.encode_batch(prompt)

params = og.GeneratorParams(model)
params.input_ids = tokens
params.set_search_options(do_sample=True, top_k=args.top_k, top_p=args.top_p, temperature=temperature, max_length=max_length, min_length=max_length)

if args.use_graph_capture:
params.try_use_cuda_graph_with_max_batch_size(batch_size)

if args.verbose: print("Running warmup runs...")
for _ in tqdm(range(args.warmup)):
generator = og.Generator(model, params)
Expand Down Expand Up @@ -100,6 +107,10 @@ def main(args):
params = og.GeneratorParams(model)
params.input_ids = tokens
params.set_search_options(max_length=max_length, min_length=max_length)

if args.use_graph_capture:
params.try_use_cuda_graph_with_max_batch_size(batch_size)

generator = og.Generator(model, params)

# Measure prompt processing
Expand Down Expand Up @@ -199,5 +210,6 @@ def main(args):
parser.add_argument('-o', '--output', type=str, default='genai_e2e', help='Output CSV file name or path (with .csv extension)')
parser.add_argument('-v', '--verbose', action='store_true', help='Print extra information')
parser.add_argument('-mo', '--print_model_output', action='store_true', help='Print model output')
parser.add_argument('-gc', '--use_graph_capture', action='store_true', help='Use the graph capture feature for CUDA or DML')
args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
# NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI.
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e

34 changes: 34 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,37 @@ FetchContent_Declare(

onnxruntime_fetchcontent_makeavailable(googletest)

if(USE_DML)
set(WIL_BUILD_PACKAGING OFF CACHE BOOL "" FORCE)
set(WIL_BUILD_TESTS OFF CACHE BOOL "" FORCE)

FetchContent_Declare(
microsoft_wil
URL ${DEP_URL_microsoft_wil}
URL_HASH SHA1=${DEP_SHA1_microsoft_wil}
FIND_PACKAGE_ARGS NAMES wil
)

onnxruntime_fetchcontent_makeavailable(microsoft_wil)
set(WIL_TARGET "WIL::WIL")

FetchContent_Declare(
directx_headers
URL ${DEP_URL_directx_headers}
URL_HASH SHA1=${DEP_SHA1_directx_headers}
)

onnxruntime_fetchcontent_makeavailable(directx_headers)
set(DIRECTX_HEADERS_TARGET "DirectX-Headers")

include(ExternalProject)
ExternalProject_Add(nuget
PREFIX nuget
URL "https://dist.nuget.org/win-x86-commandline/v5.3.0/nuget.exe"
DOWNLOAD_NO_EXTRACT 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND ""
)
endif()
1 change: 1 addition & 0 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(args):
input_tokens = tokenizer.encode(args.system_prompt + text)

params = og.GeneratorParams(model)
params.try_use_cuda_graph_with_max_batch_size(1)
params.set_search_options(do_sample=args.do_random_sampling, max_length=args.max_length, min_length=args.min_length, top_p=args.top_p, top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty)
params.input_ids = input_tokens
generator = og.Generator(model, params)
Expand Down
4 changes: 4 additions & 0 deletions generate_dml_shaders.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_update_attention_mask.hlsl -E CSMain -T cs_6_2 -DT=int32_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\update_mask_int32.h
.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_update_attention_mask.hlsl -E CSMain -T cs_6_2 -DT=int64_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\update_mask_int64.h
.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_increment_values.hlsl -E CSMain -T cs_6_2 -DT=int32_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\increment_values_int32.h
.\build\_deps\Microsoft.Direct3D.DXC.1.7.2308.12\build\native\bin\x64\dxc.exe src\models\dml\dml_shaders\dml_increment_values.hlsl -E CSMain -T cs_6_2 -DT=int64_t -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh src\models\dml\generated_dml_shaders\increment_values_int64.h
13 changes: 13 additions & 0 deletions nuget.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<solution>
<add key="disableSourceControlIntegration" value="true" />
</solution>
<packageSources>
<clear />
<add key="NuGet Official" value="https://api.nuget.org/v3/index.json" />
</packageSources>
<disabledPackageSources>
<clear />
</disabledPackageSources>
</configuration>
4 changes: 4 additions & 0 deletions packages.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Microsoft.Direct3D.DXC" version="1.7.2308.12" targetFramework="native" />
</packages>
64 changes: 64 additions & 0 deletions src/dml/dml_command_allocator_ring.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <array>
#include "dml_gpu_event.h"

// A fixed-size ring of command allocators. Each time an allocator is retrieved, the allocator will
// be reset if its previously recorded commands have finished executing on the GPU.
template <size_t AllocatorCount>
class DmlCommandAllocatorRing {
public:
DmlCommandAllocatorRing(
ID3D12Device* device,
D3D12_COMMAND_LIST_TYPE commandListType,
DmlGpuEvent initialEvent) {
for (auto& info : command_allocators_) {
THROW_IF_FAILED(device->CreateCommandAllocator(
commandListType,
IID_PPV_ARGS(info.allocator.ReleaseAndGetAddressOf())));

info.completion_event = initialEvent;
}
}

ID3D12CommandAllocator* GetNextAllocator(DmlGpuEvent next_completion_event) {
size_t earliest_other_allocator = (current_command_allocator_ + 1) % AllocatorCount;

assert(!command_allocators_[current_command_allocator_].completion_event.IsSignaled() ||
command_allocators_[earliest_other_allocator].completion_event.IsSignaled());

if (command_allocators_[earliest_other_allocator].completion_event.IsSignaled()) {
THROW_IF_FAILED(command_allocators_[earliest_other_allocator].Get()->Reset());
current_command_allocator_ = earliest_other_allocator;
}

// Set the completion event for the current allocator so it can be reset eventually.
command_allocators_[current_command_allocator_].completion_event = next_completion_event;

return command_allocators_[current_command_allocator_].Get();
}

// Updates the completion event of the current allocator to a different value. This is used when the caller
// decides to issue an unrelated call to the queue such as ExecuteCommandLists which updates its fence between calling
// GetNextAllocator and executing the work which it recorded using the allocator it received.
void UpdateCurrentAllocatorCompletionEvent(DmlGpuEvent next_completion_event) {
command_allocators_[current_command_allocator_].completion_event = next_completion_event;
}

private:
struct CommandAllocatorInfo {
ComPtr<ID3D12CommandAllocator> allocator;

// The event which will be signaled when the last command list submitted using this allocator
// completes execution on the GPU.
DmlGpuEvent completion_event = {};

ID3D12CommandAllocator* Get() const { return allocator.Get(); }
};

std::array<CommandAllocatorInfo, AllocatorCount> command_allocators_;
size_t current_command_allocator_ = 0;
};
76 changes: 76 additions & 0 deletions src/dml/dml_command_queue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <assert.h>
#include <wil/result.h>
#include "dml_command_queue.h"

DmlCommandQueue::DmlCommandQueue(ID3D12CommandQueue* existing_queue)
: queue_(existing_queue), type_(existing_queue->GetDesc().Type) {
ComPtr<ID3D12Device> device;
THROW_IF_FAILED(queue_->GetDevice(IID_PPV_ARGS(device.GetAddressOf())));
THROW_IF_FAILED(device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(fence_.ReleaseAndGetAddressOf())));
}

void DmlCommandQueue::ExecuteCommandList(ID3D12CommandList* command_list) {
ExecuteCommandLists(std::span(&command_list, 1));
}

void DmlCommandQueue::ExecuteCommandLists(std::span<ID3D12CommandList*> command_lists) {
queue_->ExecuteCommandLists(static_cast<uint32_t>(command_lists.size()), command_lists.data());

++last_fence_value_;
THROW_IF_FAILED(queue_->Signal(fence_.Get(), last_fence_value_));
}

void DmlCommandQueue::Wait(ID3D12Fence* fence, uint64_t value) {
THROW_IF_FAILED(queue_->Wait(fence, value));

++last_fence_value_;
THROW_IF_FAILED(queue_->Signal(fence_.Get(), last_fence_value_));
}

DmlGpuEvent DmlCommandQueue::GetCurrentCompletionEvent() {
return DmlGpuEvent{last_fence_value_, fence_};
}

DmlGpuEvent DmlCommandQueue::GetNextCompletionEvent() {
return DmlGpuEvent{last_fence_value_ + 1, fence_};
}

void DmlCommandQueue::QueueReference(IUnknown* object, bool wait_for_unsubmitted_work) {
// If the DmlCommandQueue is closing, then queued_references_ is being cleared -- it is not OK
// to queue additional references at this time, since those references would be leaked. This
// affects any objects in queued_references_ whose destructors indirectly call QueueReference;
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
// since Close() already blocks for scheduled GPU work before clearing queued_references_.
if (!closing_) {
QueuedReference queued_reference = {GetLastFenceValue(), object};

// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
if (wait_for_unsubmitted_work) {
++queued_reference.fence_value;
}

queued_references_.push_back(queued_reference);
}
}

void DmlCommandQueue::Close() {
// Wait for flushed work:
assert(!closing_);
closing_ = true;
DmlGpuEvent event = GetCurrentCompletionEvent();
event.WaitForSignal();
queued_references_.clear();
closing_ = false;
}

void DmlCommandQueue::ReleaseCompletedReferences() {
uint64_t completed_value = GetFence()->GetCompletedValue();
while (!queued_references_.empty() && queued_references_.front().fence_value <= completed_value) {
queued_references_.pop_front();
}
}
54 changes: 54 additions & 0 deletions src/dml/dml_command_queue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <d3d12.h>
#include <deque>
#include "../span.h"
#include "dml_gpu_event.h"

// Manages a D3D12 command queue and provides a waitable fence which is signaled with a monotonically increasing
// value once each execute completes on the GPU.
class DmlCommandQueue {
public:
// Creates a DmlCommandQueue object that wraps an existing D3D12 queue.
DmlCommandQueue(ID3D12CommandQueue* existing_queue);

D3D12_COMMAND_LIST_TYPE GetType() const { return type_; }
ComPtr<ID3D12Fence> GetFence() const { return fence_; }
uint64_t GetLastFenceValue() const { return last_fence_value_; }

void ExecuteCommandList(ID3D12CommandList* command_list);
void ExecuteCommandLists(std::span<ID3D12CommandList*> command_lists);

// Queues a wait to block the GPU until the specified fence is signaled to a given value.
void Wait(ID3D12Fence* fence, uint64_t value);

// Returns an event that will become signaled when everything submitted to the queue thus far has
// completed execution on the GPU.
DmlGpuEvent GetCurrentCompletionEvent();

// Returns an event that will become signaled after the next ExecuteCommandLists call.
DmlGpuEvent GetNextCompletionEvent();

void QueueReference(IUnknown* object, bool wait_for_unsubmitted_work);

void Close();
void ReleaseCompletedReferences();

private:
struct QueuedReference {
uint64_t fence_value;
ComPtr<IUnknown> object;
};

std::deque<QueuedReference> queued_references_;

ComPtr<ID3D12CommandQueue> queue_;
D3D12_COMMAND_LIST_TYPE type_;

ComPtr<ID3D12Fence> fence_;
uint64_t last_fence_value_ = 0;
bool closing_ = false;
};
Loading
Loading