Skip to content

Commit

Permalink
feat: add directml support
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jul 20, 2024
1 parent 25f0a10 commit 53e3833
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 7 deletions.
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
Expand Down Expand Up @@ -160,6 +161,14 @@ else()
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
endif()

if(SHERPA_ONNX_ENABLE_DIRECTML)
message(STATUS "DirectML is enabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
else()
message(WARNING "DirectML is disabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
endif()

if(SHERPA_ONNX_ENABLE_WASM_TTS)
if(NOT SHERPA_ONNX_ENABLE_TTS)
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
Expand Down
72 changes: 72 additions & 0 deletions cmake/onnxruntime-win-x64-directml.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2022-2023 Xiaomi Corporation
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}")

if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows)
message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}")
endif()

if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64))
message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}")
endif()

if(BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building static libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

if(NOT CMAKE_BUILD_TYPE STREQUAL Release)
message(FATAL_ERROR "This file is for building a release version on Windows x64")
endif()

set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a")

# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
# You can add more if you want.
set(possible_file_locations
$ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
/tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(onnxruntime_URL "${f}")
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
set(onnxruntime_URL2)
break()
endif()
endforeach()

FetchContent_Declare(onnxruntime
URL
${onnxruntime_URL}
${onnxruntime_URL2}
URL_HASH ${onnxruntime_HASH}
)

FetchContent_GetProperties(onnxruntime)
if(NOT onnxruntime_POPULATED)
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
FetchContent_Populate(onnxruntime)
endif()
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")

# for static libraries, we use onnxruntime_lib_files directly below
include_directories(${onnxruntime_SOURCE_DIR}/build/native/include)

file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/*.lib")

set(onnxruntime_lib_files ${onnxruntime_lib_files} PARENT_SCOPE)

message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
if(SHERPA_ONNX_ENABLE_PYTHON)
install(FILES ${onnxruntime_lib_files} DESTINATION ..)
else()
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
endif()
5 changes: 4 additions & 1 deletion cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ function(download_onnxruntime)
include(onnxruntime-win-arm64)
else()
# for 64-bit windows (x64)
if(BUILD_SHARED_LIBS)
if(SHERPA_ONNX_ENABLE_DIRECTML)
message(STATUS "Use DirectML")
include(onnxruntime-win-x64-directml)
elseif(BUILD_SHARED_LIBS)
message(STATUS "Use dynamic onnxruntime libraries")
if(SHERPA_ONNX_ENABLE_GPU)
include(onnxruntime-win-x64-gpu)
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
return Provider::kNNAPI;
} else if (s == "trt") {
return Provider::kTRT;
} else if (s == "directml") {
return Provider::kDirectML;
} else {
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
return Provider::kCPU;
Expand Down
13 changes: 7 additions & 6 deletions sherpa-onnx/csrc/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ namespace sherpa_onnx {
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
// for a list of available providers
enum class Provider {
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
kDirectML = 6, // DmlExecutionProvider
};

/**
Expand Down
22 changes: 22 additions & 0 deletions sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include "nnapi_provider_factory.h" // NOLINT
#endif

#if defined(_WIN32) && defined(SHERPA_ONNX_ENABLE_DIRECTML)
#include "dml_provider_factory.h" // NOLINT
#endif

namespace sherpa_onnx {

static void OrtStatusFailure(OrtStatus *status, const char *s) {
Expand Down Expand Up @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
}
break;
}
case Provider::kDirectML: {
#if defined(_WIN32)
sess_opts.DisableMemPattern();
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
int32_t device_id = 0;
OrtStatus *status =
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
if (status) {
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
api.ReleaseStatus(status);
}
#else
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
#endif
break;
}
case Provider::kCoreML: {
#if defined(__APPLE__)
uint32_t coreml_flags = 0;
Expand Down

0 comments on commit 53e3833

Please sign in to comment.