From 53e3833aee70fe4bb31b42e008f2711e110659ec Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Sat, 20 Jul 2024 18:04:00 +0300 Subject: [PATCH] feat: add directml support --- CMakeLists.txt | 9 +++ cmake/onnxruntime-win-x64-directml.cmake | 72 ++++++++++++++++++++++++ cmake/onnxruntime.cmake | 5 +- sherpa-onnx/csrc/provider.cc | 2 + sherpa-onnx/csrc/provider.h | 13 +++-- sherpa-onnx/csrc/session.cc | 22 ++++++++ 6 files changed, 116 insertions(+), 7 deletions(-) create mode 100644 cmake/onnxruntime-win-x64-directml.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 661647020..4e0ebd577 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) +option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF) option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) @@ -160,6 +161,14 @@ else() add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0) endif() +if(SHERPA_ONNX_ENABLE_DIRECTML) + message(STATUS "DirectML is enabled") + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1) +else() + message(WARNING "DirectML is disabled") + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0) +endif() + if(SHERPA_ONNX_ENABLE_WASM_TTS) if(NOT SHERPA_ONNX_ENABLE_TTS) message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS") diff --git a/cmake/onnxruntime-win-x64-directml.cmake b/cmake/onnxruntime-win-x64-directml.cmake new file mode 100644 index 000000000..235571522 --- /dev/null +++ b/cmake/onnxruntime-win-x64-directml.cmake @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2023 Xiaomi Corporation +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}") + +if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows) + message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}") +endif() + +if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64)) + message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}") +endif() + +if(BUILD_SHARED_LIBS) + message(FATAL_ERROR "This file is for building static libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}") +endif() + +if(NOT CMAKE_BUILD_TYPE STREQUAL Release) + message(FATAL_ERROR "This file is for building a release version on Windows x64") +endif() + +set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") +set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") +set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a") + +# If you don't have access to the Internet, +# please download onnxruntime to one of the following locations. +# You can add more if you want. +set(possible_file_locations + $ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + /tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg +) + +foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(onnxruntime_URL "${f}") + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}") + set(onnxruntime_URL2) + break() + endif() +endforeach() + +FetchContent_Declare(onnxruntime + URL + ${onnxruntime_URL} + ${onnxruntime_URL2} + URL_HASH ${onnxruntime_HASH} +) + +FetchContent_GetProperties(onnxruntime) +if(NOT onnxruntime_POPULATED) + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}") + FetchContent_Populate(onnxruntime) +endif() +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") + +# for static libraries, we use onnxruntime_lib_files directly below +include_directories(${onnxruntime_SOURCE_DIR}/build/native/include) + +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/*.lib") + +set(onnxruntime_lib_files ${onnxruntime_lib_files} PARENT_SCOPE) + +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") +if(SHERPA_ONNX_ENABLE_PYTHON) + install(FILES ${onnxruntime_lib_files} DESTINATION ..) +else() + install(FILES ${onnxruntime_lib_files} DESTINATION lib) +endif() \ No newline at end of file diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index d1c4dc851..6655b45cd 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -95,7 +95,10 @@ function(download_onnxruntime) include(onnxruntime-win-arm64) else() # for 64-bit windows (x64) - if(BUILD_SHARED_LIBS) + if(SHERPA_ONNX_ENABLE_DIRECTML) + message(STATUS "Use DirectML") + include(onnxruntime-win-x64-directml) + elseif(BUILD_SHARED_LIBS) message(STATUS "Use dynamic onnxruntime libraries") if(SHERPA_ONNX_ENABLE_GPU) include(onnxruntime-win-x64-gpu) diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc index 19d585976..3baed32c1 100644 --- a/sherpa-onnx/csrc/provider.cc +++ b/sherpa-onnx/csrc/provider.cc @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) { return Provider::kNNAPI; } else if (s == "trt") { return Provider::kTRT; + } else if (s == "directml") { + return Provider::kDirectML; } else { SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); return Provider::kCPU; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index 712006f2b..2b85b8a2e 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -14,12 +14,13 @@ namespace sherpa_onnx { // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java // for a list of available providers enum class Provider { - kCPU = 0, // CPUExecutionProvider - kCUDA = 1, // CUDAExecutionProvider - kCoreML = 2, // CoreMLExecutionProvider - kXnnpack = 3, // XnnpackExecutionProvider - kNNAPI = 4, // NnapiExecutionProvider - kTRT = 5, // TensorRTExecutionProvider + kCPU = 0, // CPUExecutionProvider + kCUDA = 1, // CUDAExecutionProvider + kCoreML = 2, // CoreMLExecutionProvider + kXnnpack = 3, // XnnpackExecutionProvider + kNNAPI = 4, // NnapiExecutionProvider + kTRT = 5, // TensorRTExecutionProvider + kDirectML = 6, // DmlExecutionProvider }; /** diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index fb5932c47..24fa5973e 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -19,6 +19,10 @@ #include "nnapi_provider_factory.h" // NOLINT #endif +#if defined(_WIN32) && defined(SHERPA_ONNX_ENABLE_DIRECTML) +#include "dml_provider_factory.h" // NOLINT +#endif + namespace sherpa_onnx { static void OrtStatusFailure(OrtStatus *status, const char *s) { @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl( } break; } + case Provider::kDirectML: { +#if defined(_WIN32) + sess_opts.DisableMemPattern(); + sess_opts.SetExecutionMode(ORT_SEQUENTIAL); + int32_t device_id = 0; + OrtStatus *status = + OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id); + if (status) { + const auto &api = Ort::GetApi(); + const char *msg = api.GetErrorMessage(status); + SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg); + api.ReleaseStatus(status); + } +#else + SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!"); +#endif + break; + } case Provider::kCoreML: { #if defined(__APPLE__) uint32_t coreml_flags = 0;