Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into prathikrao/slice-op-webgpu…
Browse files Browse the repository at this point in the history
…-native
  • Loading branch information
prathikr committed Dec 16, 2024
2 parents a3963d2 + ae97068 commit bf852c4
Show file tree
Hide file tree
Showing 15 changed files with 272 additions and 74 deletions.
13 changes: 13 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware accelera
option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF)
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.")
option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF)
# The following 2 options are only for Windows
option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF)
option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON)

# Options related to reducing the binary size produced by the build
# XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON
Expand Down Expand Up @@ -955,9 +959,18 @@ if (onnxruntime_USE_WEBGPU)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu)
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
list(APPEND ORT_PROVIDER_FLAGS -DBUILD_DAWN_MONOLITHIC_LIBRARY=1)
endif()
if (onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1)
endif()
if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN)
list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_VULKAN=1)
endif()
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1)
endif()
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
Expand Down
41 changes: 33 additions & 8 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -635,10 +635,19 @@ if (onnxruntime_USE_WEBGPU)
)
endif()

# use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size
set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE)
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE)
set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE)

if (onnxruntime_USE_EXTERNAL_DAWN)
message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.")
endif()
else()
# use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size
set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
endif()
set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE)
set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE)

Expand Down Expand Up @@ -667,18 +676,34 @@ if (onnxruntime_USE_WEBGPU)
set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE)
set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE)

# Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it.
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12))
message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.")
endif()
if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN)
set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE)
set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE)
else()
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
endif()
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE)
else()
set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE)
endif()
# We are currently always using the D3D12 backend.
set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE)
endif()

onnxruntime_fetchcontent_makeavailable(dawn)

if (NOT onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native)
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn)
else()
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native)
endif()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()

set(onnxruntime_LINK_DIRS)
Expand Down
22 changes: 19 additions & 3 deletions cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,25 @@
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native)

if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)

if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS)
list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll")
endif()

# Copy webgpu_dawn.dll to the output directory
add_custom_command(
TARGET onnxruntime_providers_webgpu
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different "$<TARGET_FILE:dawn::webgpu_dawn>" "$<TARGET_FILE_DIR:onnxruntime_providers_webgpu>"
VERBATIM )
else()
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native)
endif()
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc)
endif()
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc)

set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
124 changes: 66 additions & 58 deletions include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -1467,11 +1467,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
status = ThreadStatus::Spinning;
}

void SetBlocked(std::function<bool()> should_block,
bool SetBlocked(std::function<bool()> should_block,
std::function<void()> post_block) {
std::unique_lock<std::mutex> lk(mutex);
assert(GetStatus() == ThreadStatus::Spinning);
status.store(ThreadStatus::Blocking, std::memory_order_relaxed);
auto old_status = status.exchange(ThreadStatus::Blocking, std::memory_order_seq_cst);
if (old_status != ThreadStatus::Spinning) {
// Encountered a logical error
return false;
}
if (should_block()) {
status.store(ThreadStatus::Blocked, std::memory_order_relaxed);
do {
Expand All @@ -1480,6 +1483,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
post_block();
}
status.store(ThreadStatus::Spinning, std::memory_order_relaxed);
return true;
}

private:
Expand Down Expand Up @@ -1558,62 +1562,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

// Attempt to block
if (!t) {
td.SetBlocked( // Pre-block test
[&]() -> bool {
bool should_block = true;
// Check whether work was pushed to us while attempting to block. We make
// this test while holding the per-thread status lock, and after setting
// our status to ThreadStatus::Blocking.
//
// This synchronizes with ThreadPool::Schedule which pushes work to the queue
// and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake):
//
// Main thread: Worker:
// #1 Push work #A Set status blocking
// #2 Read worker status #B Check queue
// #3 Wake if blocking/blocked
//
// If #A is before #2 then main sees worker blocked and wakes
//
// If #A if after #2 then #B will see #1, and we abandon blocking
assert(!t);
t = q.PopFront();
if (t) {
should_block = false;
}

// No work pushed to us, continue attempting to block. The remaining
// test is to synchronize with termination requests. If we are
// shutting down and all worker threads blocked without work, that's
// we are done.
if (should_block) {
blocked_++;
if (done_ && blocked_ == num_threads_) {
should_block = false;
// Almost done, but need to re-check queues.
// Consider that all queues are empty and all worker threads are preempted
// right after incrementing blocked_ above. Now a free-standing thread
// submits work and calls destructor (which sets done_). If we don't
// re-check queues, we will exit leaving the work unexecuted.
if (NonEmptyQueueIndex() != -1) {
// Note: we must not pop from queues before we decrement blocked_,
// otherwise the following scenario is possible. Consider that instead
// of checking for emptiness we popped the only element from queues.
// Now other worker threads can start exiting, which is bad if the
// work item submits other work. So we just check emptiness here,
// which ensures that all worker threads exit at the same time.
blocked_--;
} else {
should_exit = true;
if (!td.SetBlocked( // Pre-block test
[&]() -> bool {
bool should_block = true;
// Check whether work was pushed to us while attempting to block. We make
// this test while holding the per-thread status lock, and after setting
// our status to ThreadStatus::Blocking.
//
// This synchronizes with ThreadPool::Schedule which pushes work to the queue
// and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake):
//
// Main thread: Worker:
// #1 Push work #A Set status blocking
// #2 Read worker status #B Check queue
// #3 Wake if blocking/blocked
//
// If #A is before #2 then main sees worker blocked and wakes
//
// If #A if after #2 then #B will see #1, and we abandon blocking
assert(!t);
t = q.PopFront();
if (t) {
should_block = false;
}

// No work pushed to us, continue attempting to block. The remaining
// test is to synchronize with termination requests. If we are
// shutting down and all worker threads blocked without work, that's
// we are done.
if (should_block) {
blocked_++;
if (done_ && blocked_ == num_threads_) {
should_block = false;
// Almost done, but need to re-check queues.
// Consider that all queues are empty and all worker threads are preempted
// right after incrementing blocked_ above. Now a free-standing thread
// submits work and calls destructor (which sets done_). If we don't
// re-check queues, we will exit leaving the work unexecuted.
if (NonEmptyQueueIndex() != -1) {
// Note: we must not pop from queues before we decrement blocked_,
// otherwise the following scenario is possible. Consider that instead
// of checking for emptiness we popped the only element from queues.
// Now other worker threads can start exiting, which is bad if the
// work item submits other work. So we just check emptiness here,
// which ensures that all worker threads exit at the same time.
blocked_--;
} else {
should_exit = true;
}
}
}
}
}
return should_block;
},
// Post-block update (executed only if we blocked)
[&]() {
blocked_--;
});
return should_block;
},
// Post-block update (executed only if we blocked)
[&]() {
blocked_--;
})) {
// Encountered a fatal logic error in SetBlocked
should_exit = true;
break;
}
// Thread just unblocked. Unless we picked up work while
// blocking, or are exiting, then either work was pushed to
// us, or it was pushed to an overloaded queue
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ struct OrtVitisAIEpAPI {
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps,
const char* const* keys,
const char* const* values, size_t kv_len) = nullptr;
void (*profiler_collect)(
std::vector<EventInfo>& api_events,
std::vector<EventInfo>& kernel_events);
void Ensure() {
if (handle_)
return;
Expand All @@ -81,6 +84,7 @@ struct OrtVitisAIEpAPI {
}
std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version",
(void**)&vaip_get_version);
std::ignore = env.GetSymbolFromLibrary(handle_, "profiler_collect", (void**)&profiler_collect);
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes));
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start));
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options));
Expand All @@ -97,6 +101,14 @@ static vaip_core::OrtApiForVaip the_global_api;
std::shared_ptr<KernelRegistry> get_kernel_registry_vitisaiep() { return s_kernel_registry_vitisaiep; }
const std::vector<OrtCustomOpDomain*>& get_domains_vitisaiep() { return s_domains_vitisaiep; }

void profiler_collect(
std::vector<EventInfo>& api_events,
std::vector<EventInfo>& kernel_events) {
if (s_library_vitisaiep.profiler_collect) {
s_library_vitisaiep.profiler_collect(api_events, kernel_events);
}
}

vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(
const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) {
auto model_path = graph_viewer.ModelPath().string();
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/global_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,18 @@ int vitisai_ep_set_ep_dynamic_options(
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps,
const char* const* keys,
const char* const* values, size_t kv_len);
/**
* Replace EventRecord with std::tuple<std::string, int ,int, long long, long long>,
* because EventRecord is defined in profiler_common.h which is used inside onnxruntime.
* However, profiler_collect function will call vitis ep which can't include profiler_common.h.
*/
using EventInfo = std::tuple<
std::string, // name
int, // pid
int, // tid
long long, // timestamp
long long // duration
>;
void profiler_collect(
std::vector<EventInfo>& api_events,
std::vector<EventInfo>& kernel_events);
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.
#include "vitisai_execution_provider.h"
#include "vitisai_profiler.h"

// Standard headers/libs.
#include <cassert>
Expand Down Expand Up @@ -135,4 +136,8 @@ common::Status VitisAIExecutionProvider::SetEpDynamicOptions(gsl::span<const cha
}
return Status::OK();
}

std::unique_ptr<profiling::EpProfiler> VitisAIExecutionProvider::GetProfiler() {
return std::make_unique<profiling::VitisaiProfiler>();
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class VitisAIExecutionProvider : public IExecutionProvider {
std::vector<NodeComputeInfo>& node_compute_funcs) override;
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;

std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

// This method is called after both `GetComputeCapabilityOps()` and `Compile()`.
// This timing is required to work with both compliation-based EPs and non-compilation-based EPs.
const InlinedVector<const Node*> GetEpContextNodes() const override;
Expand Down
49 changes: 49 additions & 0 deletions onnxruntime/core/providers/vitisai/vitisai_profiler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.

#include "vitisai_profiler.h"

namespace onnxruntime {
namespace profiling {

#if defined(USE_VITISAI)

bool VitisaiProfiler::StartProfiling(TimePoint tp) {
return true;
}

void VitisaiProfiler::EndProfiling(TimePoint tp, Events& events) {
auto time_point =
std::chrono::duration_cast<std::chrono::microseconds>(tp.time_since_epoch()).count();

std::vector<EventInfo> api_events;
std::vector<EventInfo> kernel_events;
profiler_collect(api_events, kernel_events);

std::unordered_map<std::string, std::string> event_args;

for (auto& a : api_events) {
events.emplace_back(EventCategory::API_EVENT,
std::get<1>(a), // pid
std::get<2>(a), // tid
std::get<0>(a), // name
std::get<3>(a) - time_point, // timestamp
std::get<4>(a), // duration
event_args);
}

for (auto& k : kernel_events) {
events.emplace_back(EventCategory::KERNEL_EVENT,
std::get<1>(k),
std::get<2>(k),
std::get<0>(k),
std::get<3>(k) - time_point,
std::get<4>(k),
event_args);
}
}

#endif

} // namespace profiling
} // namespace onnxruntime
Loading

0 comments on commit bf852c4

Please sign in to comment.