diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f9501253661a2..6b0e3659bd234 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -26,7 +26,7 @@ "component": { "type": "git", "git": { - "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", + "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -192,16 +192,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "onnx" - } - }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index f81a268d38dff..94181448fd21c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1282,14 +1282,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP16) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_U8) - add_definitions(-DOPENVINO_CONFIG_VPUX_U8=1) - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1310,16 +1302,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP32_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP32=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_FP16_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") diff --git a/cmake/deps.txt b/cmake/deps.txt index 631d326e2ba5b..575df832ac8b0 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -3,10 +3,15 @@ #The columns are separated by ";" because a list in cmake is just a ";" separated group of strings. #Names should be in lower case. They will be used as variable names in cmake. #URLs can be either https URLs or local file paths in cmake-style(directory separator is a forward slash character). -#SHA1 hashes can be generated by running sha1sum command. +#SHA1 hashes can be generated by running sha1sum command on linux. PowerShell can also be used: +# (Get-FileHash -Algorithm SHA1 ).Hash.ToLower() #If you need to change abseil's version to a different one, you may also want to update external\abseil-cpp.natvis #since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would #not affect built binaries. +# +# NOTE: You must run deps_update_and_upload.py when ready to test your changes in a CI. +# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 +# abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 @@ -18,13 +23,13 @@ fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc -googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c +googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa @@ -35,7 +40,7 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617 protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 -pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/1787867f6183f056420e532eec640cba25efafea.zip;e43e80781560c5ab404a4da20f34d846f5f5d101 +pthreadpool;https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip;07a0aa91dd9bf86f31b95497e00f31d8a261a4bd pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13 pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip;85da3caa60eb2b148613b443fbc2bfdc30689965 re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 diff --git a/cmake/external/onnx b/cmake/external/onnx index 6a20ba82b439e..b86cc54efce19 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828 +Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index 7455584f1a625..e661aa51bfc17 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -25,17 +25,23 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR}) FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool}) onnxruntime_fetchcontent_makeavailable(pthreadpool) -FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} -PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch) +FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch + ) onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool) + # the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # See source lists in _deps/googlexnnpack-src/BUILD.bazel for wasm_prod_microkernels + message("Adding WebAssembly Source Files to XNNPACK") + set(wasm_srcs "") + file(READ "${XNNPACK_DIR}/BUILD.bazel" xnnpack_bazel_config) # Replace newlines with semicolon so that it is treated as a list by CMake @@ -70,25 +76,23 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(${target_srcs} ${bazel_srcs} PARENT_SCOPE) endfunction() - GetSrcListFromBazel("PROD_SCALAR_WASM_MICROKERNEL_SRCS" prod_scalar_wasm_srcs) - GetSrcListFromBazel("ALL_WASM_MICROKERNEL_SRCS" all_wasm_srcs) - GetSrcListFromBazel("WASM32_ASM_MICROKERNEL_SRCS" wasm32_asm_srcs) + GetSrcListFromBazel("OPERATOR_SRCS" operator_srcs) + GetSrcListFromBazel("TABLE_SRCS" table_srcs) + list(APPEND wasm_srcs ${operator_srcs} ${table_srcs}) - message(DEBUG "prod_scalar_wasm_srcs: ${prod_scalar_wasm_srcs}\n") - message(DEBUG "all_wasm_srcs: ${all_wasm_srcs}\n") - message(DEBUG "wasm32_asm_srcs: ${wasm32_asm_srcs}\n") + # kernels + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c) + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c) - message("Adding WebAssembly Source Files to XNNPACK") - set(wasm_srcs "") - list(APPEND wasm_srcs ${prod_scalar_wasm_srcs}) - list(APPEND wasm_srcs ${all_wasm_srcs}) - list(APPEND wasm_srcs ${wasm32_asm_srcs}) + if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c) + target_compile_options(XNNPACK PRIVATE "-msimd128") + endif() + message(DEBUG "wasm_srcs: ${wasm_srcs}\n") target_sources(XNNPACK PRIVATE ${wasm_srcs}) - if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - GetSrcListFromBazel("ALL_WASMSIMD_MICROKERNEL_SRCS" all_wasmsimd_srcs) - message(DEBUG "all_wasmsimd_srcs: ${all_wasmsimd_srcs}") - target_sources(XNNPACK PRIVATE ${all_wasmsimd_srcs}) - endif() + # add flags from BAZEL.build + target_compile_options(XNNPACK PRIVATE "-fno-fast-math") + target_compile_options(XNNPACK PRIVATE "-fno-math-errno") endif() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 6ccaf00499e95..9d9b006c595bb 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -282,44 +282,77 @@ endif() # Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) + else() # macOS + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) + endif() + + # Setup the various directories required. Remove any existing ones so we start with a clean directory. set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries) - file(MAKE_DIRECTORY ${STATIC_LIB_DIR}) + set(STATIC_LIB_TEMP_DIR ${STATIC_LIB_DIR}/temp) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E rm -rf ${STATIC_LIB_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_LIB_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_LIB_TEMP_DIR}) - # Remove the existing files in the STATIC_LIB_DIR folder - file(GLOB _OLD_STATIC_LIBS ${STATIC_LIB_DIR}/*.a) - file(REMOVE "${_OLD_STATIC_LIBS}") + set(STATIC_FRAMEWORK_DIR ${STATIC_FRAMEWORK_OUTPUT_DIR}/static_framework/onnxruntime.framework) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E rm -rf ${STATIC_FRAMEWORK_DIR}) + add_custom_command(TARGET onnxruntime PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${STATIC_FRAMEWORK_DIR}) + + # replicate XCode's Single Object Pre-Link + # link the internal onnxruntime .o files with the external .a files into a single relocatable object + # to enforce symbol visibility. doing it this way limits the symbols included from the .a files to symbols used + # by the ORT .o files. - # Go through all the static libraries, and create symbolic links - foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ${onnxruntime_EXTERNAL_LIBRARIES}) + # If it's an onnxruntime library, extract .o files to a separate directory for each library to avoid any clashes + # with filenames (e.g. utils.o) + foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ) GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E create_symlink $ ${STATIC_LIB_DIR}/$) + set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${CUR_STATIC_LIB_OBJ_DIR}) + + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ar ARGS -x $ + WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) endif() endforeach() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) - else() # macOS - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) - endif() + # for external libraries we create a symlink to the .a file + foreach(_LIB ${onnxruntime_EXTERNAL_LIBRARIES}) + GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) + if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink + $ ${STATIC_LIB_DIR}/$) + endif() + endforeach() - # Assemble the static framework - set(STATIC_FRAMEWORK_DIR ${STATIC_FRAMEWORK_OUTPUT_DIR}/static_framework/onnxruntime.framework) - set(STATIC_FRAMEWORK_HEADER_DIR ${STATIC_FRAMEWORK_DIR}/Headers) - file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_DIR}) - # Remove all files under STATIC_FRAMEWORK_DIR (if any) - file(GLOB_RECURSE _OLD_STATIC_FRAMEWORK ${STATIC_FRAMEWORK_DIR}/*.*) - file(REMOVE "${_OLD_STATIC_FRAMEWORK}") + # do the pre-link with `ld -r` to create a single relocatable object with correct symbol visibility + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a + WORKING_DIRECTORY ${STATIC_LIB_TEMP_DIR}) + + # create the static library + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o + WORKING_DIRECTORY ${STATIC_LIB_DIR}) + # Assemble the other pieces of the static framework + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E + copy_if_different ${INFO_PLIST_PATH} ${STATIC_FRAMEWORK_DIR}/Info.plist) + + # add the framework header files + set(STATIC_FRAMEWORK_HEADER_DIR ${STATIC_FRAMEWORK_DIR}/Headers) file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_HEADER_DIR}) - # copy the header files one by one, and the Info.plist foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) + add_custom_command(TARGET onnxruntime POST_BUILD + COMMAND ${CMAKE_COMMAND} -E + copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) endforeach() - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${INFO_PLIST_PATH} ${STATIC_FRAMEWORK_DIR}/Info.plist) - # link the static library - add_custom_command(TARGET onnxruntime POST_BUILD COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime *.a WORKING_DIRECTORY ${STATIC_LIB_DIR}) endif() diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 735c86956ec4f..3f532ec2c3261 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -20,6 +20,8 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_deprecated_operators.cc" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.h" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc" "${ONNXRUNTIME_ROOT}/core/graph/function_template.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.h" "${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc" diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 043789c36c327..ce0c12804b08a 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -40,6 +40,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 30ae90e6564cc..9c00703ca0846 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -15,7 +15,8 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool flatbuffers::flatbuffers Boost::mp11 safeint_interface + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool + flatbuffers::flatbuffers Boost::mp11 safeint_interface ) add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -35,4 +36,4 @@ # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=shorten-64-to-32) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6ccf063c71290..4140eeee0d111 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -94,6 +94,11 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) @@ -109,6 +114,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") endif() set(provider_excluded_files diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index f5f98066675fb..bdb0230a8ebd0 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -41,7 +41,7 @@ function(AddTest) if (MSVC) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" "$<$>:/wd6330>") - #Abseil has a lot of C4127/C4324 warnings. + #Abseil has a lot of C4127/C4324 warnings. target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" "$<$>:/wd4127>") target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" @@ -201,8 +201,18 @@ function(AddTest) list(APPEND TEST_NODE_FLAGS "--experimental-wasm-simd") endif() + # prefer Node from emsdk so the version is more deterministic + if (DEFINED ENV{EMSDK_NODE}) + set(NODE_EXECUTABLE $ENV{EMSDK_NODE}) + else() + # warning as we don't know what node version is being used and whether things like the TEST_NODE_FLAGS + # will be valid. e.g. "--experimental-wasm-simd" is not valid with node v20 or later. + message(WARNING "EMSDK_NODE environment variable was not set. Falling back to system `node`.") + set(NODE_EXECUTABLE node) + endif() + add_test(NAME ${_UT_TARGET} - COMMAND node ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS} + COMMAND ${NODE_EXECUTABLE} ${TEST_NODE_FLAGS} ${_UT_TARGET}.js ${TEST_ARGS} WORKING_DIRECTORY $ ) endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index c6510c97a617e..9014089cb6112 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -192,8 +192,13 @@ else() onnxruntime_util re2::re2 ) + + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) + string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ALLOW_TABLE_GROWTH=1") endif() if(onnxruntime_USE_WEBNN) @@ -204,7 +209,6 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() - set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']") if (onnxruntime_USE_JSEP) set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") else() @@ -212,7 +216,7 @@ else() endif() target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:-s EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}" + "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" "SHELL:-s MAXIMUM_MEMORY=4294967296" "SHELL:-s EXIT_RUNTIME=0" diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 37bdbf9fb53f6..460b4d97c499b 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,66 +1,27 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d53c48aa1..77c3cf983 100755 +index dba9b4687..bcaa18ad7 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -105,22 +105,12 @@ ENDIF() - +@@ -122,7 +122,7 @@ ENDIF() + # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") --ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS)$") -+ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android|Windows|CYGWIN|MSYS|Emscripten|iOS)$") - MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}") +-ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT)$") ++ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Android|Darwin|iOS|Linux|Windows|CYGWIN|MSYS|QURT|Emscripten|iOS)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() - - # ---[ Download deps - IF(NOT XNNPACK_USE_SYSTEM_LIBS) -- IF(NOT DEFINED CLOG_SOURCE_DIR) -- MESSAGE(STATUS "Downloading clog to ${CMAKE_BINARY_DIR}/clog-source (define CLOG_SOURCE_DIR to avoid it)") -- CONFIGURE_FILE(cmake/DownloadCLog.cmake "${CMAKE_BINARY_DIR}/clog-download/CMakeLists.txt") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") -- EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . -- WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/clog-download") -- SET(CLOG_SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" CACHE STRING "clog source directory") -- ENDIF() -- - IF(NOT DEFINED CPUINFO_SOURCE_DIR) - MESSAGE(STATUS "Downloading cpuinfo to ${CMAKE_BINARY_DIR}/cpuinfo-source (define CPUINFO_SOURCE_DIR to avoid it)") - CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CMAKE_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") -@@ -7108,6 +7098,10 @@ IF(MSVC) - SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") - SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O2 >") - SET_PROPERTY(SOURCE ${COLD_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: /O1 >") -+ELSEIF(CMAKE_GENERATOR STREQUAL Xcode) -+ TARGET_COMPILE_OPTIONS(all_microkernels PRIVATE $<$>: -O2 >) -+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$>: -O2 >) -+ TARGET_COMPILE_OPTIONS(XNNPACK PRIVATE $<$>: -Os >) - ELSE() - SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") - SET_PROPERTY(SOURCE ${HOT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS "$<$>: -O2 >") -@@ -7142,26 +7136,6 @@ IF(LIBM) - TARGET_LINK_LIBRARIES(indirection PRIVATE ${LIBM}) + IF(CMAKE_SYSTEM_NAME MATCHES "Windows") +@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging) + TARGET_LINK_LIBRARIES(post-operation PRIVATE logging) + TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") ++ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph) ++ ELSE() ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) ++ ENDIF() + SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() - --# ---[ Configure clog --IF(NOT TARGET clog) -- IF(NOT XNNPACK_USE_SYSTEM_LIBS) -- SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") -- SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") -- ADD_SUBDIRECTORY( -- "${CLOG_SOURCE_DIR}/deps/clog" -- "${CMAKE_BINARY_DIR}/clog") -- # We build static version of clog but a dynamic library may indirectly depend on it -- SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) -- ELSE() -- ADD_LIBRARY(clog STATIC IMPORTED) -- FIND_LIBRARY(CLOG_LIBRARY clog) -- IF(NOT CLOG_LIBRARY) -- MESSAGE(FATAL_ERROR "Cannot find clog") -- ENDIF() -- SET_PROPERTY(TARGET clog PROPERTY IMPORTED_LOCATION "${CLOG_LIBRARY}") -- ENDIF() --ENDIF() -- - # ---[ Configure cpuinfo - IF(NOT TARGET cpuinfo) - IF(NOT XNNPACK_USE_SYSTEM_LIBS) + IF(NOT MSVC) diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 0288d752d8749..69bfd9896f1e4 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -17,9 +17,13 @@ CMake creates a target to this project x64 false - false + true + true None + + true + .. ..\tools\nuget\generate_nuspec_for_native_nuget.py @@ -30,13 +34,15 @@ CMake creates a target to this project ..\build\Linux $(OnnxRuntimeBuildDirectory)\packages $(OnnxRuntimeBuildDirectory)\$(Configuration) + python3 - + ..\build\Windows $(OnnxRuntimeBuildDirectory)\packages $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) + python @@ -86,28 +92,48 @@ CMake creates a target to this project - + + + Properties="NoBuild=true;Platform=AnyCPU;PackageVersion=$(PackageVersion);OrtPackageId=$(OrtPackageId);IncludeMobileTargets=$(IncludeMobileTargets)"/> - - + + + - - + + + - + + + + - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 29ccf55f081d5..0c74a23204d4f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -4,66 +4,53 @@ Microsoft.ML.OnnxRuntime - - PreNet6 - netstandard2.0;netcoreapp3.1;net6.0 + true + netstandard2.0 + - - - xamarinios10;monoandroid11.0 + + + false - - monoandroid11.0 - + + NOTE: We include in a build of the managed package when creating Microsoft.ML.OnnxRuntime.Gpu as both + the CPU and GPU packaging pipelines can publish Microsoft.ML.OnnxRuntime.Managed, and we need the targets + to be consistent in both. + --> - net6.0;net6.0-android;net6.0-ios;net6.0-macos + '$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Gpu') AND + '$(IncludeMobileTargets)' == 'true' AND + Exists('$(MSBuildExtensionsPath)\Xamarin\Android') AND + Exists('$(MSBuildExtensionsPath)\Xamarin\iOS')"> + xamarinios10;monoandroid11.0 - - net6.0;net6.0-android + + monoandroid11.0 - - $(BaseTargets);$(XamarinTargets);$(XamarinTargetsForTraining) + + + $(MobileTargets);net6.0-android;net6.0-ios - - $(Net6Targets);$(Net6TargetsForTrainingPackage) + + $(MobileTargets);net6.0-android - - - $(BaseTargets);$(XamarinTargets);$(XamarinTargetsForTraining);$(Net6Targets);$(Net6TargetsForTrainingPackage) + + $(BaseTargets);$(MobileTargets) - AnyCPU;x86 default @@ -204,8 +191,9 @@ $(DefineConstants);$(OrtConstants) - + @@ -214,7 +202,6 @@ - --> + + + netstandard2.0 + $(OnnxRuntimeBuildDirectory)/NativeNuget.nuspec + + diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 890403556cc47..646465ef8b56f 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -95,6 +95,7 @@ Do not modify directly.* * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -2342,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2421,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
-#### Outputs (1 - 3) +#### Outputs
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key (optional) : T
+
present_key : T
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value (optional) : T
+
present_value : T
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
@@ -2579,9 +2580,32 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. - Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. - Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] + + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. @@ -2597,6 +2621,10 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
quant_type : int (required)
quantization data type. 0 for FP4, 1 for NF4.
+
training_mode : int
+
Indicate if the ops run in training_mode, by default, False.
+
transB : int
+
Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.
#### Inputs @@ -5083,6 +5111,72 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bfb7716dc5cea..bbfc33a915fae 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -801,7 +801,7 @@ Do not modify directly.* |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| +|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|16+|**B** = tensor(bool)
**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 15]|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |Xor|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| | | @@ -861,6 +861,7 @@ Do not modify directly.* |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index f12c01d278dca..6ef16e1378139 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,6 @@ OpenVINOâ„¢ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs Installation ------------ @@ -22,7 +21,6 @@ This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs ``pip3 install onnxruntime-openvino`` diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bbdee0c91b885..0515eadb6bf0b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -611,7 +611,7 @@ typedef struct OrtMIGraphXProviderOptions { typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus OrtOpenVINOProviderOptions() : device_type{}, - enable_vpu_fast_compile{}, + enable_npu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, @@ -624,7 +624,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default @@ -4614,6 +4614,10 @@ struct OrtCustomOp { OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 53b9bead3f1e6..29ba3e1714630 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2229,6 +2229,8 @@ struct ShapeInferContext { using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -2281,6 +2283,14 @@ struct CustomOpBase : OrtCustomOp { } SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -2349,6 +2359,9 @@ struct CustomOpBase : OrtCustomOp { protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index b12221e56b79f..443710884743a 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp { PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { + const char* execution_provider, + int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -837,6 +840,16 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::KernelCompute = {}; OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; } const std::string op_name_; @@ -844,6 +857,9 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFn compute_fn, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_(compute_fn), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFnReturnStatus compute_fn_return_status, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_return_status_(compute_fn_return_status), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { }; OrtLiteCustomStruct(const char* op_name, - const char* execution_provider) : OrtLiteCustomOp(op_name, - execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { @@ -1049,25 +1070,31 @@ template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, void (*custom_compute_fn)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, Status (*custom_compute_fn_v2)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, - const char* execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct; - return std::make_unique(op_name, execution_provider).release(); + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 47e67879e66ce..ee6d26b22b1f6 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -2,11 +2,18 @@ // Licensed under the MIT License. import {resolveBackend} from './backend-impl.js'; -import {TrainingSessionHandler} from './backend.js'; +import {SessionHandler, TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {Tensor} from './tensor.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; +type FeedsType = InferenceSession.FeedsType; +type FetchesType = InferenceSession.FetchesType; +type ReturnType = InferenceSession.ReturnType; +type RunOptions = InferenceSession.RunOptions; + const noBackendErrMsg: string = 'Training backend could not be resolved. ' + 'Make sure you\'re using the correct configuration & WebAssembly files.'; @@ -42,21 +49,138 @@ export class TrainingSession implements TrainingSessionInterface { } } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + /** + * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from + * the given parameters to SessionHandler.FetchesType and RunOptions. + * + * @param feeds the required input + * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object + * @param arg2 optional RunOptions object. + * @returns + */ + typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): + [SessionHandler.FetchesType, RunOptions] { + const fetches: {[name: string]: OnnxValue|null} = {}; + let options: RunOptions = {}; + // check inputs + if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { + throw new TypeError( + '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + } + + let isFetchesEmpty = true; + // determine which override is being used + if (typeof arg1 === 'object') { + if (arg1 === null) { + throw new TypeError('Unexpected argument[1]: cannot be null.'); + } + if (arg1 instanceof Tensor) { + throw new TypeError('\'fetches\' cannot be a Tensor'); + } + + if (Array.isArray(arg1)) { + if (arg1.length === 0) { + throw new TypeError('\'fetches\' cannot be an empty array.'); + } + isFetchesEmpty = false; + // output names + for (const name of arg1) { + if (typeof name !== 'string') { + throw new TypeError('\'fetches\' must be a string array or an object.'); + } + if (this.outputNames.indexOf(name) === -1) { + throw new RangeError(`'fetches' contains invalid output name: ${name}.`); + } + fetches[name] = null; + } + + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + // decide whether arg1 is fetches or options + // if any output name is present and its value is valid OnnxValue, we consider it fetches + let isFetches = false; + const arg1Keys = Object.getOwnPropertyNames(arg1); + for (const name of this.outputNames) { + if (arg1Keys.indexOf(name) !== -1) { + const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; + if (v === null || v instanceof Tensor) { + isFetches = true; + isFetchesEmpty = false; + fetches[name] = v; + } + } + } + + if (isFetches) { + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + options = arg1 as RunOptions; + } + } + } else if (typeof arg1 !== 'undefined') { + throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + } + + // check if all inputs are in feed + for (const name of this.inputNames) { + if (typeof feeds[name] === 'undefined') { + throw new Error(`input '${name}' is missing in 'feeds'.`); + } + } + + // if no fetches is specified, we use the full output names list + if (isFetchesEmpty) { + for (const name of this.outputNames) { + fetches[name] = null; + } + } + + return [fetches, options]; } - async getContiguousParameters(_trainableOnly: boolean): Promise { + /** + * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler + * and changes it into a map of Tensors. + * + * @param results + * @returns + */ + convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { + const returnValue: {[name: string]: OnnxValue} = {}; + for (const key in results) { + if (Object.hasOwnProperty.call(results, key)) { + const result = results[key]; + if (result instanceof Tensor) { + returnValue[key] = result; + } else { + returnValue[key] = new Tensor(result.type, result.data, result.dims); + } + } + } + return returnValue; + } + + runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; + runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const results = await this.handler.runTrainStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } + + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } - runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): - Promise; - runTrainStep( - feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions|undefined): Promise; - async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): - Promise { + async getContiguousParameters(_trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 44003021293b0..0b82a9c031baa 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,15 +20,15 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | -| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation | +| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | | Concat | ai.onnx(1-3,4-10,11-12,13+) | | -| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation | -| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | +| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; conv3d is not supported; need implementing activation | +| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation | | Cos | ai.onnx(7+) | | | Cosh | ai.onnx(9+) | | | Div | ai.onnx(7-12,13,14+) | | @@ -40,6 +40,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | @@ -56,7 +57,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | -| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation | +| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | @@ -78,7 +79,7 @@ Do not modify directly.* | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | | Relu | ai.onnx(6-12,13,14+) | | | Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | -| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | +| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 5ea7de809a495..7176823c9bf13 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -5,7 +5,7 @@ import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler'; +import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 98e40807aa29c..09dac3a85311c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,7 +4,7 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; -import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 5740263583031..78edcc90f55f9 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -5,7 +5,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler'; +import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** * This function initializes all flags for WebAssembly. diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler-inference.ts similarity index 100% rename from js/web/lib/onnxjs/session-handler.ts rename to js/web/lib/onnxjs/session-handler-inference.ts diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 40309c1849bcc..a4d51e68b6a25 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -67,6 +67,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 22b91d680a9b4..6481a6b21d723 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -41,12 +41,12 @@ export const activationFnSnippet = if (!activation) { return ''; } - // TODO: add implementations return ''; }; export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - ${activation ? 'value = activation(value, coords);' : ''} + // TODO uncomment the following line when activation is supported above. + // ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 01ddca520deed..fbb936a045b9c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -242,8 +242,9 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2], t)} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 840360223c75a..a95d3830f34eb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -236,7 +236,9 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + conv2dTransposeCommonSnippet( + isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 1032869412462..0a0f29db6a494 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; -import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -440,7 +440,7 @@ export const createMatmulProgramInfo = const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -473,8 +473,8 @@ export const createMatmulProgramInfo = const dimBOuter: i32 = ${dimBOuter}; const dimInner: i32 = ${dimInner}; ${shaderHelper.declareVariables(...inputVariables, output)} - ${declareFunctions} ${activationFunction} + ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7abf022928ade..8bfa722dd0909 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActicationSnippet} from './fuse-utils'; +import {getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -22,7 +22,7 @@ export const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActicationSnippet(attributes); + const {activationFunction, applyActivation} = getActivationSnippet(attributes); const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 92105859a8c0e..956ef18eb5cfb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,24 +10,25 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActicationSnippet = - (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; - case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; - case 'Clip': - return { - activationFunction: - `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): { + activationFunction: string; applyActivation: string; +} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + case 'Sigmoid': + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + case 'Clip': + return { + activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' : + 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts new file mode 100644 index 0000000000000..54a7414360c4e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce'; +import {createTransposeProgramInfo} from './transpose'; + +const reduceOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate * candidate', + logSumExp: 'bestValue + exp(candidate)', + l1: 'bestValue + abs(candidate)', + l2: 'bestValue + candidate * candidate', + logSum: 'bestValue + candidate' +}; + +const reduceSharedOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate', + logSumExp: 'bestValue + candidate', + l1: 'bestValue + candidate', + l2: 'bestValue + candidate', + logSum: 'bestValue + candidate' +}; + +const reduceInitValues: {[key: string]: string} = { + max: '_A[offset]', + min: '_A[offset]', + mean: '0', + sum: '0', + prod: '1', + sumSquare: '0', + logSumExp: '0', + l1: '0', + l2: '0', + logSum: '0' +}; + +const reduceOutputValues: {[key: string]: string} = { + max: 'bestValue', + min: 'bestValue', + sum: 'bestValue', + prod: 'bestValue', + sumSquare: 'bestValue', + logSumExp: 'log(bestValue)', + l1: 'bestValue', + l2: 'sqrt(bestValue)', + logSum: 'log(bestValue)' +}; + +const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => { + const res = []; + for (let i = rank - numInnerAxes; i < rank; ++i) { + res.push(i); + } + return res; +}; + +const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => { + const outputShape = []; + const rank = shape.length; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + outputShape.push(shape[dim]); + } + } + const reduceShape = axes.map(dim => shape[dim]); + return [outputShape, reduceShape]; +}; + +const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => { + const rank = shape.length + axes.length; + const expandShape = []; + let shapeIdx = 0; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + expandShape.push(shape[shapeIdx++]); + } else { + expandShape.push(1); + } + } + return expandShape; +}; + +const areAxesInnerMostDims = (axes: number[], rank: number): boolean => { + for (let i = 0; i < axes.length; ++i) { + if (axes[axes.length - i - 1] !== rank - 1 - i) { + return false; + } + } + return true; +}; + +const getAxesPermutation = (axes: number[], rank: number): number[] => { + const res = []; + if (!areAxesInnerMostDims(axes, rank)) { + for (let i = 0; i < rank; ++i) { + if (axes.indexOf(i) === -1) { + res.push(i); + } + } + axes.forEach(axis => res.push(axis)); + } + return res; +}; + +export const createReduceSharedProgramInfo = + (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string, + outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => { + const inputShape = inputs[0].dims; + + const outputSize = ShapeUtil.size(outputShape); + const reduceSize = ShapeUtil.size(reduceShape); + + const input = inputVariable('_A', inputs[0].dataType, inputShape); + const output = outputVariable('output', outputDataType, outputShape); + + const workgroupSize = 32; + + const sharedMemorySnippet = ` + var aBestValues : array<${output.type.storage}, ${workgroupSize}>; + `; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)} + ${sharedMemorySnippet} + fn DIV_CEIL(a : u32, b : u32) -> u32 { + return ((a - 1u) / b + 1u); + } + ${shaderHelper.mainStart(workgroupSize)} + let local_idx = local_id.x; + + let outputIndex = global_idx / ${workgroupSize}; + let offset = outputIndex * uniforms.reduceSize; + + var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + let Length = uniforms.reduceSize; + for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { + let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + bestValue = ${reduceOps[reduceType]}; + } + aBestValues[local_idx] = bestValue; + workgroupBarrier(); + + var reduceSize = min(Length, ${workgroupSize}u); + for (var currentSize = reduceSize / 2u; reduceSize > 1u; + currentSize = reduceSize / 2u) { + let interval = DIV_CEIL(reduceSize, 2u); + if (local_idx < currentSize) { + let candidate = aBestValues[local_idx + interval]; + bestValue = ${reduceSharedOps[reduceType]}; + aBestValues[local_idx] = bestValue; + } + reduceSize = interval; + workgroupBarrier(); + } + + if (local_idx == 0u) { + ${ + output.setByOffset( + 'outputIndex', + `${ + reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : + `${reduceOutputValues[reduceType]}`}`)}; + } + }`; + + // One work group is responsible for only one element of output. + return { + name, + shaderCache, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: outputSize}, + programUniforms: [{type: 'uint32', data: reduceSize}] + }), + }; + }; + +const reduceCommon = + (context: ComputeContext, name: string, attributes: ReduceAttributes, + reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => { + const updatedAttributes: ReduceAttributes = + context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes); + + let updatedAxes = updatedAttributes.axes; + if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { + updatedAxes = context.inputs[0].dims.map((s, i) => i); + } + const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); + + let axes = normalizeAxes; + let input = context.inputs[0]; + const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length); + if (permutedAxes.length > 0) { + input = context.compute( + createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0]; + axes = getInnerMostAxes(axes.length, input.dims.length); + } + + const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes); + let finalOutputShape = outputShape; + if (updatedAttributes.keepDims) { + finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes); + } + + context.compute( + createReduceSharedProgramInfo( + name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType, + context.inputs[0].dataType, finalOutputShape, reduceShape), + {inputs: [input]}); + }; + +export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMeanShared', attributes, 'mean'); +}; + +export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL1Shared', attributes, 'l1'); +}; + +export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL2Shared', attributes, 'l2'); +}; + +export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp'); +}; + +export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMaxShared', attributes, 'max'); +}; + +export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMinShared', attributes, 'min'); +}; + +export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceProdShared', attributes, 'prod'); +}; + +export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumShared', attributes, 'sum'); +}; + +export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare'); +}; + +export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum'); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 44d6332852d2a..b5c956e57a9b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -8,6 +8,7 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length === 0 || inputs.length > 2) { @@ -106,7 +107,7 @@ export const createReduceProgramInfo = }; }; -const createReduceAttributesFromInputs = +export const createReduceAttributesFromInputs = (inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => { const axes: number[] = []; if (inputs[1].dims[0] > 0) { @@ -131,7 +132,7 @@ const runReduceProgram = {inputs: [0]}); }; -export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -142,7 +143,7 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); }; -export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -153,7 +154,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): runReduceProgram(context, 'ReduceL1', attributes, reduceOp); }; -export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -164,7 +165,7 @@ export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): runReduceProgram(context, 'ReduceL2', attributes, reduceOp); }; -export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -175,7 +176,7 @@ export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttri runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); }; -export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; @@ -195,7 +196,7 @@ export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceMax', attributes, reduceOp); }; -export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output, axes) => { let size = 1.0; @@ -216,7 +217,7 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes runReduceProgram(context, 'ReduceMean', attributes, reduceOp); }; -export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; @@ -236,7 +237,7 @@ export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceMin', attributes, reduceOp); }; -export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(1);`, @@ -247,7 +248,7 @@ export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes runReduceProgram(context, 'ReduceProd', attributes, reduceOp); }; -export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -258,7 +259,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceSum', attributes, reduceOp); }; -export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -269,5 +270,107 @@ export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttri runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); }; +const useNaiveReduceMethod = + (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { + if (axes.length === 0) { + return noopWithEmptyAxes ? true : false; + } + + let outputSize = 1; + let reduceSize = 1; + for (let dim = 0; dim < axes.length; dim++) { + if (axes.indexOf(dim) === -1) { + outputSize *= shape[dim]; + } else { + reduceSize *= shape[dim]; + } + } + + // The condition data is very rough, although considering the count of Execution Unit (EU), the potential + // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments + // on some machines. + return reduceSize < 32 && outputSize > 1024 ? true : false; + }; + +export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMeanNaive(context, attributes); + } else { + reduceMeanShared(context, attributes); + } +}; + +export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL1Naive(context, attributes); + } else { + reduceL1Shared(context, attributes); + } +}; + +export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL2Naive(context, attributes); + } else { + reduceL2Shared(context, attributes); + } +}; + +export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumExpNaive(context, attributes); + } else { + reduceLogSumExpShared(context, attributes); + } +}; + +export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMaxNaive(context, attributes); + } else { + reduceMaxShared(context, attributes); + } +}; + +export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMinNaive(context, attributes); + } else { + reduceMinShared(context, attributes); + } +}; + +export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceProdNaive(context, attributes); + } else { + reduceProdShared(context, attributes); + } +}; + +export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumNaive(context, attributes); + } else { + reduceSumShared(context, attributes); + } +}; + +export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumSquareNaive(context, attributes); + } else { + reduceSumSquareShared(context, attributes); + } +}; + +export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumNaive(context, attributes); + } else { + reduceLogSumShared(context, attributes); + } +}; + export const parseReduceAttributes = (attributes: Record): ReduceAttributes => createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index d4dbad79e613e..ec651ce34e8c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { @@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; + const components = getMaxComponents(cols); + const packedCols = cols / components; + const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`; + + const maxVector = (name: string, components: number) => { + if (components === 4) { + return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; + } else if (components === 2) { + return `max(${name}.x, ${name}.y)`; + } else if (components === 3) { + return `max(max(${name}.x, ${name}.y), ${name}.z)`; + } + + return name; + }; // 6.2.4 in wgsl spec - const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;'; + const threadMaxDecl = + dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (_shaderHelper: ShaderHelper) => ` - var rowMaxShared : ${dataType}; - var rowSumShared : ${dataType}; - var threadShared : array<${dataType}, ${WG}>; + var rowMaxShared : ${valueType}; + var rowSumShared : ${valueType}; + var threadShared : array<${valueType}, ${WG}>; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var result : array<${dataType}>; + @group(0) @binding(0) var x : array<${valueType}>; + @group(0) @binding(1) var result : array<${valueType}>; - fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} { let index = row * row_stride + col; return x[index]; } - fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) { let index = row * row_stride + col; result[index] = value; } @@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut let lindex = i32(local_id.x); const wg = ${WG}; let row = gindex / wg; - let cols = ${cols}; - let row_stride : i32 = ${cols}; + let cols = ${packedCols}; + let row_stride : i32 = ${packedCols}; // find the rows max ${threadMaxDecl} @@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowMaxShared = threadShared[0]; + rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)}); } workgroupBarrier(); // find the rows sum - var threadSum: ${dataType} = 0.0; + var threadSum = ${valueType}(0.0); for (var col = lindex; col < cols; col += wg) { let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); threadSum += subExp; @@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowSumShared = threadShared[0]; + rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)}); } workgroupBarrier(); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index bead3e72f63c7..4238449f9246f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -29,12 +29,12 @@ const createElementwiseProgramShader = const output = outputVariable('outputData', outputDataType, [vecSize], 4); return ` - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${additionalImplementation ?? ''} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} let a = ${input.getByOffset('global_idx')}; ${output.setByOffset('global_idx', expression)} @@ -45,13 +45,16 @@ const createElementwiseProgramInfo = (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, cacheKey?: string, outputDataType: number = input.dataType): ProgramInfo => ({ name, - shaderCache: {hint: cacheKey}, + shaderCache: {hint: cacheKey, inputDependencies: ['type']}, getShaderSource: shaderHelper => createElementwiseProgramShader( shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation), getRunData: (inputTensors) => ({ outputs: [{dims: input.dims, dataType: outputDataType}], dispatchGroup: - {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)} + {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, + programUniforms: [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + ], }) }); diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts deleted file mode 100644 index 83d133b9a5157..0000000000000 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; - -import {SerializableModeldata} from './proxy-messages'; -import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - inputEncodedNames: number[]; - outputEncodedNames: number[]; - - async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return createSessionAllocate(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions) { - if (!isOrtEnvInitialized()) { - await initRuntime(env); - } - const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableModeldata = [0, 0]; - let optimizerModelData: SerializableModeldata = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = - createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint( - this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); - } - - async runTrainStep( - _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, - _options: InferenceSession.RunOptions): Promise { - throw new Error('Method not implemented yet.'); - } -} diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler-inference.ts similarity index 96% rename from js/web/lib/wasm/session-handler.ts rename to js/web/lib/wasm/session-handler-inference.ts index a5017a920f38b..3ca34d957c572 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -10,7 +10,7 @@ import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitializationPromise: Promise|undefined; -const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { +export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': return [tensor.type, tensor.dims, tensor.data, 'cpu']; @@ -21,7 +21,7 @@ const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMeta } }; -const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { +export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { switch (tensor[3]) { case 'cpu': return new Tensor(tensor[0], tensor[2], tensor[1]); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts new file mode 100644 index 0000000000000..09d91591128d1 --- /dev/null +++ b/js/web/lib/wasm/session-handler-training.ts @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + /** + * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the + * corresponding name as a number referring to the index in the list of names provided. + * + * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType + * @param names either inputNames or outputNames + * @returns a tuple of a list of values and a list of indices. + */ + convertMapIntoValuesArrayAndIndicesArray( + feeds: {[name: string]: T}, names: string[], mapFunc: (val: T, index: number) => U): [T[], number[], U[]] { + const values: T[] = []; + const indices: number[] = []; + Object.entries(feeds).forEach(kvp => { + const name = kvp[0]; + const tensor = kvp[1]; + const index = names.indexOf(name); + if (index === -1) { + throw new Error(`invalid input '${name}`); + } + values.push(tensor); + indices.push(index); + }); + + const uList = values.map(mapFunc); + return [values, indices, uList]; + } + + /** + * Helper method that converts the TensorMetadata that the wasm-core functions return to the + * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the + * corresponding result. + * + * @param results used to populate the resultMap if there is no value for that outputName already + * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results + * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. + * @returns a map of output names and OnnxValues. + */ + convertTensorMetadataToReturnType( + results: TensorMetadata[], outputArray: Array, outputIndices: number[]): SessionHandler.ReturnType { + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); + } + return resultMap; + } + + async runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.inputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.outputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } +} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 947242945c665..3aacf8f4d90e0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -240,7 +240,7 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; -const prepareInputOutputTensor = +export const prepareInputOutputTensor = (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): void => { if (!tensor) { diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 4830b5d2b5e80..a35d285346db4 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from 'onnxruntime-common'; +import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; +import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; @@ -146,6 +149,170 @@ export const createTrainingSessionHandle = } }; +/** + * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the + * WASM tensors. + * + * @param trainingSessionId + * @param indices for each tensor, the index of the input or output name that the tensor corresponds with + * @param tensors list of TensorMetaData + * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting + * handles of the allocated tensors on the heap + * @param inputOutputAllocs modified in-place by this method + * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor + */ +const createAndAllocateTensors = + (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], + inputOutputAllocs: number[], indexAdd: number) => { + const count = indices.length; + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor( + tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } + + // moves to heap + const wasm = getInstance(); + const valuesOffset = wasm.stackAlloc(count * 4); + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } + + return valuesOffset; + }; + +/** + * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information + * associated with the tensor handle. + * + * @param outputValuesOffset + * @param outputCount + * @returns list of TensorMetadata retrieved from the output handles. + */ +const moveOutputToTensorMetadataArr = + (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], + outputTensors: Array) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; + } + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); + } + } + + return output; + }; + +export const runTrainStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingRunTrainStep) { + const errorCode = wasm._OrtTrainingRunTrainStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + + if (errorCode !== 0) { + checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); + } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 0000000000000..812e9d7c2def0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 628e5408150f8..29acc07e118f9 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -164,7 +164,10 @@ async function initializeSession( session = await ort.InferenceSession.create(modelFilePath, sessionConfig); } } catch (e) { - Logger.error('TestRunner', `Failed to load model from file: ${modelFilePath}. Error: ${inspect(e)}`); + Logger.error( + 'TestRunner', + `Failed to load model from file: ${modelFilePath}. ` + + `Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`); throw e; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5184dd99309b1..0fd8790e0d29d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; - int num_splits; // number of splits for splitkv AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; }; diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h deleted file mode 100644 index 11b5447d65ed2..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -namespace onnxruntime { -namespace contrib { - -#if defined(_MSC_VER) -#define FORCEINLINE __forceinline -#else -#define FORCEINLINE __attribute__((always_inline)) inline -#endif - -template -struct alignas(1) BlockwiseQuantBlock { - static_assert(block_size % 8 == 0); - - uint8_t blob_data[block_size / 8 * bits]; - - FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const; - FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const; - - FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N); - FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N); -}; - -template -struct alignas(1) BlockwiseQuantBlock { - static_assert(block_size % 8 == 0); - - uint8_t blob_data[block_size / 2]; - - FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const { - for (int i = 0; i < block_size; i += 2) { - T zp_t = static_cast(float(zp)); - if (k_idx + i < K) { - T x0 = static_cast(float(blob_data[i / 2] & 0xF)); - dst[i] = scale * (x0 - zp_t); - } - if (k_idx + i + 1 < K) { - T x1 = static_cast(float(blob_data[i / 2] >> 4)); - dst[i + 1] = scale * (x1 - zp_t); - } - } - } - - FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const { - constexpr uint8_t zp = 8; - dequant(dst, scale, zp, k_idx, K); - } - - FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) { - float min = static_cast(*src); - float max = static_cast(*src); - int32_t klen = std::min(block_size, K - k_idx); - for (int32_t kk = 0; kk < klen; kk++) { - const float v = static_cast(src[N * kk]); - if (v < min) min = v; - if (v > max) max = v; - } - min = std::min(min, 0.0f); - max = std::max(max, 0.0f); - - const float scale = (max - min) / ((1 << 4) - 1); - scale_block = static_cast(scale); - - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - float zero_point_fp = min; - if (scale != 0.0f) { - zero_point_fp = 0.f - min / scale; - } - - // Handle any clamping - if (zero_point_fp < 0.0f) { - zp = 0; - } else if (zero_point_fp > 15.0f) { - zp = 15; - } else { - zp = (uint8_t)roundf(zero_point_fp); - } - - for (int32_t kk = 0; kk < klen; kk += 2) { - const float v0 = static_cast(src[N * kk]); - const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); - - const float v1 = static_cast((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f); - const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); - - blob_data[kk / 2] = vi0 | (vi1 << 4); - } - } - - FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) { - float amax = 0.0f; // abs(max) - float max = 0.0f; - - int32_t klen = std::min(block_size, K - k_idx); - - for (int32_t kk = 0; kk < klen; kk++) { - const float v = static_cast(src[N * kk]); - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float scale = max / (-8.f); - scale_block = static_cast(scale); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - - for (int32_t kk = 0; kk < klen; kk += 2) { - const float v0 = src[N * kk] * reciprocal_scale; - const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f))); - - const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0; - const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f))); - - blob_data[kk / 2] = vi0 | (vi1 << 4); - } - } -}; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h deleted file mode 100644 index 8811e5649fc19..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "blockwise_quant_block.h" - -#include - -#include "core/common/safeint.h" -#include "core/framework/float16.h" -#include "core/platform/threadpool.h" -#include - -namespace onnxruntime { -namespace contrib { - -template -void QuantizeBlockwise( - uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] - const T* src, // shape: [K, N] - T* scale, // shape: [N * block_per_K] - uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2] - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - BlockwiseQuantBlock* dst_blob = - reinterpret_cast*>(dst); - - int32_t block_per_K = (K + block_size - 1) / block_size; - int32_t total_block_count = N * block_per_K; - - std::vector zero_points_tmp; // to avoid race condition - (void)zero_points_tmp; - uint8_t* zero_points_tmp_ptr = zero_points; - if (bits <= 4 && zero_points != nullptr) { - zero_points_tmp.resize(total_block_count, 0); - zero_points_tmp_ptr = zero_points_tmp.data(); - } - - concurrency::ThreadPool::TryBatchParallelFor( - thread_pool, - total_block_count, - [&](ptrdiff_t block_idx) { - int32_t n = static_cast(block_idx / block_per_K); - int32_t k_block_idx = static_cast(block_idx % block_per_K); - int32_t k = k_block_idx * block_size; - BlockwiseQuantBlock* blob_ptr = dst_blob + block_idx; - size_t offset = SafeInt(k) * N + n; - if (nullptr != zero_points_tmp_ptr) { - blob_ptr->quant(src + offset, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N); - } else { - blob_ptr->quant(src + offset, scale[block_idx], k, K, N); - } - }, - 0); - - if (bits <= 4 && zero_points != nullptr) { // compact zero points - for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) { - zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4)); - } - if (total_block_count & 1) { - zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] & 0xf0) | zero_points_tmp[total_block_count - 1]; - } - } -} - -#define QuantizeBlockwise4Bits(block_size) \ - QuantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); - -template -void QuantizeBlockwise( - uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] - const T* src, // shape: [K, N] - T* scale, // shape: [N, block_per_K] - uint8_t* zero_points, // shape: [N, block_per_K] - int32_t block_size, - int32_t bits, - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); - - if (16 == block_size) { - QuantizeBlockwise4Bits(16); - } else if (32 == block_size) { - QuantizeBlockwise4Bits(32); - } else if (64 == block_size) { - QuantizeBlockwise4Bits(64); - } else if (128 == block_size) { - QuantizeBlockwise4Bits(128); - } else if (256 == block_size) { - QuantizeBlockwise4Bits(256); - } else { - ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); - } -} - -#undef QuantizeBlockwise4Bits - -template -void DequantizeBlockwise( - T* dst, // shape: [N, K] - const uint8_t* src, // shape: [N, block_per_K, block_blob_size] - const T* scale, // shape: [N, block_per_K] - const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - int32_t block_per_K = (K + block_size - 1) / block_size; - int32_t task_count = N * block_per_K; - - const BlockwiseQuantBlock* src_blob = - reinterpret_cast*>(src); - - concurrency::ThreadPool::TryBatchParallelFor( - thread_pool, - task_count, - [&](ptrdiff_t task_idx) { - int32_t n = static_cast(task_idx / block_per_K); - int32_t k_block_idx = static_cast(task_idx % block_per_K); - int32_t k = k_block_idx * block_size; - const BlockwiseQuantBlock* blob_ptr = src_blob + task_idx; - size_t offset = SafeInt(n) * K + k; - if (nullptr != zero_points) { - if constexpr (bits > 4) { // zero point is stored with a byte - blob_ptr->dequant(dst + offset, scale[task_idx], zero_points[task_idx], k, K); - } else { // zero points is stored with 4bits - uint8_t zp = zero_points[task_idx / 2]; - zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf); - blob_ptr->dequant(dst + offset, scale[task_idx], zp, k, K); - } - } else { - blob_ptr->dequant(dst + offset, scale[task_idx], k, K); - } - }, - 0); -} - -#define DequantizeBlockwise4Bits(block_size) \ - DequantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); - -template -void DequantizeBlockwise( - T* dst, // [N, K] - const uint8_t* src, // [N, block_per_K, block_blob_size] - const T* scale, // [N, block_per_K] - const uint8_t* zero_points, // [N, block_per_K] - int32_t block_size, - int32_t bits, - int32_t N, - int32_t K, - onnxruntime::concurrency::ThreadPool* thread_pool) { - ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); - - if (16 == block_size) { - DequantizeBlockwise4Bits(16); - } else if (32 == block_size) { - DequantizeBlockwise4Bits(32); - } else if (64 == block_size) { - DequantizeBlockwise4Bits(64); - } else if (128 == block_size) { - DequantizeBlockwise4Bits(128); - } else if (256 == block_size) { - DequantizeBlockwise4Bits(256); - } else { - ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); - } -} - -#undef DequantizeBlockwise4Bits - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc index 2f3ede49c3650..b898c956b6e6a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -21,6 +21,8 @@ class MatMulBnb4 final : public OpKernel { ORT_ENFORCE( quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); } Status Compute(OpKernelContext* context) const override; @@ -30,6 +32,8 @@ class MatMulBnb4 final : public OpKernel { int64_t N_; int64_t block_size_; int64_t quant_type_; + bool is_training_mode_; + bool transB_; }; Status MatMulBnb4::Compute(OpKernelContext* ctx) const { @@ -58,7 +62,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const { thread_pool); constexpr bool transa = false; - constexpr bool transb = true; + const bool transb = transB_; TensorShape b_shape({N_, K_}); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 57aada94be39c..c72d811170a27 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -5,8 +5,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "dequantize_blockwise.h" -#include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace contrib { @@ -18,6 +17,9 @@ class MatMulNBits final : public OpKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); } Status Compute(OpKernelContext* context) const override; @@ -27,6 +29,7 @@ class MatMulNBits final : public OpKernel { int64_t N_; int64_t block_size_; int64_t nbits_; + bool column_wise_quant_{true}; }; Status MatMulNBits::Compute(OpKernelContext* ctx) const { @@ -46,15 +49,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - DequantizeBlockwise(tmp_b_data_ptr.get(), - b_data, - scales_data, - zero_points_data, - static_cast(block_size_), - static_cast(nbits_), - static_cast(N_), - static_cast(K_), - thread_pool); + + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + zero_points_data, // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0dc7de0e9e519..bf6431cf1afb2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,8 +135,24 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { @@ -279,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb1..16ce3a899fb5e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); @@ -372,6 +374,7 @@ Status EfficientAttention( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; @@ -393,6 +396,7 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d6..3e78978c3cc43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index ed330b0fca332..51c3d3d3a458b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } } constexpr auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f725be8d7cf89..f16567bb6f2b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index ff7a22d253a5b..89a27c4d2b0d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits, bool new_kv, bool is_sm8x) { + int max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea return 1; } +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 0a0328edb0059..58f4304251872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,6 +31,7 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, ); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 784335a124c75..82dfa59b8f8e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -123,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 65d19d4473872..8694dc998c7a8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -6,9 +6,8 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" -// #include "contrib_ops/cpu/utils/console_dumper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -55,6 +54,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_flash_attention_ = true; #endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif } template @@ -92,18 +98,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.hidden_size); Tensor* output = context->Output(0, output_shape); - std::vector present_dims; - if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - present_dims = { - parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; - } else { // BNSH - present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; - } - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, @@ -116,22 +110,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t out_accum_bytes = 0; size_t seqlens_k_bytes = 0; if (use_flash_attention) { + // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffers - parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, - parameters.head_size, device_prop.multiProcessorCount, 128, false, - device_prop.major == 8 && device_prop.minor > 0); - if (parameters.num_splits > 1) { - // softmax_lse_accum buffer - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); - // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(parameters.head_size, 32); - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); - } + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; // seqlens_k buffer if (past_key != nullptr) { seqlens_k_bytes = sizeof(int) * parameters.batch_size; @@ -149,8 +137,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - // only kernel implemented for gqa right now - ORT_ENFORCE(use_flash_attention); +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); @@ -161,6 +188,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -173,6 +201,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (seqlens_k_buffer != nullptr) { data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 72c9814fad670..a90418ec2243a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; float scale_; bool disable_flash_attention_; + bool disable_memory_efficient_attention_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index be8f5ca0ae3e9..8c21de9ced058 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query, // query (Q) : (B, S, D) // key (K) : (B, S+, D_kv) // value (V) : (B, S+, D_kv) + ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BSNH; const auto& query_dims = query->Shape().GetDims(); const auto& key_dims = key->Shape().GetDims(); - const auto& value_dims = value->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query, int q_hidden_size = static_cast(query_dims[2]); int head_size = static_cast(q_hidden_size) / num_heads; - int kv_sequence_length = sequence_length; - int kv_hidden_size = (key_dims.size() == 3) - ? static_cast(key_dims[2]) - : (kv_num_heads * static_cast(key_dims[3])); + int kv_sequence_length = static_cast(key_dims[1]); + int kv_hidden_size = static_cast(key_dims[2]); int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { @@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - if (key_dims[2] != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else { + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing key tensor."); + "Input 'query' and 'key' shall have same dim 0 (batch size)"); } - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { + if (static_cast(kv_sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing value tensor."); + "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); } // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. int32_t past_sequence_length = 0; - int present_sequence_length = 0; + int present_sequence_length = kv_sequence_length; if (past_seq_len != nullptr) { + if (past_key == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past KV must be present as share-buffer when using past_seq_len pointer."); + } if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "past_sequence_length tensor must be of one element when using past kv."); @@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query, } else { past_sequence_length = static_cast(*((*past_seq_len).template Data())); } + if (past_sequence_length + kv_sequence_length > max_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length"); + } present_sequence_length = max_sequence_length; } else if (past_key != nullptr) { past_sequence_length = max_sequence_length; // this is the length of past_key tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ab3029ca34886..0455825c364a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -47,6 +48,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { +////////// Auxiliary Kernels for KV prep + // Kernel for seqlens_k __global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { int id = blockDim.x * blockIdx.x + threadIdx.x; @@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int present_head_stride = is_bsnh ? H : present_seqlen * H; // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH + // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L const int past_seqlen = present_seqlen - new_seqlen; @@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, } } +// Use when (H*)*num_heads > 1024 template __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int H, + const int num_heads, const T* past_kv, const T* new_kv, T* present_kv, const bool is_bsnh) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; - const int present_seqlen = gridDim.x; - const int num_heads = blockDim.y; - const int thread_stride = blockDim.x; - - const int present_batch_stride = present_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; - - // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH - // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; - - while (h < H) { int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { const int past_batch_stride = past_seqlen * num_heads * H; @@ -135,133 +137,477 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; present_kv[out_offset] = new_kv[in_offset]; } - h += thread_stride; } } +// Concat new to past in present. Supports past BSNH or past BNSH template -Status QkvToContext( +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int past_seqlen, + const int present_seqlen, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int past_seqlen, + const int present_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, + cudaStream_t stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { - assert(data.use_flash_attention); + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if USE_FLASH_ATTENTION - auto stream = static_cast(ort_stream->GetHandle()); + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + if (data.past_key != nullptr && data.past_key == data.present_key) { + // Share buffer case + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + + } else { + // Not share buffer or no past (prompt generation) + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, + sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.kv_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; const int present_sequence_length = parameters.present_sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale; - if (data.use_flash_attention) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - assert(parameters.num_heads % parameters.kv_num_heads == 0); - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - - bool is_causal = parameters.is_unidirectional; - - if (data.past_key == nullptr && data.present_key == nullptr) { - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse), - parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size, - parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); - - } else if (data.past_key == data.present_key) { - // Assume past and present kv share buffer. - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - assert(parameters.past_sequence_length >= 0); - assert(data.past_value != nullptr); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); - - } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) { - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (head_size % 4 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4"); - } - const int H = head_size / 4; - if (H * kv_num_heads <= max_threads_per_block) { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(H, kv_num_heads, 1); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } else { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), - batch_size, num_heads, kv_num_heads, head_size, - sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + if (data.past_key != nullptr) { + // Past key case + // concatenate new kv to past kv + if (data.past_key == data.present_key) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); } + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + } else if (num_heads == kv_num_heads) { + // no past or present and no need to ungroup... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // intermediate buffer so q and kv have same num heads... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length, + kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream, + max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = past_sequence_length + kv_sequence_length; + p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif - return Status::OK(); +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } #endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 0bad9eeb61231..8412631078e6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -14,19 +14,28 @@ namespace cuda { template struct GroupQueryAttentionData { + // Input Tensors const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; + // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; int* seqlens_k = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors T* output = nullptr; T* present_key = nullptr; T* present_value = nullptr; + // Kernel Flags bool use_flash_attention = false; + bool use_memory_efficient_attention = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e3f53ca6a63cb..ebd66d8c6528e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -291,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index aba0efdbd7d5f..d7aeef1501cd6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e09fd9e6b36e5..3fe9dbf8ed34a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -688,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -702,6 +703,7 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index e4b09b00f030c..973ef8d304e2e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,11 +51,11 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192}; constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { for (size_t i = 0; i < kNumOfSizes; ++i) { @@ -63,14 +63,13 @@ int NextSize(int x) { return kSizes[i]; } } - return kMaxSize; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias, - const T* gamma, const T* beta, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && reinterpret_cast(output) % alignment == 0 && reinterpret_cast(sum_output) % alignment == 0 && reinterpret_cast(input) % alignment == 0 && @@ -78,8 +77,8 @@ bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, cons reinterpret_cast(bias) % alignment == 0 && reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace @@ -187,8 +186,14 @@ void LaunchSkipLayerNormKernel( int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); - bool flag_vec4 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); #define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ SkipLayerNormKernelSmall<<>>( \ @@ -198,39 +203,42 @@ void LaunchSkipLayerNormKernel( SkipLayerNormKernel<<>>( \ output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - constexpr int block_size = 256; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 8 * 256) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ } break switch (next_size) { - case kSizes[0]: { - constexpr int block_size = kSizes[0]; - // TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size. - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); - break; - } + CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); CASE_NEXT_SIZE(kSizes[3]); CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); - // kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize. + CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); default: { constexpr int block_size = 256; LAUNCH_SKIP_LAYER_NORM_KERNEL(); diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 0000000000000..967f30a304ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 0000000000000..2939852c75c60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 2618fe4a238bd..8e157da6cb43f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -97,6 +97,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -174,6 +175,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); #endif template <> @@ -269,6 +279,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -352,6 +363,15 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..87e88ac31c998 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4be77..48b161552ce0c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050ee6..9532aeecb2f57 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 8c328d00b44d0..7921315ab52e1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -18,6 +18,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { + __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); @@ -61,15 +62,19 @@ __global__ void Dequantize4BitsKernel( const T* scale_data, const uint8_t* zero_points, int block_size, + int blocks_per_K, int blocks_per_threadblock, int shift) { int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + int n_idx = block_id / blocks_per_K; + int kb_idx = block_id % blocks_per_K; int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); uint8_t zp = 8; if (zero_points) { - zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f); + zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); } output = output + element_offset; @@ -100,6 +105,7 @@ Status Dequantize4Bits( scales_data, zero_points, block_size, + blocks_per_K, blocks_per_threadblock, shift); @@ -126,6 +132,244 @@ template Status Dequantize4Bits( int block_size, cudaStream_t stream); + +/////////////////////////////////////////////////////////////////////////////// +// A more general block-wise dequantization implementation that supports +// different block sizes and block orientations (row-wise/column-wise). + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +/** + * @brief Blockwise quantization constants + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlkQuantTraits { + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D; +}; + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ +void dequantizeThread(ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int rows, + int columns, + int thrd_row_blks) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // !! 4b specific code + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + const auto meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + // quantized matrix is stored in column major, packed by column + const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + + int32_t r_blk_idx = static_cast(block_idx % thrd_row_blks); + int32_t c_blk_idx = static_cast(block_idx / thrd_row_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets + const ElementT scale_buf[2] = { + scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], + ((r/QuantBlk::kRow) < (meta_rows - 1)) + ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] + : static_cast(0.0f)}; + const uint8_t zp_pair = (zero_points == nullptr) + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; + const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), + (-scale_buf[1]) * static_cast(zp_buf[1])}; + + for (int32_t j = c; j < c_end; ++j) { + const uint8_t* q_ptr = weights + j * q_rows; + for (int32_t i = r; i < (r_end - 1); i += 2) { + const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; + + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[i / 2]; + + if constexpr (std::is_same::value) { + half2 scale_half2 = {scale0, scale1}; + half2 zp_adjust2 = {adjust0, adjust1}; + + half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)}; + half2 results = v * scale_half2 + zp_adjust2; + + dst[j * rows + i] = results.x; + dst[j * rows + (i + 1)] = results.y; + } else { + static_assert(std::is_same::value, "Only float and half are supported!"); + const uint8_t vi0 = vi & 0xf; + const uint8_t vi1 = vi >> 4; + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; + } + } + + if ((r_end & 1) && (r_end > r)) { + const auto scale0 = scale_buf[(r_end - 1 - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(r_end - 1 - r) / QuantBlk::kRow]; + + const auto vi = q_ptr[(r_end - 1) / 2]; + const uint8_t vi0 = vi & 0xf; + + dst[j * rows + (r_end - 1)] = static_cast(vi0) * scale0 + adjust0; + } + } +} + +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto grids = (total_thrd_blks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; + dequantizeThread<<>>( + dst, + weights, + scales, + zero_points, + rows, + columns, + thrd_row_blks); +} + + +template +Status +DequantizeBlockwise4b( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream) { + switch (block_size) { + case 16: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 32: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 64: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 128: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + case 256: + if (columnwise) { + dequantize(dst, src, scales, zero_points, rows, + columns, stream); + } else { + dequantize(dst, src, scales, zero_points, + rows, columns, stream); + } + return Status::OK(); + default: + // Only block size 16, 32, 64, 128, 256 are supported. + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL, + "Unsupported block size for blockwise quantization."); + } +} + +template +Status DequantizeBlockwise4b( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +template +Status DequantizeBlockwise4b( + half* dst, + const uint8_t* src, + const half* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index 741ce1e735b42..f9c09c55fd893 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -18,6 +18,33 @@ Status Dequantize4Bits( int block_size, cudaStream_t stream); + +/** + * @brief Dequantize a block-wise quantized matrix, and store the result in a + * column major matrix for use in subsequent GEMM. This implementation supports + * columnwise and rowwise block orientation. + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] qelements pointer to the quantized elements, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] block_size size of the quantized block + * @param[in] columnwise whether the quantized matrix is columnwise or rowwise quantized + * @param[in] rows + * @param[in] columns + */ +template +Status DequantizeBlockwise4b( + T* dst, + const uint8_t* qelements, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bd5b6e0a8a1ce..ecf332715d470 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -25,6 +25,9 @@ class MatMulBnb4 final : public CudaKernel { ORT_ENFORCE( quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); } Status ComputeInternal(OpKernelContext* context) const override; @@ -34,6 +37,8 @@ class MatMulBnb4 final : public CudaKernel { int64_t N_; int64_t block_size_; int64_t quant_type_; + bool is_training_mode_; + bool transB_; }; template @@ -59,7 +64,7 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { static_cast(ctx->GetComputeStream()->GetHandle()))); constexpr bool transa = false; - constexpr bool transb = true; + const bool transb = transB_; MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR( @@ -69,17 +74,18 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMulBnb4( - reinterpret_cast(quant_map_buffer_data), - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - b_quant_data, - reinterpret_cast(absmax_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle())); + bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode + && TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); if (!is_4bit_done) { IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); @@ -98,16 +104,16 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, + transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + CUBLAS_OP_N, // transA SafeInt(helper.N()), SafeInt(helper.M()), SafeInt(helper.K()), &alpha, reinterpret_cast(b_dequant_data), - SafeInt(K_), + helper.Ldb(transb), // ldb reinterpret_cast(a_data), - helper.Lda(transa), + helper.Lda(transa), // lda &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 14a8163fef500..5b0e61e197014 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -27,6 +27,9 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); } Status ComputeInternal(OpKernelContext* context) const override; @@ -36,6 +39,7 @@ class MatMulNBits final : public CudaKernel { int64_t N_; int64_t block_size_; int64_t nbits_; + bool column_wise_quant_blk_{true}; }; template @@ -50,8 +54,6 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); - ORT_ENFORCE(nbits_ == 4, "only 4 bits is supported now"); - typedef typename ToCudaType::MappedType CudaT; constexpr bool transa = false; @@ -81,14 +83,32 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); auto* b_data = b_data_ptr.get(); - ORT_RETURN_IF_ERROR(Dequantize4Bits(reinterpret_cast(b_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(K_padded), - SafeInt(N_), - SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + if (column_wise_quant_blk_) { + // column-wise block + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); T* b_data_cpu = new T[K_ * N_]; diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 4c3c345076416..f2600a506285d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -96,6 +96,9 @@ __global__ void MatMulFloatInt4Kernel( constexpr int k_per_iter = 256; int k_iter = k / k_per_iter; + // blocks_per_k is the number of scales and zero points on the k dim + const int b_zp_k = (blocks_per_K + 1)/ 2; + extern __shared__ char shared_buffer[]; // load scale to shared buffer @@ -105,30 +108,39 @@ __global__ void MatMulFloatInt4Kernel( for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { b_scale_vec[i] = scales_data[offset + i]; } - for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K / 2; i += kColsPerThreadBlock * kWarpSize) { - b_zp_vec[i] = zero_points != nullptr ? zero_points[offset / 2 + i] : uint8_t(0x88); + + int zp_offset = n_block_id * kColsPerThreadBlock * b_zp_k; + for (int i = thread_id; i < kColsPerThreadBlock * b_zp_k; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[zp_offset + i] : uint8_t(0x88); } __syncthreads(); a_data += m_id * k; b_data_quant += n_id * blocks_per_K * (block_size / 2); + const int scale_col_offset = warp_id * blocks_per_K; + const int zp_col_offset = warp_id * b_zp_k; + float sum = 0.f; int k_id = 0; for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { - uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4)); - int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; - T scale = b_scale_vec[block_idx]; - uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point + uint32_t value = *(reinterpret_cast(b_data_quant + (t_k >> 1))); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } // handle reminder if (k_id + lane_id * 8 < k) { + const int t_k = k_id + (lane_id << 3); // k index for this thread + const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); - int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; - T scale = b_scale_vec[block_idx]; - uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + T scale = b_scale_vec[scale_col_offset + t_meta_k]; + uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; + zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 0000000000000..76402f0681976 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 4641b006a7785..24d327576ecd9 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -11,6 +11,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -23,7 +24,9 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 246b66078537a..78983ac95e672 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { Nop{}); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); if constexpr (USE_MASK) { ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index cbf24ee2f5487..ea9040aa7875f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -58,7 +58,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); auto nop = Nop{}; auto addfastgelu = AddFastGelu{}; @@ -67,7 +67,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { params->lda, params->ldb, std::array{0}, params->ldc, nop, nop, addfastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -95,7 +95,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); auto nop = Nop{}; auto fastgelu = FastGelu{}; @@ -108,7 +108,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { params->ldc, nop, nop, fastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af36c..e82e15a304f4c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index e87813fb19956..0146e81c6cf8c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -79,7 +79,7 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { nullptr, activation); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index c4eef5b27c1bb..b2ef853119588 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -62,8 +62,13 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; - errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + errormsg << prefix; + const auto& domain = node.Domain(); + if (!domain.empty()) { + errormsg << domain << "."; + } + errormsg << node.OpType() << "(" << node.SinceVersion() << ")" + << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; if (!status.IsOK()) errormsg << status.ErrorMessage(); diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 96b4cc53a022c..6d2dd641f6bc6 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -232,14 +232,15 @@ class TunableOp { return timer.Duration() / num_iter; } - static bool IsSupported(Op& op, const ParamsT* param) { - Status status = op.IsSupported(param); + // Filter all Status, only OK and TUNABLE_OP_UNSUPPORTED is left, other error status will be thrown, and to be + // processed by onnxruntime. We return Status to avoid the construction of op and params signature string. + static Status IsSupported(Op& op, const ParamsT* params) { + Status status = op.IsSupported(params); if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) { - LOGS_DEFAULT(VERBOSE) << "unsupported reason: " << status.ErrorMessage(); - return false; + return status; } ORT_THROW_IF_ERROR(status); - return true; + return status; } protected: @@ -250,9 +251,9 @@ class TunableOp { int FindFastestImpl(const ParamsT* params, const std::vector>& candidates) { ITuningContext* ctx = params->TuningContext(); auto op_sig = Signature(); - auto param_sig = params->Signature(); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')'; - auto min_time = std::numeric_limits::infinity(); + auto params_sig = params->Signature(); + LOGS_DEFAULT(VERBOSE) << "finding fastest for " << op_sig << '(' << params_sig << ')'; + auto min_duration_ms = std::numeric_limits::infinity(); int id = -1; constexpr const int max_tuning_iter = 100; @@ -260,30 +261,32 @@ class TunableOp { for (size_t i = 0; i < candidates.size(); i++) { auto& candidate = const_cast&>(candidates[i]); - if (!IsSupported(candidate, params)) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i; + auto status = IsSupported(candidate, params); + if (!status.IsOK()) { + LOGS_DEFAULT(VERBOSE) << "├──unsupported id=" << i << ", " << op_sig << '(' << params_sig << ")"; + LOGS_DEFAULT(VERBOSE) << "│ reason: " << status.ErrorMessage(); continue; } WarmUp(candidate, params); auto approx_duration = Profile(candidate, params, approx_num_iter); - if (approx_duration > 2 * min_time) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl skip slow instance " << op_sig << '(' << param_sig << ") id=" << i; + if (approx_duration > 2 * min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──skip slow instance id=" << i; continue; } int tuning_iter = std::max(1, int(std::min(double(max_tuning_iter), ctx->GetMaxTuningDurationMs() / approx_duration))); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl run instance " << op_sig << '(' << param_sig << ") id=" << i << " " << tuning_iter << " times."; - - auto time = Profile(candidate, params, tuning_iter); - if (time < min_time) { - min_time = time; + auto duration_ms = Profile(candidate, params, tuning_iter); + if (duration_ms < min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──found better instance, new best id=" << i << ", old id=" << id << ". " + << duration_ms << "ms, " << tuning_iter << " iters."; + min_duration_ms = duration_ms; id = static_cast(i); } } ORT_ENFORCE(id >= 0, "Could not find viable op"); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ") found fastest with id=" << id; + LOGS_DEFAULT(VERBOSE) << "└──found fastest with id=" << id << " for " << op_sig << '(' << params_sig << ")"; std::this_thread::sleep_for(std::chrono::milliseconds(50)); return id; } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 76c3f8716ff09..5bc18a4e69b47 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1051,15 +1051,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .Output(2, "present_value", "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 070df487a264d..8b5b561c1ad87 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -273,6 +273,129 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e757e39130d39..39449bea6303a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2693,7 +2693,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, - int64_t N) { + int64_t N, + bool transB) { int input_a_idx = 0; if (!hasInputShape(ctx, input_a_idx)) { return; @@ -2707,15 +2708,15 @@ static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext // TODO: check B shape const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); - if (dim_last.has_dim_value() && dim_last.dim_value() != K) { + ONNX_NAMESPACE::TensorShapeProto resultShape; + if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) { fail_shape_inference("Incompatible dimensions for matrix multiplication"); } - ONNX_NAMESPACE::TensorShapeProto resultShape; for (int i = 0; i < a_shape.dim_size() - 1; ++i) { *resultShape.add_dim() = a_shape.dim(i); } - resultShape.add_dim()->set_dim_value(N); + resultShape.add_dim()->set_dim_value(transB ? N : K); *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; } @@ -3354,7 +3355,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored // Shape inference int64_t in_features = getAttribute(ctx, "K", -1); int64_t out_features = getAttribute(ctx, "N", -1); - MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); }); static const char* MatMulBnb4_ver1_doc = R"DOC( @@ -3364,8 +3365,30 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. -Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. -Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] )DOC"; @@ -3377,6 +3400,12 @@ Input absmax is stored in same type as original type of B(float32, float16) with .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Attr("training_mode", + "Indicate if the ops run in training_mode, by default, False.", + AttributeProto::INT, + static_cast(0)) + .Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.", + AttributeProto::INT, static_cast(1)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") @@ -3389,7 +3418,8 @@ Input absmax is stored in same type as original type of B(float32, float16) with // Shape inference int64_t in_features = getAttribute(ctx, "K", -1); int64_t out_features = getAttribute(ctx, "N", -1); - MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + bool transB = getAttribute(ctx, "transB", 1) != 0; + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB); }); #ifdef ENABLE_ATEN diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa6149b..f81c3b8e0182c 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index d3fc5873cb274..03ad95260c0ad 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -90,12 +90,16 @@ void RegisterNHWCSchemaWithActivation(const RegistrationFunc& f, ::ONNX_NAMESPAC void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function& fn) { // if the operator may be fused with an activation, use the WITH_ACTIVATION variant to add optional attributes // for the activation parameters. - // For now we only register operators from opset 11 on. Models can easily have their opset updated using ONNX tools + // We mainly register operators from opset 11 on . Models can easily have their opset updated using ONNX tools // so supporting older opsets is unnecessary. + // Older opsets are included on a per-operator basis as needed. // NOTE: This should be in sync with GetLayoutSensitiveOps in // /onnxruntime/core/optimizer/transpose_optimization/transpose_optimizer.cc + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 7); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 10); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 11); + REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, AveragePool, 19); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 9); REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 14); @@ -106,16 +110,18 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index cab9467501f55..4b3cafcb39b78 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3908,7 +3908,8 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std // kernel lookup works as per usual, if not using an existing schema. if (sub_graph.schema_source == IndexedSubGraph::SourceOfSchema::EXISTING) { ORT_ENFORCE(SetOpSchemaFromRegistryForNode(fused_node), - "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType()); + "Schema was not found for fused node. Domain:", fused_node.Domain(), " OpType:", fused_node.OpType(), + " SinceVersion:", fused_node.SinceVersion()); } else if (IndexedSubGraph::SourceOfSchema::REUSE_OR_CREATE == sub_graph.schema_source) { auto schema_key = GenerateSchemaKey(sub_graph); if (reusable_fused_schema_map_.count(schema_key) == 0) { diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 076332a65c8f2..b3935e69ad7b1 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -232,6 +232,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, } } + // special-case the internal NHWC domain as it must match the ONNX opset if not explicitly imported + if (domain_to_version.find(kMSInternalNHWCDomain) == domain_to_version.end()) { + auto onnx_version = domain_to_version.find(kOnnxDomain); + if (onnx_version != domain_to_version.end()) { + domain_to_version[kMSInternalNHWCDomain] = onnx_version->second; + } + } + auto domain_map = allow_official_onnx_release_only_final ? schema_registry->GetLastReleasedOpsetVersions(false) : schema_registry->GetLatestOpsetVersions(false); diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index f3bc2a2434ab3..7c7b729117e4a 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -39,7 +39,7 @@ typedef enum { * @brief Computes the number of bytes required to pack and int4-quantize * a weight matrix * @param QType type of block quantization - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @return size of the packing buffer, 0 if the operation is not yet supported. */ @@ -53,11 +53,11 @@ MlasQ4GemmPackBSize( /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * + * * @param QType type of block quantization * @param PackedBuf destination buffer * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. + * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @param ldb leading dimension of B */ @@ -257,14 +257,14 @@ MlasBlockwiseQuantMetaShape( * matrix shape [rows, columns], compute the shape of the * quantized matrix [q_rows, q_cols]. The quantized matrix * is in column major layout, with bits packed on the column. - * - * @tparam T - * @param block_size - * @param columnwise - * @param rows - * @param columns - * @param q_rows - * @param q_cols + * + * @tparam T + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols */ template void @@ -283,21 +283,22 @@ MlasBlockwiseQuantizedShape( * parameters (scales, zero points) are packed into separate matrices * all in column major layout for faster access during subsequent matrix * multiplication. - * + * * @tparam ElementT type of the input matrix element, usually floating point - * + * @tparam qbits number of bits used for quantization, 4 for int4 + * * @param dst points to the quantized matrix, shape [rows, columns] column major - * @param scales points to the scales matrix, column major + * @param scales points to the scales matrix, column major * @param zero_points points to the zero_points matrix, column major * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row - * @param rows - * @param columns - * @param leading_dimension - * @param thread_pool + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool */ -template +template void MlasQuantizeBlockwise( uint8_t* dst, @@ -318,19 +319,21 @@ MlasQuantizeBlockwise( * parameters (scales, zero points) are from separate matrices packed * in column major layout. Output is a floating point matrix in column * major layout for faster access during subsequent matrix multiplication. - * + * * @tparam ElementT type of the dequantized matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * * @param dst points to dequantized matrix shape [rows, columns] column major * @param src points to quantized matrix, column major * @param scales points to quantization scales, column major * @param zero_points points to quantization zero points, column major * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row - * @param rows - * @param columns - * @param thread_pool + * @param rows + * @param columns + * @param thread_pool */ -template +template void MlasDequantizeBlockwise( ElementT* dst, diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 24a2212ba0714..fbd1030de8ab7 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -364,7 +364,7 @@ range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) } else { zp = (uint8_t)roundf(zero_point_fp); } - scale = static_cast(scale_f); + scale = ScaleT(scale_f); } template @@ -377,7 +377,7 @@ range2scale(float min, float max, ScaleT& scale) max = fabsf(max) > fabsf(min) ? max : min; - scale = static_cast(max / mid_fp); + scale = ScaleT(max / mid_fp); }; @@ -773,7 +773,7 @@ MlasBlockwiseQuantizedShape( ); -template +template void MlasQuantizeBlockwise( uint8_t* dst, @@ -791,50 +791,50 @@ MlasQuantizeBlockwise( switch (block_size) { case 16: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 32: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 64: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 128: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 256: if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { - BlockwiseQuantizer::quantizeAndTranspose( + BlockwiseQuantizer::quantizeAndTranspose( dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; @@ -847,7 +847,7 @@ MlasQuantizeBlockwise( template void -MlasQuantizeBlockwise( +MlasQuantizeBlockwise( uint8_t* dst, float* scales, uint8_t* zero_points, @@ -860,8 +860,23 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + -template +template void MlasDequantizeBlockwise( T* dst, @@ -878,46 +893,46 @@ MlasDequantizeBlockwise( switch (block_size) { case 16: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 32: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 64: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 128: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; case 256: if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); } break; @@ -929,7 +944,7 @@ MlasDequantizeBlockwise( template void -MlasDequantizeBlockwise( +MlasDequantizeBlockwise( float* dst, const uint8_t* src, const float* scales, diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index c090ab2a6cc9b..d27603e4ab3a1 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -4,7 +4,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include - +#include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" @@ -174,9 +174,29 @@ using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtime_state) const override { + const auto& domain = runtime_state.selected_nodes.Target().Domain(); + const auto& op_type = runtime_state.selected_nodes.Target().OpType(); + if (domain == kOnnxDomain) { + if (op_type == "Conv") { + return "FusedConv"; + } + } else if (domain == kMSDomain) { + if (op_type == "NhwcConv") { + return "NhwcFusedConv"; + } + } else if (domain == kMSInternalNHWCDomain) { + if (op_type == "Conv") { + return "Conv"; + } + } + ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain); + } - std::string Domain(const RuntimeState&) const override { return kMSDomain; } + std::string Domain(const RuntimeState& runtime_state) const override { + auto domain = runtime_state.selected_nodes.Target().Domain(); + return domain == kOnnxDomain ? kMSDomain : domain; + } NodeAttributes ExtraAttributes(const RuntimeState& state) const override { NodeAttributes extra_fused_conv_attributes; @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain); + const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + + registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}}, std::move(selector), std::move(action)); #else registry.RegisterAction(name, std::move(action)); diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 7c8bfeaec5f0f..6f90eaf07ef4d 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew { void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { auto action = std::make_unique(); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, + std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}}, std::move(selector), std::move(action)); - auto action_nhwc = std::make_unique(); - auto selector_nhwc = std::make_unique(); - registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, - std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5a441b1d1701e..c1397e92d9d26 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -51,6 +51,7 @@ #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" +#include "core/optimizer/pad_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -128,6 +129,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); @@ -270,11 +272,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = @@ -296,7 +299,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7c087ec77d9fe..959fcd6efdc3c 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -32,7 +32,7 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, int64_t to_type, onnxruntime::ProviderType providerType) { // insert cast op to cast input - std::string node_name = graph.GenerateNodeName("InsertedCast_" + old_arg->Name()); + std::string node_name = graph.GenerateNodeName("InsertedPrecisionFreeCast_" + old_arg->Name()); auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); @@ -235,7 +235,8 @@ enum TypeGroup { Unknown = -1, Bool = 0, Integer = 1, - Float = 2, + Unsigned = 2, + Float = 3, }; TypeGroup GetTypeGroup(DataType type) { @@ -243,11 +244,14 @@ TypeGroup GetTypeGroup(DataType type) { return Bool; } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { return Integer; } + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { return Float; } @@ -255,6 +259,22 @@ TypeGroup GetTypeGroup(DataType type) { return Unknown; } +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -262,6 +282,48 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. + + // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. + // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + return true; + } + + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { auto output_args = graph.GetOutputs(); InlinedHashSet graph_outputs; @@ -293,17 +355,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // - for each consumer cast node, it meets above condition for this optimization. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - TypeGroup src_type_group = GetTypeGroup(src_type); - TypeGroup dst_type_group = GetTypeGroup(dst_type); - if (src_type_group == Unknown || dst_type_group == Unknown) { - continue; - } - - bool loss_precision_cast = false; - if (src_type_group > dst_type_group) { - loss_precision_cast = true; - } + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); size_t num_children = node.GetOutputEdgesCount(); bool inconsistent_casts = false; @@ -312,10 +365,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - TypeGroup src_type_group1 = GetTypeGroup(src_type1); - TypeGroup dst_type_group1 = GetTypeGroup(dst_type1); - if (src_type_group1 == Unknown || dst_type_group1 == Unknown || - (loss_precision_cast && dst_type_group1 > src_type_group1)) { + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; break; } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 290380cabb036..4505d4afdf1e0 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -66,17 +66,6 @@ bool ConvertNodeLayout(const api::NodeRef& node) { const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // handle special cases -#if defined(USE_XNNPACK) - if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) { - if (node.OpType() == "Resize") { - // XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize - // with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through - // the Resize during transpose optimization. - return false; - } - } -#endif - #if defined(USE_JSEP) // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed if (node.GetExecutionProviderType() == kJsExecutionProvider) { diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc new file mode 100644 index 0000000000000..b25e7618802dd --- /dev/null +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/pad_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +/* + * It matches following pattern: + * Pad + * | + * Conv/MaxPool + */ +bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + // if Pad has input axis, don't fuse it. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) || + node.GetOutputEdgesCount() != 1 || + node.InputDefs().size() > 3) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + const Node& child_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { + return false; + } + + // Don't fuse if MaxPool has optional output indices tensor because output indices tensor + // does not incorporate pad values. Basically if we allow the fusion, then dimension values + // of input tensor < dimension values of input tensor without fusion. + // This will cause the range of values for output indices tensor to be less than what it + // should have been. + + if (child_node.OutputDefs().size() > 1) { + return false; + } + + // conv or maxpool node must use explicit padding to perform this fusion. + if (child_node.GetAttributes().find("auto_pad") != child_node.GetAttributes().end() && + child_node.GetAttributes().at("auto_pad").s() != "NOTSET") { + return false; + } + + const NodeAttributes& pad_attributes = node.GetAttributes(); + if (pad_attributes.find("mode") != pad_attributes.end() && + pad_attributes.at("mode").s() != "constant") { + return false; + } + + // Since opset 11, and moved to inputs. + // Both of these should be initializer because we have to verify the values. + if (node.SinceVersion() >= 11) { + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + (node.InputDefs().size() > 2 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2]))) { + return false; + } + + // constant_value should be zero because Conv and MaxPool allow only 0 as padding value. + if (node.InputDefs().size() > 2) { + const auto* pad_constant_value_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); + Initializer pad_constant_value{*pad_constant_value_proto, graph.ModelPath()}; + if (std::any_of(pad_constant_value.DataAsByteSpan().begin(), pad_constant_value.DataAsByteSpan().end(), [](const uint8_t byte) { return byte != 0; })) { + return false; + } + } + } else { + if (pad_attributes.find("value") != pad_attributes.end() && + pad_attributes.at("value").f() != 0.0) { + return false; + } + } + + return true; +} + +/* + * - For 1st two dimension Pads array's value should be zero and for rest of them values should >= 0 + */ +Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + std::vector pads_values; + + if (pad_node.SinceVersion() >= 11) { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, pad_node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + pads_values.assign(pads.DataAsSpan().begin(), pads.DataAsSpan().end()); + } else { + pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end()); + } + + assert(static_cast(pads_values.size()) == (2 * static_cast(pad_node.InputDefs()[0]->Shape()->dim_size()))); + + uint32_t pads_size = static_cast(pads_values.size()); + // check if padding is applied only on feature dims + if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 || + pads_values[pads_size / 2 + 1] != 0) { + return Status::OK(); + } + + // check if padding is only positive + if (std::any_of(pads_values.begin(), pads_values.end(), [](int64_t value) { return value < 0; })) { + return Status::OK(); + } + + Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index()); + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); + uint32_t child_pads_size = static_cast(child_pads->size()); + + for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) { + child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]); + uint32_t mirrored_child_index = child_index + (child_pads_size / 2); + uint32_t mirrored_pad_index = pads_index + (pads_size / 2); + child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]); + } + + graph_utils::RemoveNodeOutputEdges(graph, pad_node); + graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]); + graph.RemoveNode(pad_node.Index()); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h new file mode 100644 index 0000000000000..a1b6978a83d1e --- /dev/null +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a Pad operator to it's child + * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * is true. + */ +class PadFusion : public RewriteRule { + public: + PadFusion() : RewriteRule("Pad_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Pad"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index e182b6c695d2f..546d52b6f1682 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -3,9 +3,10 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include #include +#include #include +#include #include #include "core/graph/op_identifier_utils.h" @@ -56,9 +57,9 @@ const SelectorActionRegistry::Entry* SelectorActionRegistry::LookUp(const std::s } #if !defined(ORT_MINIMAL_BUILD) -auto SelectorActionRegistry::LookUpByOpType(const std::string& op_type) const +auto SelectorActionRegistry::LookUpByOpTypeAndDomain(const std::string& op_type, const std::string& domain) const -> std::vector> { - const auto [range_begin, range_end] = op_type_to_entry_.equal_range(op_type); + const auto [range_begin, range_end] = op_type_to_entry_.equal_range(OpVersionsMapKey(op_type, domain)); std::vector> result{}; result.reserve(std::distance(range_begin, range_end)); std::transform(range_begin, range_end, std::back_inserter(result), @@ -93,20 +94,15 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. - // If we ever had a transformer that was going to target non-ONNX ops, - // we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { - break; - } - std::optional node_selection_opt{}; const SelectorActionRegistry::Entry* selector_action_entry_ptr = nullptr; - const auto selector_action_entries = selector_action_registry.LookUpByOpType(node.OpType()); + const auto selector_action_entries = + selector_action_registry.LookUpByOpTypeAndDomain(node.OpType(), node.Domain()); + std::string key = SelectorActionRegistry::OpVersionsMapKey(node.OpType(), node.Domain()); for (const auto& entry : selector_action_entries) { // check the supported versions if specified - const auto& versions = entry->ops_and_versions.find(node.OpType())->second; + const auto& versions = entry->ops_and_versions.find(key)->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node.SinceVersion()) == versions.cend()) { continue; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h index 7eb162cc693f1..5caa949ebbe93 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h @@ -38,8 +38,20 @@ struct NodeSelector { // class to manage a set of selector and associated actions class SelectorActionRegistry { public: + // The key is a string representing the op, optionally specifying the domain using ':' as the + // separator with domain as the first part and operator as the second part, ":" or "". + // For ops in kOnnxDomain, the domain should be left unspecified (""). + // For ops in other domains, the domain should be specified (":"). + // Ex: "Conv", "com.microsoft:Conv", "com.ms.internal.nhwc:Conv" using OpVersionsMap = std::unordered_map>; + // Helper function to create a key to OpVersionsMap using domain and op_type. + static std::string OpVersionsMapKey(std::string_view op_type, std::string_view domain = kOnnxDomain) { + return (domain == kOnnxDomain) + ? std::string{op_type} + : std::string{domain} + ":" + std::string{op_type}; + } + struct Entry { Entry(const std::string& name_in, #if !defined(ORT_MINIMAL_BUILD) @@ -95,14 +107,15 @@ class SelectorActionRegistry { #if !defined(ORT_MINIMAL_BUILD) // return registered Entry or nullptr if not found - auto LookUpByOpType(const std::string& op_type) const -> std::vector>; + auto LookUpByOpTypeAndDomain(const std::string& op_type, + const std::string& domain) const -> std::vector>; #endif // !defined(ORT_MINIMAL_BUILD) private: std::unordered_map name_to_entry_; #if !defined(ORT_MINIMAL_BUILD) - // auxiliary mapping to enable lookup by op type + // auxiliary mapping to enable lookup by op type or "domain:op type" std::unordered_multimap op_type_to_entry_; #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index f4f3505128737..8eaac3d34c3af 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -17,7 +17,7 @@ static bool EPAwareHandleResize(HandlerArgs& args) { // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. const auto ep_type = args.node.GetExecutionProviderType(); - if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) { + if (ep_type == kCpuExecutionProvider) { // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model int64_t rank_int = gsl::narrow_cast(args.perm.size()); if (rank_int == 4) { diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index b63c0d2161ad5..dfa27f1f44d5a 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -428,4 +428,14 @@ template Status MultinomialComputeShared(AllocatorPtr& alloc, std::default_random_engine& generator, Tensor& Y); +#if !defined(DISABLE_CONTRIB_OPS) +// used by onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +template Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 3192c8573c5c0..1d524a90302e7 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -967,7 +967,7 @@ Status Xor::Compute(OpKernelContext* context) const { }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array() ^ per_iter_bh.EigenInput1().array(); + per_iter_bh.EigenInput0().array() != per_iter_bh.EigenInput1().array(); }}; UntypedBroadcastTwo(*context, funcs, 1.0); diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index 99522cdf0759a..0b3ce6f477843 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -352,7 +352,7 @@ class UpsampleBase { (scales.size() == 4 && scales[0] == 1 && scales[3] == 1) || scales.size() == 3 || (scales.size() == 5 && scales[0] == 1 && scales[1] == 1), - "'Linear' mode only support:\n" + "'Linear' mode only supports:\n" " * 2-D inputs or\n" " * 3-D inputs ('Bilinear', 'Trilinear') or\n" " * 4-D inputs with the corresponding outermost 2 scale values being 1" diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4f5469ad8de36..2d242d7d6fb12 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1228,6 +1228,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, BFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, int32_t, Where); @@ -2111,6 +2112,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index d46ed9c245a8e..bc78e577c5052 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -614,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index b3f92c913a84b..4d98b3a8a145e 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -216,5 +216,6 @@ SPECIALIZED_COMPUTE(int64_t) SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double_t) SPECIALIZED_COMPUTE(MLFloat16) +SPECIALIZED_COMPUTE(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/where_impl.cu b/onnxruntime/core/providers/cuda/tensor/where_impl.cu index 0fbd2062ca1b2..d7909454e922c 100644 --- a/onnxruntime/core/providers/cuda/tensor/where_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/where_impl.cu @@ -238,6 +238,7 @@ SPECIALIZED_IMPL(int64_t) SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double_t) SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 1db22ac92e527..5c7b7bff1e370 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -93,7 +93,7 @@ namespace Dml onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override { - ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); + ORT_THROW_HR_IF(E_UNEXPECTED, static_cast(m_subgraphInputs.size()) != kernelContext->InputCount()); bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; @@ -159,7 +159,7 @@ namespace Dml if (iter != m_inferredInputShapes.end()) { auto tensorShape = *nodeArg->Shape(); - ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions()); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != static_cast(iter->second.NumDimensions())); for (int i = 0; i < tensorShape.dim_size(); ++i) { diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 8dbd552dd0550..798244d7cb75b 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -238,18 +238,40 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Whe class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm); @@ -257,17 +279,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gem class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMax); @@ -291,11 +302,17 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather); @@ -304,11 +321,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, Resize); - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice); @@ -322,8 +334,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, LayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, InstanceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range); @@ -508,18 +521,40 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -527,17 +562,6 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -575,7 +599,7 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -594,8 +618,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index 2e07124dcd901..474fd260880ce 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -16,6 +16,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, @@ -23,6 +24,14 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Conv, + kMSInternalNHWCDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); ONNX_OPERATOR_VERSIONED_KERNEL_EX( Conv, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index fdf3e5b6c6b66..3a01a4aa46be4 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -3,23 +3,42 @@ #pragma once +#include +#include + #include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" namespace onnxruntime { namespace js { -template -class Conv : public JsKernel { +class ConvBase : public JsKernel { public: - Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { + ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info), + conv_attrs_(info), + w_is_const_(false) { + std::vector activation_params; TensorShapeVector kernel_shape; + const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size(); + std::vector local_pads(pads_vec_size, 0); + for (size_t i = 0; i < conv_attrs_.pads.size() && i < pads_vec_size; ++i) { + local_pads[i] = gsl::narrow_cast(conv_attrs_.pads[i]); + } + if (conv_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - + if (is_fused_conv) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_attrs_.activation)); + ORT_ENFORCE(info.GetAttrs("activation_params", activation_params).IsOK()); + } else { + conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); + activation_params = info.GetAttrsOrDefault("activation_params", activation_params); + } + const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr; int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); - + auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_attrs_.dilations.size() == 1 || (conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || @@ -30,44 +49,52 @@ class Conv : public JsKernel { "dilations" : [$2], "group" : $3, "kernel_shape" : [$4], - "pads" : [ $5, $6 ], + "pads" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : [], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9]) + "w_is_const" : () JS_ARROW(!!HEAP8[$9]), + "activation" : UTF8ToString($10), + "activation_params" : $11 ? Array.from(HEAPF32.subarray($12, $12 + $11)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), + static_cast(kernel_shape_0), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ - "format" : $13 ? "NHWC" : "NCHW", + "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, "dilations" : [ $2, $3 ], "group" : $4, "kernel_shape" : [ $5, $6 ], - "pads" : [ $7, $8, $9, $10 ], - "strides" : [ $11, $12 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$14]) + "pads" : $7 ? Array.from(HEAP32.subarray($8, $8 + $7)) : [], + "strides" : [ $9, $10 ], + "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "activation" : UTF8ToString($13), + "activation_params" : $14 ? Array.from(HEAPF32.subarray($15, $15 + $14)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), - static_cast(conv_attrs_.pads.size() > 2 ? conv_attrs_.pads[2] : 0), - static_cast(conv_attrs_.pads.size() > 3 ? conv_attrs_.pads[3] : 0), + static_cast(kernel_shape_0), + static_cast(kernel_shape_1), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } } @@ -94,5 +121,12 @@ class Conv : public JsKernel { // Tensor w_transposed_; }; +template +class Conv : public ConvBase { + public: + explicit Conv(const OpKernelInfo& info) : ConvBase(info, is_channels_last, is_fused_conv) { + } +}; + } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 2228343e1e6e3..2aaf438f30d4d 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -15,6 +15,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_KERNEL_EX( ConvTranspose, kOnnxDomain, @@ -22,6 +23,14 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); ONNX_OPERATOR_VERSIONED_KERNEL_EX( ConvTranspose, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 18ef73268005d..5d30dc851e00f 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -4,26 +4,45 @@ #pragma once #include +#include #include "core/common/gsl.h" #include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { TensorShapeVector kernel_shape; + if (is_fused_convtranspose) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_transpose_attrs_.activation)); + } else { + conv_transpose_attrs_.activation = info.GetAttrOrDefault("activation", ""); + } + if (conv_transpose_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), + conv_transpose_attrs_.output_shape.end()); + std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), + conv_transpose_attrs_.output_padding.end()); + const auto* local_output_padding_ptr = + local_output_padding.size() > 0 ? local_output_padding.data() : nullptr; + const auto* local_output_shape_ptr = + local_output_shape.size() > 0 ? local_output_shape.data() : nullptr; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_transpose_attrs_.dilations.size() == 1 || (conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || conv_transpose_attrs_.strides.size() == 1) { + auto dilations = conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0; + auto kernel_shape_0 = conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto pads_0 = conv_transpose_attrs_.pads.size() > 0 ? conv_transpose_attrs_.pads[0] : 0; + auto pads_1 = conv_transpose_attrs_.pads.size() > 1 ? conv_transpose_attrs_.pads[1] : 0; + auto strides = conv_transpose_attrs_.strides.size() > 0 ? conv_transpose_attrs_.strides[0] : 0; JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $8 ? "NHWC" : "NCHW", "autoPad" : $1, @@ -34,21 +53,23 @@ class ConvTranspose : public JsKernel { "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), "outputPadding" : $10 ? Array.from(HEAP32.subarray($11, $11 + $10)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [] + "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [], + "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), - static_cast(conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0), + static_cast(dilations), static_cast(conv_transpose_attrs_.group), - static_cast(conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0) ? kernel_shape[0] : 0, - static_cast(conv_transpose_attrs_.pads.size()), - static_cast(conv_transpose_attrs_.pads.size() > 1) ? conv_transpose_attrs_.pads[1] : 0, - static_cast(conv_transpose_attrs_.strides.size() > 0) ? conv_transpose_attrs_.strides[0] : 0, + static_cast(kernel_shape_0), + static_cast(pads_0), + static_cast(pads_1), + static_cast(strides), static_cast(channels_last), reinterpret_cast(&w_is_const_), - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_padding.size() > 0 ? conv_transpose_attrs_.output_padding.data() : nullptr) >> 2, - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_shape.size() > 0 ? conv_transpose_attrs_.output_shape.data() : nullptr) >> 2); + gsl::narrow_cast(local_output_padding.size()), + reinterpret_cast(local_output_padding_ptr) >> 2, + gsl::narrow_cast(local_output_shape.size()), + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; constexpr size_t strides_vec_size = 2; @@ -59,8 +80,6 @@ class ConvTranspose : public JsKernel { std::vector local_strides(strides_vec_size, 0); std::vector local_dilations(dialations_vec_size, 0); std::vector local_kernel_shape; - std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), conv_transpose_attrs_.output_shape.end()); - std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); if (conv_transpose_attrs_.kernel_shape_specified) { for (size_t i = 0; i < kernel_shape.size() && i < kernel_shape_vec_size; ++i) { local_kernel_shape.push_back(gsl::narrow_cast(kernel_shape[i])); @@ -91,7 +110,8 @@ class ConvTranspose : public JsKernel { "strides" : Array.from(HEAP32.subarray($6, $6 + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10, $10 + $9)) : [], - "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [] + "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [], + "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), reinterpret_cast(local_dilations.data()) >> 2, @@ -102,9 +122,10 @@ class ConvTranspose : public JsKernel { static_cast(channels_last), reinterpret_cast(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - reinterpret_cast(local_output_padding.size() > 0 ? local_output_padding.data() : nullptr) >> 2, + reinterpret_cast(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - reinterpret_cast(local_output_shape.size() > 0 ? local_output_shape.data() : nullptr) >> 2); + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc index 7fdb4e5d114ea..7df1e483f52a1 100644 --- a/onnxruntime/core/providers/js/operators/pool.cc +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -52,15 +52,20 @@ namespace js { Pool); POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) +POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10) POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED(MaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1, 7) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 8, 9) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 10, 10) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11) POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11) POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12) diff --git a/onnxruntime/core/providers/js/operators/resize.cc b/onnxruntime/core/providers/js/operators/resize.cc index 5b2e385777a37..2514ab75dff26 100644 --- a/onnxruntime/core/providers/js/operators/resize.cc +++ b/onnxruntime/core/providers/js/operators/resize.cc @@ -51,6 +51,7 @@ namespace js { REGISTER_RESIZE_KERNEL(domain, 19); REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kOnnxDomain); +REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kMSInternalNHWCDomain); REGISTER_RESIZE_KERNEL_DOMAIN(kOnnxDomain); REGISTER_RESIZE_KERNEL_DOMAIN(kMSInternalNHWCDomain); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 78467b646b195..7e4c0dc8d7267 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -2,9 +2,7 @@ // Licensed under the MIT License #include -#include -#include -#include +#include #include "core/providers/shared_library/provider_api.h" #include "contexts.h" @@ -18,7 +16,8 @@ namespace openvino_ep { static std::unique_ptr g_global_context; GlobalContext& BackendManager::GetGlobalContext() { - // This is not thread safe to call for the first time, but it is first called on the main thread by the constructor so it is safe. + // This is not thread safe to call for the first time, + // but it is first called on the main thread by the constructor so it is safe. if (!g_global_context) g_global_context = std::make_unique(); return *g_global_context; @@ -88,7 +87,9 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, << "Backend created for graph " << subgraph_context_.subgraph_name; } } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. Initializing backend for graph " << subgraph_context_.subgraph_name; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " + << "Initializing backend for graph " + << subgraph_context_.subgraph_name; subgraph_context_.has_dynamic_input_shape = false; try { @@ -104,7 +105,7 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { bool has_batched_inputs = true; - for (int i = 0; i < (int)subgraph_context_.input_indexes.size(); i++) { + for (int i = 0; i < static_cast(subgraph_context_.input_indexes.size()); i++) { auto& input = model_proto.graph().input(subgraph_context_.input_indexes[i]); // Batch-process only raw image inputs (NCHW or NHWC layouts) @@ -215,7 +216,10 @@ BackendManager::ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_pr auto graph_proto = model_copy->mutable_graph(); for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { - auto g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + auto g_in_shape = graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->clear_dim(); const auto& shape = input_shapes[i]; for (size_t dim = 0, end = shape.size(); dim < end; dim++) { @@ -234,7 +238,11 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p auto graph_proto = model_copy->mutable_graph(); for (int i = 0; i < graph_proto->input_size(); i++) { - ONNX_NAMESPACE::TensorShapeProto* g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + ONNX_NAMESPACE::TensorShapeProto* g_in_shape = + graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->mutable_dim(0)->clear_dim_value(); g_in_shape->mutable_dim(0)->set_dim_value(1); } diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index c247ab60d3a6f..a177324b23f7d 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -3,6 +3,11 @@ #pragma once +#include +#include +#include +#include + #include "ov_interface.h" #include "contexts.h" #include "ibackend.h" @@ -13,7 +18,9 @@ namespace openvino_ep { // Singleton class that manages all the backends class BackendManager { public: - BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); + BackendManager(const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger); void Compute(OrtKernelContext* context); void ShutdownBackendManager(); static GlobalContext& GetGlobalContext(); @@ -21,7 +28,9 @@ class BackendManager { private: std::unique_ptr GetModelProtoFromFusedNode( - const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const; + const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger) const; bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index d49968cdb7f3d..d47c91dd46622 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -1,9 +1,7 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License -#include -#include -#include +#include #include #include @@ -58,7 +56,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext try { auto cnn_network = global_context.ie_core.ReadModel(model); if ((subgraph_context.precision == "FP16") && - (global_context.device_type.find("VPUX") == std::string::npos)) { + (global_context.device_type.find("NPU") == std::string::npos)) { // FP16 transformations ov::pass::ConvertFP32ToFP16 pass_obj; pass_obj.run_on_model(cnn_network); @@ -88,7 +86,8 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { - if (auto const_node = std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { + if (auto const_node = + std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { const_outputs_map[(*it)->get_friendly_name()] = const_node; results.erase(results.begin() + index); } @@ -254,7 +253,7 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName) { - long long totalTime = 0; + int64_t totalTime = 0; // Print performance counts stream << std::endl << "performance counts:" << std::endl diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index de78a150fe2dd..82b0351e87da5 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -4,9 +4,15 @@ #pragma once #define ORT_API_MANUAL_INIT +#include +#include +#include +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" #include "contexts.h" -#include #include "ov_interface.h" #ifdef _WIN32 #include @@ -57,7 +63,9 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); std::shared_ptr -CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, const SubGraphContext& subgraph_context, +CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, + const GlobalContext& global_context, + const SubGraphContext& subgraph_context, std::map>& const_outputs_map); void printPerformanceCounts(const std::vector& performanceMap, diff --git a/onnxruntime/core/providers/openvino/backends/backend_factory.cc b/onnxruntime/core/providers/openvino/backends/backend_factory.cc index c339f24e7022f..c586dd8b38af9 100644 --- a/onnxruntime/core/providers/openvino/backends/backend_factory.cc +++ b/onnxruntime/core/providers/openvino/backends/backend_factory.cc @@ -16,7 +16,7 @@ BackendFactory::MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto, const SubGraphContext& subgraph_context) { std::string type = global_context.device_type; if (type == "CPU" || type.find("GPU") != std::string::npos || - type.find("VPUX") != std::string::npos || + type.find("NPU") != std::string::npos || type.find("HETERO") != std::string::npos || type.find("MULTI") != std::string::npos || type.find("AUTO") != std::string::npos) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index f9517d7942664..09e1322ff59fb 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -6,10 +6,10 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" -// #include #include "basic_backend.h" #include "../backend_manager.h" @@ -57,33 +57,39 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, cl_context ctx = static_cast(global_context_.context); remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx); ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + model, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; #endif #endif } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } } catch (const char* msg) { @@ -127,10 +133,10 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } #endif #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - if (global_context_.device_type.find("VPUX") != std::string::npos) { + if (global_context_.device_type.find("NPU") != std::string::npos) { std::pair device_property; - device_property = std::make_pair("VPU_COMPILER_TYPE", "MLIR"); - device_config.emplace(ov::device::properties("VPUX", device_property)); + device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER"); + device_config.emplace(ov::device::properties("NPU", device_property)); } #endif } @@ -152,12 +158,12 @@ void BasicBackend::EnableCaching() { } void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { - if (global_context_.enable_opencl_throttling == true && global_context_.device_type.find("GPU") != std::string::npos) { + if (global_context_.enable_opencl_throttling == true && + global_context_.device_type.find("GPU") != std::string::npos) { LOGS_DEFAULT(INFO) << log_tag << "Enabled OpenCL queue throttling for GPU device"; std::pair device_property; device_property = std::make_pair("PLUGIN_THROTTLE", "1"); device_config.emplace(ov::device::properties("GPU_CONFIG_KEY", device_property)); - // device_config[GPU_CONFIG_KEY(PLUGIN_THROTTLE)] = "1"; } } @@ -187,7 +193,9 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && @@ -197,6 +205,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); auto tensor_size = tensor_shape.size(); + const char* tensor_data = tensor.GetTensorData(); auto tensor_iter = 0; ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { @@ -204,8 +213,16 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque tensor_iter += 1; } auto input = ie_cnn_network_->get_parameters().at(input_idx); - OVTensorPtr tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + OVTensorPtr tensor_ptr; + // avoid input copies on the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape, + (void*)tensor_data); + } else { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + } + try { infer_request->SetTensor(input_name, tensor_ptr); } catch (const char* msg) { @@ -251,7 +268,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } input_idx++; // Kernel Context Input Buffer @@ -264,9 +284,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create an Input Remote Blob auto input = ie_cnn_network_->get_parameters().at(0); - auto remote_blob = remote_context_->create_tensor(input->get_element_type(), input->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_blob); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_blob = remote_context_->create_tensor( + input->get_element_type(), input->get_shape(), *shared_buffer_const); + ov::Tensor tensor_remote = static_cast(remote_blob); + OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); infer_request->SetTensor(input_name, tensor_ptr); } else { OVTensorPtr graph_input_blob; @@ -295,7 +316,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe } } if (!output_name_found) { - throw std::string(log_tag + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); + throw std::string( + log_tag + + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); } size_t batch_size = 1; @@ -307,9 +331,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create a shared Blob, set the Infer Request Output Blob auto output = ie_cnn_network_->get_results().at(0); - auto remote_tensor = remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_tensor); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_tensor = + remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); + ov::Tensor tensor_t = static_cast(remote_tensor); + OVTensorPtr tensor_ptr = std::make_shared(tensor_t); try { infer_request->SetTensor(output_name, tensor_ptr); } catch (const char* msg) { @@ -364,7 +389,8 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe throw(msg); } size_t batch_size = 1; - auto output_tensor = GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); + auto output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); auto mem_info = output_tensor.GetTensorMemoryInfo(); if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; @@ -465,7 +491,8 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { #ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED if (openvino_ep::backend_utils::IsDebugEnabled()) { inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode - std::string& hw_target = (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; + std::string& hw_target = + (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; printPerformanceCounts(infer_request, std::cout, hw_target); } #endif diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 2f1d603640809..6eda641451a72 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -6,16 +6,17 @@ #include #define ORT_API_MANUAL_INIT -#include "core/session/onnxruntime_cxx_api.h" -#include "core/providers/openvino/contexts.h" -#include "core/providers/openvino/ibackend.h" -#include "core/providers/openvino/ov_interface.h" #include #include #include #include #include +#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/openvino/contexts.h" +#include "core/providers/openvino/ibackend.h" +#include "core/providers/openvino/ov_interface.h" + namespace onnxruntime { namespace openvino_ep { @@ -29,7 +30,7 @@ class BasicBackend : public IBackend { void Infer(OrtKernelContext* context) override; private: - bool ImportBlob(std::string hw_target, bool vpu_status); + bool ImportBlob(std::string hw_target, bool npu_status); void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index b61dcf8ca4922..29233e72c33b9 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -3,6 +3,9 @@ #pragma once +#include +#include +#include #include "ov_interface.h" namespace onnxruntime { @@ -12,7 +15,7 @@ namespace openvino_ep { struct GlobalContext { OVCore ie_core; bool is_wholly_supported_graph = false; - bool enable_vpu_fast_compile = false; + bool enable_npu_fast_compile = false; bool enable_opencl_throttling = false; bool enable_dynamic_shapes = false; size_t num_of_threads; @@ -34,7 +37,7 @@ struct GlobalContext { struct SubGraphContext { bool has_dynamic_input_shape = false; bool enable_batching = false; - bool set_vpu_config = false; + bool set_npu_config = false; bool is_constant = false; void* context = 0; std::string subgraph_name; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 990809926299e..a4c6b0f851c04 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -17,17 +17,18 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv openvino_ep::BackendManager::GetGlobalContext().device_type = info.device_type_; openvino_ep::BackendManager::GetGlobalContext().precision_str = info.precision_; - openvino_ep::BackendManager::GetGlobalContext().enable_vpu_fast_compile = info.enable_vpu_fast_compile_; + openvino_ep::BackendManager::GetGlobalContext().enable_npu_fast_compile = info.enable_npu_fast_compile_; openvino_ep::BackendManager::GetGlobalContext().cache_dir = info.cache_dir_; openvino_ep::BackendManager::GetGlobalContext().num_streams = info.num_streams_; openvino_ep::BackendManager::GetGlobalContext().context = info.context_; openvino_ep::BackendManager::GetGlobalContext().enable_opencl_throttling = info.enable_opencl_throttling_; openvino_ep::BackendManager::GetGlobalContext().enable_dynamic_shapes = info.enable_dynamic_shapes_; - if ((int)info.num_of_threads_ <= 0) { + if (static_cast(info.num_of_threads_) <= 0) { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = 8; - } else if ((int)info.num_of_threads_ > 8) { - std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; + } else if (static_cast(info.num_of_threads_) > 8) { + std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; ORT_THROW(err_msg); } else { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = info.num_of_threads_; @@ -56,7 +57,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv device_found = true; break; } - if (info.device_type_.find("VPUX") != std::string::npos && (info.precision_ == "FP16" || info.precision_ == "U8")) { + if ((info.device_type_.find("NPU") != std::string::npos) && + (info.precision_ == "FP16" || info.precision_ == "U8")) { device_found = true; break; } @@ -109,11 +111,14 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::BackendManager::GetGlobalContext().onnx_model_name = graph_viewer.Name(); #ifdef _WIN32 std::wstring onnx_path = graph_viewer.ModelPath().ToPathString(); - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = std::string(onnx_path.begin(), onnx_path.end()); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + std::string(onnx_path.begin(), onnx_path.end()); #else - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = graph_viewer.ModelPath().ToPathString(); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + graph_viewer.ModelPath().ToPathString(); #endif - openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = graph_viewer.DomainToVersionMap().at(kOnnxDomain); + openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = + graph_viewer.DomainToVersionMap().at(kOnnxDomain); #if defined(OPENVINO_2022_1) openvino_ep::GetCapability obj(graph_viewer, @@ -151,7 +156,8 @@ common::Status OpenVINOExecutionProvider::Compile( openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; - std::shared_ptr backend_manager = std::make_shared(fused_node, graph_body_viewer, *GetLogger()); + std::shared_ptr backend_manager = + std::make_shared(fused_node, graph_body_viewer, *GetLogger()); compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index a4fc09362fa23..3b56b54410e40 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -3,19 +3,28 @@ #pragma once -#include "backend_manager.h" #include #include #include +#include +#include +#include + +#include "backend_manager.h" namespace onnxruntime { static void print_build_options() { std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; - std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority you want to build" << std::endl; - std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build "; - std::cout << "are ['CPU','GPU','VPUX']" << std::endl; - std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" << std::endl; + std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " + << "you want to build" + << std::endl; + std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build " + << "are ['CPU','GPU']" + << std::endl; + std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. " + << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" + << std::endl; } static std::vector split(const std::string& s, char delim) { @@ -39,7 +48,7 @@ static std::vector parseDevices(const std::string& device_string) { print_build_options(); ORT_THROW("Invalid device string: " + device_string); } - std::vector dev_options = {"CPU", "GPU", "VPUX"}; + std::vector dev_options = {"CPU", "GPU"}; for (std::string dev : devices) { if (!std::count(dev_options.begin(), dev_options.end(), dev)) { print_build_options(); @@ -53,7 +62,7 @@ static std::vector parseDevices(const std::string& device_string) { struct OpenVINOExecutionProviderInfo { std::string device_type_; std::string precision_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -62,11 +71,18 @@ struct OpenVINOExecutionProviderInfo { bool enable_opencl_throttling_; bool enable_dynamic_shapes_; - explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_vpu_fast_compile, std::string dev_id, + explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_npu_fast_compile, std::string dev_id, size_t num_of_threads, std::string cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), device_id_(dev_id), num_of_threads_(num_of_threads), cache_dir_(cache_dir), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + device_id_(dev_id), + num_of_threads_(num_of_threads), + cache_dir_(cache_dir), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { if (dev_type == "") { LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" << "No runtime device selection option provided."; @@ -82,11 +98,11 @@ struct OpenVINOExecutionProviderInfo { #elif defined OPENVINO_CONFIG_GPU_FP16 device_type_ = "GPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_FP16 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_FP16 + device_type_ = "NPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_U8 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_U8 + device_type_ = "NPU"; precision_ = "U8"; #elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO #ifdef DEVICE_NAME @@ -126,11 +142,11 @@ struct OpenVINOExecutionProviderInfo { } else if (dev_type == "GPU.1_FP16") { device_type_ = "GPU.1"; precision_ = "FP16"; - } else if (dev_type == "VPUX_FP16") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_FP16") { + device_type_ = "NPU"; precision_ = "FP16"; - } else if (dev_type == "VPUX_U8") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_U8") { + device_type_ = "NPU"; precision_ = "U8"; } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0) { std::vector devices = parseDevices(dev_type); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 95b39bcc05983..fbb89710c8008 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -8,11 +8,16 @@ namespace onnxruntime { struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(const char* device_type, bool enable_vpu_fast_compile, + OpenVINOProviderFactory(const char* device_type, bool enable_npu_fast_compile, const char* device_id, size_t num_of_threads, const char* cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), num_of_threads_(num_of_threads), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + num_of_threads_(num_of_threads), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { device_type_ = (device_type == nullptr) ? "" : device_type; device_id_ = (device_id == nullptr) ? "" : device_id; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; @@ -24,7 +29,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { private: std::string device_type_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -35,7 +40,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { - OpenVINOExecutionProviderInfo info(device_type_, enable_vpu_fast_compile_, device_id_, num_of_threads_, + OpenVINOExecutionProviderInfo info(device_type_, enable_npu_fast_compile_, device_id_, num_of_threads_, cache_dir_, num_streams_, context_, enable_opencl_throttling_, enable_dynamic_shapes_); return std::make_unique(info); @@ -59,17 +64,18 @@ struct OpenVINO_Provider : Provider { std::string device_type = ""; // [device_type]: Overrides the accelerator hardware type and precision // with these values at runtime. - bool enable_vpu_fast_compile = false; // [enable_vpu_fast_compile]: Fast-compile may be optionally enabled to - // speeds up the model's compilation to VPU device specific format. + bool enable_npu_fast_compile = false; // [enable_npu_fast_compile]: Fast-compile may be optionally enabled to + // speeds up the model's compilation to NPU device specific format. const char* device_id = ""; // [device_id]: Selects a particular hardware device for inference. - size_t num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of + int num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of // threads with this value at runtime. const char* cache_dir = ""; // [cache_dir]: specify the path to // dump and load the blobs for the model caching/kernel caching (GPU) // feature. If blob files are already present, it will be directly loaded. int num_streams = 1; // [num_streams]: Option that specifies the number of parallel inference // requests to be processed on a given `device_type`. Overrides the - // accelerator default value of number of streams with this value at runtime. + // accelerator default value of number of streams + // with this value at runtime. bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU // device (Reduces CPU Utilization when using GPU) bool enable_dynamic_shapes = false; // [enable_dynamic_shapes]: Enables Dynamic Shapes feature for CPU device) @@ -80,14 +86,15 @@ struct OpenVINO_Provider : Provider { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (!((ov_supported_device_types.find(device_type) != ov_supported_device_types.end()) || - (device_type.find("HETERO:") == 0) || (device_type.find("MULTI:") == 0) || (device_type.find("AUTO:") == 0))) { + (device_type.find("HETERO:") == 0) || + (device_type.find("MULTI:") == 0) || + (device_type.find("AUTO:") == 0))) { ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } @@ -97,30 +104,37 @@ struct OpenVINO_Provider : Provider { if (provider_options_map.find("cache_dir") != provider_options_map.end()) { cache_dir = provider_options_map.at("cache_dir").c_str(); } + if (provider_options_map.find("context") != provider_options_map.end()) { - context = (void*)provider_options_map.at("context").c_str(); + std::string str = provider_options_map.at("context"); + uint64_t number = std::strtoull(str.c_str(), nullptr, 16); + context = reinterpret_cast(number); } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_threads' should be in the positive range.\n " + << "Executing with num_threads=1"; } } if (provider_options_map.find("num_streams") != provider_options_map.end()) { num_streams = std::stoi(provider_options_map.at("num_streams")); - if (num_streams <= 0 && num_streams > 8) { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_streams' should be in the range of 1-8 \n"); + if (num_streams <= 0) { + num_streams = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_streams' should be in the range of 1-8.\n " + << "Executing with num_streams=1"; } } std::string bool_flag = ""; - if (provider_options_map.find("enable_vpu_fast_compile") != provider_options_map.end()) { - bool_flag = provider_options_map.at("enable_vpu_fast_compile"); + if (provider_options_map.find("enable_npu_fast_compile") != provider_options_map.end()) { + bool_flag = provider_options_map.at("enable_npu_fast_compile"); if (bool_flag == "true" || bool_flag == "True") - enable_vpu_fast_compile = true; + enable_npu_fast_compile = true; else if (bool_flag == "false" || bool_flag == "False") - enable_vpu_fast_compile = false; + enable_npu_fast_compile = false; bool_flag = ""; } @@ -141,7 +155,7 @@ struct OpenVINO_Provider : Provider { enable_dynamic_shapes = false; } return std::make_shared(const_cast(device_type.c_str()), - enable_vpu_fast_compile, + enable_npu_fast_compile, device_id, num_of_threads, cache_dir, @@ -157,7 +171,6 @@ struct OpenVINO_Provider : Provider { void Shutdown() override { openvino_ep::BackendManager::ReleaseGlobalContext(); } - } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 3914488fc523b..d2ce378c97e02 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -29,7 +29,10 @@ std::shared_ptr OVCore::ReadModel(const std::string& model) const { } } -OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(ie_cnn_network, hw_target, device_config); @@ -43,7 +46,10 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std } #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) -OVExeNetwork OVCore::LoadNetwork(const std::string& model, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(const std::string& model, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index ed9583033ab34..935ac8f68411d 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -4,6 +4,7 @@ #pragma once #include +#include #if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) #define OV_API_20 @@ -43,9 +44,15 @@ class OVCore { public: std::shared_ptr ReadModel(const std::string& model_stream) const; - OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - OVExeNetwork LoadNetwork(const std::string& model_stream, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(const std::string& model_stream, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #endif void SetCache(std::string cache_dir_path); #ifdef IO_BUFFER_ENABLED @@ -62,7 +69,7 @@ class OVExeNetwork { ov::CompiledModel obj; public: - OVExeNetwork(ov::CompiledModel md) { obj = md; } + explicit OVExeNetwork(ov::CompiledModel md) { obj = md; } OVExeNetwork() { obj = ov::CompiledModel(); } ov::CompiledModel& Get() { return obj; } OVInferRequest CreateInferRequest(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h index b76d1cf534c2a..5bcf9d68cd94e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h +++ b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include "data_ops.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 171dd45c508cc..b030efa238209 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -24,7 +24,8 @@ namespace openvino_ep { // Constructor GetCapability::GetCapability(const GraphViewer& graph_viewer_param, std::string device_type_param, - const std::string version_param) : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { + const std::string version_param) + : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { if (version_param == "V_2022_1") { data_ops_ = new DataOps(graph_viewer_, V_2022_1, device_type_); } else if (version_param == "V_2022_2") { @@ -114,11 +115,11 @@ std::vector> GetCapability::Execute() { } openvino_ep::BackendManager::GetGlobalContext().is_wholly_supported_graph = true; - } else { // unsupported_nodes_idx.empty() - + } else { // unsupported_nodes_idx.empty() #if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time LOGS_DEFAULT(INFO) << "[OpenVINO-EP] DISABLE_GRAPH_PARTITION option is set"; - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, so making the full model fall back to default CPU Execution Provider"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, " + << "so making the full model fall back to default CPU Execution Provider"; return result; #endif @@ -159,7 +160,13 @@ std::vector> GetCapability::Execute() { std::vector cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs; - GetInputsOutputsOfCluster(graph_viewer_, this_cluster, ng_required_initializers, cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs); + GetInputsOutputsOfCluster(graph_viewer_, + this_cluster, + ng_required_initializers, + cluster_graph_inputs, + cluster_inputs, + const_inputs, + cluster_outputs); bool omit_subgraph = false; // Omitting zero dim subgraphs diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 70118c94f9ff8..a5a0faa3a8f24 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -2,11 +2,15 @@ // Licensed under the MIT License #include +#include +#include +#include +#include +#include + #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" #include "../backend_manager.h" -#include -#include #include "data_ops.h" #include "capabilities.h" #include "utils.h" @@ -72,269 +76,355 @@ std::set ops_supported_as_function = { std::vector supported_op_mode = { {"Abs", V_2020_4, {"CPU", "GPU"}}, - {"Abs", V_2023_0, {"VPUX"}}, + {"Abs", V_2023_0, {"NPU"}}, {"Acos", V_2020_4, {"CPU"}}, {"Acos", V_2022_1, {"GPU"}}, + {"Acos", V_2023_1, {"NPU"}}, {"Acosh", V_2020_4, {"CPU"}}, {"Acosh", V_2022_1, {"GPU"}}, + {"Acosh", V_2023_1, {"NPU"}}, {"Add", V_2020_4, {"CPU", "GPU"}}, - {"Add", V_2023_0, {"VPUX"}}, + {"Add", V_2023_0, {"NPU"}}, {"And", V_2020_4, {"CPU", "GPU"}}, + {"And", V_2023_1, {"NPU"}}, {"ArgMax", V_2020_4, {"CPU"}}, {"ArgMax", V_2021_1, {"GPU"}}, {"ArgMin", V_2020_4, {"CPU"}}, {"ArgMin", V_2022_1, {"GPU"}}, {"Asin", V_2020_4, {"CPU", "GPU"}}, + {"Asin", V_2023_1, {"NPU"}}, {"Asinh", V_2020_4, {"CPU", "GPU"}}, + {"Asinh", V_2023_1, {"NPU"}}, {"Atan", V_2020_4, {"CPU", "GPU"}}, + {"Atan", V_2023_1, {"NPU"}}, {"Atanh", V_2020_4, {"CPU"}}, {"Atanh", V_2022_1, {"GPU"}}, + {"Atanh", V_2023_1, {"NPU"}}, {"AveragePool", V_2020_4, {"CPU", "GPU"}}, - {"AveragePool", V_2023_0, {"VPUX"}}, + {"AveragePool", V_2023_0, {"NPU"}}, {"BatchNormalization", V_2020_4, {"CPU", "GPU"}}, - {"BatchNormalization", V_2023_0, {"VPUX"}}, + {"BatchNormalization", V_2023_0, {"NPU"}}, {"BitShift", V_2022_1, {"CPU"}}, + {"BitShift", V_2023_1, {"NPU"}}, {"Cast", V_2020_4, {"CPU", "GPU"}}, - {"Cast", V_2023_0, {"VPUX"}}, + {"Cast", V_2023_0, {"NPU"}}, + {"CastLike", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Ceil", V_2020_4, {"GPU"}}, {"Ceil", V_2021_4, {"CPU"}}, + {"Ceil", V_2023_1, {"NPU"}}, {"Celu", V_2022_1, {"CPU", "GPU"}}, {"Clip", V_2020_4, {"CPU", "GPU"}}, - {"Clip", V_2023_0, {"VPUX"}}, + {"Clip", V_2023_0, {"NPU"}}, + {"Compress", V_2023_1, {"CPU", "GPU"}}, {"Concat", V_2020_4, {"CPU", "GPU"}}, - {"Concat", V_2023_0, {"VPUX"}}, + {"Concat", V_2023_0, {"NPU"}}, {"Constant", V_2020_4, {"CPU", "GPU"}}, - {"Constant", V_2023_0, {"VPUX"}}, + {"Constant", V_2023_0, {"NPU"}}, {"ConstantOfShape", V_2020_4, {"CPU", "GPU"}}, - {"ConstantOfShape", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op in the plugin. + {"ConstantOfShape", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op in the plugin. {"Conv", V_2020_4, {"CPU", "GPU"}}, - {"Conv", V_2023_0, {"VPUX"}}, + {"Conv", V_2023_0, {"NPU"}}, {"ConvInteger", V_2022_1, {"CPU", "GPU"}}, + {"ConvInteger", V_2023_1, {"NPU"}}, {"ConvTranspose", V_2020_4, {"CPU", "GPU"}}, + {"ConvTranspose", V_2023_1, {"NPU"}}, {"Cos", V_2020_4, {"CPU"}}, {"Cos", V_2022_1, {"GPU"}}, - {"Cos", V_2023_0, {"VPUX"}}, + {"Cos", V_2023_0, {"NPU"}}, {"Cosh", V_2020_4, {"CPU"}}, {"Cosh", V_2022_1, {"GPU"}}, + {"Cosh", V_2023_1, {"NPU"}}, {"CumSum", V_2022_1, {"CPU", "GPU"}}, - {"CumSum", V_2023_0, {"VPUX"}}, + {"CumSum", V_2023_0, {"NPU"}}, {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, - {"DepthToSpace", V_2023_0, {"VPUX"}}, + {"DepthToSpace", V_2023_0, {"NPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"DequantizeLinear", V_2023_0, {"VPUX"}}, + {"DequantizeLinear", V_2023_0, {"NPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, - {"Div", V_2023_0, {"VPUX"}}, + {"Div", V_2023_0, {"NPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, - {"Dropout", V_2023_0, {"VPUX"}}, + {"Dropout", V_2023_0, {"NPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, - {"Elu", V_2023_0, {"VPUX"}}, + {"Elu", V_2023_0, {"NPU"}}, // {"Einsum", V_2023_0, {"CPU", "GPU"}}, {"Equal", V_2020_4, {"CPU", "GPU"}}, - {"Equal", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Equal", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Erf", V_2020_4, {"CPU", "GPU"}}, - {"Erf", V_2023_0, {"VPUX"}}, + {"Erf", V_2023_0, {"NPU"}}, {"Exp", V_2020_4, {"CPU", "GPU"}}, - {"Exp", V_2023_0, {"VPUX"}}, + {"Exp", V_2023_0, {"NPU"}}, {"Expand", V_2022_1, {"CPU", "GPU"}}, - {"Expand", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op and multiply op in the plugin. + {"Expand", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op and multiply op in the plugin. {"EyeLike", V_2022_1, {"CPU"}}, - {"EyeLike", V_2023_0, {"VPUX"}}, // NoOP + {"EyeLike", V_2023_0, {"NPU"}}, // NoOP {"Flatten", V_2020_4, {"CPU", "GPU"}}, - {"Flatten", V_2023_0, {"VPUX"}}, + {"Flatten", V_2023_0, {"NPU"}}, {"Floor", V_2020_4, {"CPU", "GPU"}}, + {"Floor", V_2023_1, {"NPU"}}, {"Gather", V_2020_4, {"CPU", "GPU"}}, - {"Gather", V_2023_0, {"VPUX"}}, + {"Gather", V_2023_0, {"NPU"}}, {"GatherElements", V_2022_2, {"CPU", "GPU"}}, + {"GatherElements", V_2023_1, {"NPU"}}, {"GatherND", V_2021_4, {"CPU", "GPU"}}, + {"GatherND", V_2023_1, {"NPU"}}, {"Gemm", V_2020_4, {"CPU", "GPU"}}, - {"Gemm", V_2023_0, {"VPUX"}}, + {"Gemm", V_2023_0, {"NPU"}}, {"GlobalAveragePool", V_2020_4, {"CPU", "GPU"}}, - {"GlobalAveragePool", V_2023_0, {"VPUX"}}, + {"GlobalAveragePool", V_2023_0, {"NPU"}}, {"GlobalLpPool", V_2020_4, {"CPU", "GPU"}}, + {"GlobalLpPool", V_2023_1, {"NPU"}}, {"GlobalMaxPool", V_2022_1, {"CPU", "GPU"}}, + {"GlobalMaxPool", V_2023_1, {"NPU"}}, {"Greater", V_2020_4, {"CPU", "GPU"}}, - {"Greater", V_2023_0, {"VPUX"}}, + {"Greater", V_2023_0, {"NPU"}}, {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"GreaterOrEqual", V_2023_0, {"VPUX"}}, + {"GreaterOrEqual", V_2023_0, {"NPU"}}, {"GridSample", V_2022_3, {"CPU"}}, {"GridSample", V_2023_0, {"GPU"}}, + {"GridSample", V_2023_1, {"NPU"}}, + {"HardMax", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Identity", V_2020_4, {"CPU", "GPU"}}, - {"Identity", V_2023_0, {"VPUX"}}, // NoOP + {"Identity", V_2023_0, {"NPU"}}, // NoOP {"If", V_2022_3, {"CPU", "GPU"}}, + {"If", V_2023_1, {"NPU"}}, {"ImageScaler", V_2022_1, {"CPU", "GPU"}}, - {"ImageScaler", V_2023_0, {"VPUX"}}, + {"ImageScaler", V_2023_0, {"NPU"}}, {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, - {"InstanceNormalization", V_2023_0, {"VPUX"}}, + {"InstanceNormalization", V_2023_0, {"NPU"}}, {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, + {"HardSigmoid", V_2023_1, {"NPU"}}, {"HardMax", V_2022_1, {"CPU", "GPU"}}, {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, - {"LeakyRelu", V_2023_0, {"VPUX"}}, + {"LeakyRelu", V_2023_0, {"NPU"}}, {"Less", V_2020_4, {"CPU", "GPU"}}, - {"Less", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Less", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"LessOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"LessOrEqual", V_2023_0, {"VPUX"}}, + {"LessOrEqual", V_2023_0, {"NPU"}}, {"Log", V_2020_4, {"CPU", "GPU"}}, - {"Log", V_2023_0, {"VPUX"}}, + {"Log", V_2023_0, {"NPU"}}, {"LogSoftMax", V_2022_1, {"CPU", "GPU"}}, {"Loop", V_2021_4, {"CPU", "GPU"}}, + {"LpNormalization", V_2023_1, {"CPU", "GPU", "NPU"}}, + {"LpPool", V_2023_1, {"CPU", "GPU", "NPU"}}, {"LRN", V_2020_4, {"CPU", "GPU"}}, - {"LRN", V_2023_0, {"VPUX"}}, + {"LRN", V_2023_0, {"NPU"}}, {"LSTM", V_2020_4, {"CPU", "GPU"}}, + {"LSTM", V_2023_1, {"NPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, - {"MatMul", V_2023_0, {"VPUX"}}, + {"MatMul", V_2023_0, {"NPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2023_1, {"NPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, - {"Max", V_2023_0, {"VPUX"}}, + {"Max", V_2023_0, {"NPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, - {"MaxPool", V_2023_0, {"VPUX"}}, + {"MaxPool", V_2023_0, {"NPU"}}, {"Mean", V_2020_4, {"CPU", "GPU"}}, - {"Mean", V_2023_0, {"VPUX"}}, + {"Mean", V_2023_0, {"NPU"}}, {"MeanVarianceNormalization", V_2022_1, {"CPU", "GPU"}}, + {"MeanVarianceNormalization", V_2023_1, {"NPU"}}, {"Min", V_2020_4, {"CPU", "GPU"}}, - {"Min", V_2023_0, {"VPUX"}}, + {"Min", V_2023_0, {"NPU"}}, {"Mod", V_2022_1, {"CPU", "GPU"}}, {"Mul", V_2020_4, {"CPU", "GPU"}}, - {"Mul", V_2023_0, {"VPUX"}}, + {"Mul", V_2023_0, {"NPU"}}, {"Neg", V_2020_4, {"CPU", "GPU"}}, - {"Neg", V_2023_0, {"VPUX"}}, + {"Neg", V_2023_0, {"NPU"}}, {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}}, + {"NonMaxSuppression", V_2023_1, {"NPU"}}, {"NonZero", V_2021_1, {"CPU"}}, {"NonZero", V_2023_0, {"GPU"}}, {"Not", V_2021_1, {"CPU", "GPU"}}, {"Not", V_2020_4, {"CPU", "GPU"}}, + {"Not", V_2023_1, {"NPU"}}, {"OneHot", V_2020_4, {"CPU", "GPU"}}, + {"OneHot", V_2023_1, {"NPU"}}, {"Or", V_2022_1, {"CPU", "GPU"}}, + {"Or", V_2023_1, {"NPU"}}, {"Pad", V_2020_4, {"CPU", "GPU"}}, - {"Pad", V_2023_0, {"VPUX"}}, + {"Pad", V_2023_0, {"NPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, - {"Pow", V_2023_0, {"VPUX"}}, + {"Pow", V_2023_0, {"NPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"PRelu", V_2023_0, {"VPUX"}}, + {"PRelu", V_2023_0, {"NPU"}}, {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2023_1, {"NPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"QuantizeLinear", V_2023_0, {"VPUX"}}, + {"QuantizeLinear", V_2023_0, {"NPU"}}, + {"RNN", V_2023_1, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_1, {"NPU"}}, {"RandomNormal", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormal", V_2023_1, {"NPU"}}, {"Range", V_2022_1, {"CPU", "GPU"}}, - {"Range", V_2023_0, {"VPUX"}}, + {"Range", V_2023_0, {"NPU"}}, {"Reciprocal", V_2020_4, {"CPU", "GPU"}}, - {"Reciprocal", V_2023_0, {"VPUX"}}, + {"Reciprocal", V_2023_0, {"NPU"}}, {"ReduceL1", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL1", V_2023_1, {"NPU"}}, {"ReduceL2", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL2", V_2023_1, {"NPU"}}, {"ReduceLogSum", V_2020_4, {"CPU"}}, {"ReduceLogSum", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSum", V_2023_1, {"NPU"}}, {"ReduceLogSumExp", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSumExp", V_2023_1, {"NPU"}}, {"ReduceMax", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMax", V_2023_1, {"NPU"}}, {"ReduceMean", V_2020_4, {"CPU", "GPU"}}, - {"ReduceMean", V_2023_0, {"VPUX"}}, + {"ReduceMean", V_2023_0, {"NPU"}}, {"ReduceMin", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMin", V_2023_1, {"NPU"}}, {"ReduceProd", V_2020_4, {"CPU"}}, {"ReduceProd", V_2022_1, {"GPU"}}, + {"ReduceProd", V_2023_1, {"NPU"}}, {"ReduceSum", V_2020_4, {"CPU", "GPU"}}, + // {"ReduceSum", V_2023_1, {"NPU"}}, {"ReduceSumSquare", V_2020_4, {"CPU"}}, {"ReduceSumSquare", V_2022_1, {"CPU", "GPU"}}, + {"ReduceSumSquare", V_2023_1, {"NPU"}}, {"Relu", V_2020_4, {"CPU", "GPU"}}, - {"Relu", V_2023_0, {"VPUX"}}, + {"Relu", V_2023_0, {"NPU"}}, {"Resize", V_2020_4, {"CPU"}}, {"Resize", V_2022_1, {"GPU"}}, + {"Resize", V_2023_1, {"NPU"}}, {"Reshape", V_2020_4, {"CPU", "GPU"}}, - {"Reshape", V_2023_0, {"VPUX"}}, + {"Reshape", V_2023_0, {"NPU"}}, {"ReverseSequence", V_2022_1, {"CPU", "GPU"}}, {"RoiAlign", V_2021_1, {"CPU", "GPU"}}, + {"RoiAlign", V_2023_1, {"NPU"}}, {"Round", V_2021_4, {"CPU", "GPU"}}, + {"Round", V_2023_1, {"NPU"}}, {"Scatter", V_2022_1, {"CPU", "GPU"}}, + {"Scatter", V_2023_1, {"NPU"}}, {"ScatterElements", V_2022_1, {"CPU", "GPU"}}, + {"ScatterElements", V_2023_1, {"NPU"}}, {"ScatterND", V_2022_1, {"CPU", "GPU"}}, + {"ScatterND", V_2023_1, {"NPU"}}, {"Selu", V_2020_4, {"CPU", "GPU"}}, + {"Selu", V_2023_1, {"NPU"}}, {"Shape", V_2020_4, {"CPU", "GPU"}}, - {"Shape", V_2023_0, {"VPUX"}}, + {"Shape", V_2023_0, {"NPU"}}, {"Shrink", V_2022_1, {"CPU", "GPU"}}, - {"Shrink", V_2023_0, {"VPUX"}}, + {"Shrink", V_2023_0, {"NPU"}}, {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, - {"Sigmoid", V_2023_0, {"VPUX"}}, + {"Sigmoid", V_2023_0, {"NPU"}}, {"Sign", V_2020_4, {"CPU"}}, {"Sign", V_2022_1, {"GPU"}}, - {"Sign", V_2023_0, {"VPUX"}}, + {"Sign", V_2023_0, {"NPU"}}, {"Sin", V_2022_1, {"CPU", "GPU"}}, - {"Sin", V_2023_0, {"VPUX"}}, + {"Sin", V_2023_0, {"NPU"}}, {"Sinh", V_2020_4, {"CPU"}}, + {"Sinh", V_2023_1, {"NPU"}}, {"Size", V_2022_1, {"CPU", "GPU"}}, + {"Size", V_2023_1, {"NPU"}}, {"Slice", V_2020_4, {"CPU", "GPU"}}, - {"Slice", V_2023_0, {"VPUX"}}, + {"Slice", V_2023_0, {"NPU"}}, {"Softmax", V_2020_4, {"CPU", "GPU"}}, - {"Softmax", V_2023_0, {"VPUX"}}, + {"Softmax", V_2023_0, {"NPU"}}, {"Softplus", V_2022_1, {"CPU", "GPU"}}, - {"Softplus", V_2023_0, {"VPUX"}}, + {"Softplus", V_2023_0, {"NPU"}}, {"Softsign", V_2022_1, {"CPU", "GPU"}}, {"SpaceToDepth", V_2020_4, {"CPU", "GPU"}}, - {"SpaceToDepth", V_2023_0, {"VPUX"}}, + {"SpaceToDepth", V_2023_0, {"NPU"}}, {"Split", V_2020_4, {"CPU", "GPU"}}, - {"Split", V_2023_0, {"VPUX"}}, + {"Split", V_2023_0, {"NPU"}}, {"Sqrt", V_2020_4, {"CPU", "GPU"}}, - {"Sqrt", V_2023_0, {"VPUX"}}, + {"Sqrt", V_2023_0, {"NPU"}}, {"Squeeze", V_2020_4, {"CPU", "GPU"}}, - {"Squeeze", V_2023_0, {"VPUX"}}, + {"Squeeze", V_2023_0, {"NPU"}}, {"Softsign", V_2020_4, {"CPU"}}, {"Sub", V_2020_4, {"CPU", "GPU"}}, - {"Sub", V_2023_0, {"VPUX"}}, + {"Sub", V_2023_0, {"NPU"}}, {"Sum", V_2020_4, {"CPU", "GPU"}}, - {"Sum", V_2023_0, {"VPUX"}}, + {"Sum", V_2023_0, {"NPU"}}, {"Tan", V_2020_4, {"CPU", "GPU"}}, + {"Tan", V_2023_1, {"NPU"}}, {"Tanh", V_2020_4, {"CPU", "GPU"}}, - {"Tanh", V_2023_0, {"VPUX"}}, + {"Tanh", V_2023_0, {"NPU"}}, {"ThresholdedRelu", V_2022_1, {"CPU", "GPU"}}, - {"ThresholdedRelu", V_2023_0, {"VPUX"}}, + {"ThresholdedRelu", V_2023_0, {"NPU"}}, {"Tile", V_2021_3, {"CPU", "GPU"}}, - {"Tile", V_2023_0, {"VPUX"}}, + {"Tile", V_2023_0, {"NPU"}}, {"Transpose", V_2020_4, {"CPU", "GPU"}}, - {"Transpose", V_2023_0, {"VPUX"}}, + {"Transpose", V_2023_0, {"NPU"}}, {"Trilu", V_2023_0, {"CPU", "GPU"}}, + {"Trilu", V_2023_1, {"NPU"}}, {"TopK", V_2020_4, {"CPU", "GPU"}}, - {"TopK", V_2023_0, {"VPUX"}}, + {"TopK", V_2023_0, {"NPU"}}, + {"Upsample", V_2020_4, {"CPU", "GPU"}}, {"Unsqueeze", V_2020_4, {"CPU", "GPU"}}, - {"Unsqueeze", V_2023_0, {"VPUX"}}, - {"Upsample", V_2021_1, {"CPU"}}, - {"Upsample", V_2021_4, {"GPU"}}, - {"Upsample", V_2023_0, {"VPUX"}}, + {"Unsqueeze", V_2023_0, {"NPU"}}, {"Where", V_2022_1, {"CPU", "GPU"}}, - {"Where", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Where", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Xor", V_2022_1, {"CPU", "GPU"}}, + {"Xor", V_2023_1, {"NPU"}}, }; void DataOps::populate_types_supported() { - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_initializer_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_vpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_npu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_cpu_.insert(std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_cpu_.insert( + std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_gpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_gpu_.insert(std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_gpu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_gpu_.insert( + std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); } void DataOps::populate_op_mode_supported() { @@ -349,10 +439,10 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); - no_dimension_supported_.push_back({"Greater", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Greater", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); - no_dimension_supported_.push_back({"Max", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Max", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"QuantizeLinear", V_2021_4, {"All"}}); @@ -382,11 +472,14 @@ void DataOps::populate_op_mode_supported() { { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { - // Abs is not supproted with INT8 or INT32 as input data type on GPU - if (device_id_.find("GPU") != std::string::npos) { + // Abs is not supproted with INT8 or INT32 as input data type on GPU and NPU + if ((device_id_.find("GPU") != std::string::npos) || + (device_id_.find("NPU") != std::string::npos)) { for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || - node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || + node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -399,11 +492,14 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // tensor type does not support select last index auto& attributes = node->GetAttributes(); - auto last_index_arg = attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() : 0; + auto last_index_arg = + attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() + : 0; if (last_index_arg != 0) return true; // tensor type supports float as input for argmax and argmin - if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) + if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) return true; return false; }}; @@ -415,7 +511,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { // int64 data type is not supported on GPU - const bool data_is_int64 = node->InputDefs()[0]->Type()->find("int64") != std::string::npos; + const bool data_is_int64 = + node->InputDefs()[0]->Type()->find("int64") != std::string::npos; return data_is_int64; } return false; @@ -506,9 +603,12 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto x_data_type = node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); auto y_data_type = node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - // currently both inputs with int32 are not supported and also both input datatypes should be same - const bool A_is_int32 = node->InputDefs()[0]->Type()->find("int32") != std::string::npos; - const bool B_is_int32 = node->InputDefs()[1]->Type()->find("int32") != std::string::npos; + // currently both inputs with int32 are not supported + // and also both input datatypes should be same + const bool A_is_int32 = + node->InputDefs()[0]->Type()->find("int32") != std::string::npos; + const bool B_is_int32 = + node->InputDefs()[1]->Type()->find("int32") != std::string::npos; if ((A_is_int32 && B_is_int32) || (x_data_type != y_data_type)) return true; } @@ -589,11 +689,13 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto slope = node->InputDefs()[1]; // PRelu slope has to be an initializer or needs to come from a constant node - if (initializers.count(slope->Name())) + if (initializers.count(slope->Name())) { return false; - else { - for (auto input_node = node->InputNodesBegin(); input_node != node->InputNodesEnd(); ++input_node) { - if (GetInputCount(this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) + } else { + for (auto input_node = node->InputNodesBegin(); + input_node != node->InputNodesEnd(); ++input_node) { + if (GetInputCount( + this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) return false; } } @@ -603,12 +705,12 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"PRelu", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); // Reshape op with empty dim is Rejected for Myriad - //[TODO] Is this condition required anymore with Myriad removed? + // [TODO] Is this condition required anymore with Myriad removed? if (shape != nullptr) { for (const auto& dim : input_arg->Shape()->dim()) { if (utils::HasDimValue(dim) && dim.dim_value() == 0) @@ -638,7 +740,8 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { // INT32 dataype is not supported as input for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -650,9 +753,11 @@ void DataOps::populate_op_mode_supported() { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { - auto output_data_type = node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + auto output_data_type = + node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // If the output of ScatterND op is BOOL, it is rejected for GPU. - if (output_data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) + if (output_data_type == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) return true; } return false; @@ -666,7 +771,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // If the Input of Shrink op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) return true; } return false; @@ -714,10 +820,11 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Squeeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze - // If axes is an input, then we cannot produce a static graph. Conversion fails in convert_function_to_cnn_network. + // If axes is an input, then we cannot produce a static graph. + // Conversion fails in convert_function_to_cnn_network. for (size_t i = 0; i < node->InputDefs().size(); i++) { if (node->InputDefs()[i]->Name() == "axes") { return true; @@ -728,14 +835,15 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); if (upsample_attr.count("scales") > 0) { auto& upsample_arg = upsample_attr.at("scales"); auto float_size = upsample_arg.floats_size(); - if (float_size > 2 && (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { + if (float_size > 2 && + (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { return true; } } @@ -750,9 +858,12 @@ void DataOps::populate_op_mode_supported() { } } // x_arg supports only float, int8 and float16 type - if ((x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { + if ((x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { return false; } else { return true; @@ -849,9 +960,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } else { auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("VPUX") != std::string::npos || device_id_.find("HETERO") != std::string::npos || + if (device_id_.find("NPU") != std::string::npos || device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { - for (auto const& var : supported_types_vpu_) { + for (auto const& var : supported_types_npu_) { if ((var.first <= version_id_) && (var.second == dtype)) { return true; @@ -1079,7 +1190,9 @@ bool DataOps::node_is_supported(const std::mapsecond.find(optype) == opset->second.end() && op_fun == ops_supported_as_function.end()) { #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "The operator is not available in OpenVINO ngraph operators list nor the operator is a special ONNX function" << std::endl; + std::cout << "The operator is not available in OpenVINO ngraph operators list" + << "nor the operator is a special ONNX function" + << std::endl; } #endif return false; @@ -1095,10 +1208,12 @@ std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_setForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, bool is_input) { - if(is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { + graph_viewer_.GetNode(node_idx)->ForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, + bool is_input) { + if (is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { ng_required_initializers.insert(node_arg.Name()); - } }, true); + } }, + true); } else { unsupported_nodes_idx.push_back(node_idx); } @@ -1110,7 +1225,8 @@ bool DataOps::IsOpSupportedOnlyInModel(std::string name) { return ops_supported_only_in_model.find(name) != ops_supported_only_in_model.end(); } -bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node) { +bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, + const Node* node) { if (node->OpType() == "Reshape") { const auto& shape_arg = node->InputDefs()[1]; if (ng_required_initializers.find(shape_arg->Name()) == ng_required_initializers.end()) { @@ -1119,15 +1235,20 @@ bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& } else if (node->OpType() == "Expand") { // nGraph only supports constant shape input values const auto& output = node->OutputDefs()[0]; - if (output->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) + if (output->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) return true; } else if (node->OpType() == "RoiAlign") { using onnx_dtype = ONNX_NAMESPACE::TensorProto_DataType; - onnx_dtype input_0_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_1_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_2_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype output_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_0_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_1_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_2_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype output_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); if ((input_0_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || (input_1_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index cc968d02ea644..a5aa3f825602c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -3,6 +3,11 @@ #pragma once #include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -47,7 +52,7 @@ class DataOps { std::multimap op_list_; std::vector subgraph_supported_; std::vector no_dimension_supported_; - std::set supported_types_vpu_; + std::set supported_types_npu_; std::set supported_types_cpu_; std::set supported_types_gpu_; std::set supported_types_initializer_; @@ -64,14 +69,16 @@ class DataOps { const NodeIndex node_idx); public: - DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { + DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) + : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { populate_op_mode_supported(); populate_types_supported(); } virtual std::vector GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers); virtual bool IsOpSupportedOnlyInModel(std::string name); - virtual bool SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node); + virtual bool SpecialConditionForClusterSizeOne( + std::unordered_set& ng_required_initializers, const Node* node); virtual bool DoNotOmitSubGraph(const std::string& name); virtual bool InsertNode(const std::string& name); VersionNum GetVersion() const { return version_id_; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index be509b6743621..74369d39b9a24 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/shared_library/provider_api.h" +#include "utils.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245 5208) @@ -113,7 +114,8 @@ std::map> GetNgSupportedOps(const int onnx_op * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph */ std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes) { +GetPartitionedClusters(const std::vector& topological_order, + const std::vector& unsupported_nodes) { std::vector> ng_clusters; auto prev = topological_order.begin(); @@ -140,7 +142,10 @@ GetPartitionedClusters(const std::vector& topological_order, const st return ng_clusters; } -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster) { +void IdentifyConnectedNodes(const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster) { if (std::find(cluster.begin(), cluster.end(), curr_node_index) == cluster.end()) return; @@ -205,7 +210,8 @@ void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const auto& ext_node = graph_viewer.GetNode((*it).Index()); if (std::find(cluster.begin(), cluster.end(), ext_node->Index()) == cluster.end()) { - // Node is external to this_cluster. Search through its inputs to find the output that is generated by this_cluster. + // Node is external to this_cluster. Search through its inputs to + // find the output that is generated by this_cluster. std::set ext_node_inputs; ext_node->ForEachDef( [&ext_node_inputs](const NodeArg& arg, bool is_input) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 70f6954ea991c..c256cde97956e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -1,5 +1,15 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -18,9 +28,14 @@ int GetOnnxOpSet(const GraphViewer& graph_viewer); std::map> GetNgSupportedOps(const int onnx_opset); std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes); - -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster); +GetPartitionedClusters( + const std::vector& topological_order, const std::vector& unsupported_nodes); + +void IdentifyConnectedNodes( + const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster); std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index ccbc1acaa2f9e..3e17fb157b160 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -1,16 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include + #include "core/providers/common.h" +#include "core/util/qmath.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class BatchNormOpBuilder : public BaseOpBuilder { @@ -18,9 +22,446 @@ class BatchNormOpBuilder : public BaseOpBuilder { BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder); + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + std::pair CheckMinMax(float rmin, float rmax) const { + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + + // Both QNN and ORT require the range to include 0.0f + rmin = std::min(rmin, 0.0f); + rmax = std::max(rmax, 0.0f); + + return std::make_pair(rmin, rmax); + } + + template + Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax) const { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point) const { + std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + + scale = (rmax - rmin) / (qmax - qmin); + const float initial_zero_point = qmin - (rmin / scale); + zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + // To match QNN quantization definition + zero_point = 0 - zero_point; + return Status::OK(); + } + + inline Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value, + int& offset) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int8_t); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int16_t); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int32_t); + break; + } + case QNN_DATATYPE_INT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int64_t); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint8_t); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint16_t); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint32_t); + break; + } + case QNN_DATATYPE_UINT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint64_t); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(float); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status AssertUnpackedTensorSize(const Qnn_DataType_t qnn_data_type, + const uint32_t channel, + const size_t raw_ptr_length) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_FLOAT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(float)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status ConvertToRawOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const std::vector& double_tensor, + std::vector& raw_tensor) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(int8_t)); + int8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(int16_t)); + int16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(int32_t)); + int32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(int64_t)); + int64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(uint8_t)); + uint8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(uint16_t)); + uint16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(uint32_t)); + uint32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(uint64_t)); + uint64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_FLOAT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(float)); + float* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UFIXED_POINT_32: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_32: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline double Dequantize(const OnnxInputInfo& info, + const double quant_value) const { + auto offset = static_cast(info.quant_param.scaleOffsetEncoding.offset); + auto scale = static_cast(info.quant_param.scaleOffsetEncoding.scale); + return (quant_value + offset) * scale; + } + + template + inline T Saturate(const T qmax, + const T qmin, + const T quant_value) const { + if (quant_value > qmax) { + return qmax; + } else if (quant_value < qmin) { + return qmin; + } else { + return quant_value; + } + } + + inline Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value) const { + int qmin = 0; + int qmax = 255; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); + return Status::OK(); + } + + Status PreprocessMean(const OnnxInputInfo& mean_info, + const bool is_npu_backend, + const uint8_t* mean_raw_ptr, + const size_t mean_raw_ptr_length, + std::vector& mean_out) const { + // tensor length (channel) + uint32_t channel = mean_info.shape[0]; + mean_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double mean_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); + mean_out[i] = (is_npu_backend) ? Dequantize(mean_info, mean_value) : mean_value; + } + return Status::OK(); + } + + Status PreprocessStd(const OnnxInputInfo& var_info, + const bool is_npu_backend, + const uint8_t* var_raw_ptr, + const size_t var_raw_ptr_length, + const float epsilon, + std::vector& std_out) const { + // tensor length (channel) + uint32_t channel = var_info.shape[0]; + std_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double var_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); + std_out[i] = (is_npu_backend) ? Dequantize(var_info, var_value) : var_value; + std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); + } + return Status::OK(); + } + + Status PreprocessScale(const OnnxInputInfo& scale_info, + const bool is_npu_backend, + const uint8_t* scale_raw_ptr, + const size_t scale_raw_ptr_length, + const std::vector& std_double_tensor, + double& rmax, + double& rmin, + std::vector& scale_out) const { + // tensor length (channel) + uint32_t channel = scale_info.shape[0]; + scale_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double scale_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); + scale_out[i] = (is_npu_backend) ? Dequantize(scale_info, scale_value) : scale_value; + scale_out[i] = scale_out[i] / std_double_tensor[i]; + rmax = std::max(rmax, scale_out[i]); + rmin = std::min(rmin, scale_out[i]); + } + return Status::OK(); + } + + Status PreprocessBias(const OnnxInputInfo& bias_info, + const bool is_npu_backend, + const uint8_t* bias_raw_ptr, + const size_t bias_raw_ptr_length, + const std::vector& scale_double_tensor, + const std::vector& mean_double_tensor, + double& rmax, + double& rmin, + std::vector& bias_out) const { + // tensor length (channel) + uint32_t channel = bias_info.shape[0]; + bias_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double bias_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); + bias_out[i] = (is_npu_backend) ? Dequantize(bias_info, bias_value) : bias_value; + bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); + rmax = std::max(rmax, bias_out[i]); + rmin = std::min(rmin, bias_out[i]); + } + return Status::OK(); + } + + Status Postprocess(const OnnxInputInfo& info, + const bool is_npu_backend, + const std::vector& double_tensor, + const double rmax, + const double rmin, + Qnn_QuantizeParams_t& quant_param, + std::vector& raw_tensor) const { + if (is_npu_backend) { + raw_tensor.resize(double_tensor.size()); + float scale = 0.0f; + int zero_point = 0; + ORT_RETURN_IF_ERROR(GetQuantParams(static_cast(rmin), + static_cast(rmax), + info.qnn_data_type, + scale, + zero_point)); + quant_param = QNN_QUANTIZE_PARAMS_INIT; + utils::InitializeQuantizeParam(quant_param, true, scale, zero_point); + for (size_t i = 0; i < double_tensor.size(); ++i) { + // onnx only supports 8 bits quantization + int quant_value_int = 0; + ORT_RETURN_IF_ERROR(Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); + if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + raw_tensor[i] = static_cast(quant_value_int); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + int8_t quant_value = static_cast(quant_value_int); + raw_tensor[i] = *reinterpret_cast(&quant_value); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); + } + } + } else { + ORT_RETURN_IF_ERROR(ConvertToRawOnQnnDataType(info.qnn_data_type, double_tensor, raw_tensor)); + } + return Status::OK(); + } }; // BatchNorm is sensitive with data layout, no special validation so far @@ -34,11 +475,6 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, // Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } else { - NodeAttrHelper node_helper(node_unit); - const float default_epsilon = 1e-05f; - const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. - ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon."); - const auto& inputs = node_unit.Inputs(); ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); @@ -56,11 +492,16 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector scale_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[1].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic scale."); ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); std::vector bias_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[2].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic bias."); + ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, "QNN BatchNorm input 2 (bias) must have 1D shape [channel]."); @@ -68,13 +509,15 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic mean."); std::vector var_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic var."); ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); } @@ -82,6 +525,134 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + ORT_UNUSED_PARAMETER(logger); + + const auto& inputs = node_unit.Inputs(); + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + // + // Input 0 + // + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + // + // Input 1: scale + // Input 2: bias + // QNN only accept 3 input. We need to first combine mean and variance into scale and bias. + // + { + const std::string& scale_name = inputs[1].node_arg.Name(); + const std::string& bias_name = inputs[2].node_arg.Name(); + OnnxInputInfo var_info = {}; + OnnxInputInfo mean_info = {}; + OnnxInputInfo scale_info = {}; + OnnxInputInfo bias_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], scale_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], bias_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[3], mean_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[4], var_info)); + + // scale, bias, mean, and var must be initializers + ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); + ORT_RETURN_IF_NOT(bias_info.is_initializer, "bias must be initializers"); + ORT_RETURN_IF_NOT(mean_info.is_initializer, "mean must be initializers"); + ORT_RETURN_IF_NOT(var_info.is_initializer, "var must be initializers"); + + std::vector scale_unpacked_tensor; + std::vector bias_unpacked_tensor; + std::vector var_unpacked_tensor; + std::vector mean_unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*scale_info.initializer_tensor, scale_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*bias_info.initializer_tensor, bias_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*mean_info.initializer_tensor, mean_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*var_info.initializer_tensor, var_unpacked_tensor)); + + std::vector mean_double_tensor; + std::vector std_double_tensor; + std::vector scale_double_tensor; + std::vector bias_double_tensor; + + NodeAttrHelper node_helper(node_unit); + const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. + + double scale_rmax = std::numeric_limits::min(); + double scale_rmin = std::numeric_limits::max(); + double bias_rmax = std::numeric_limits::min(); + double bias_rmin = std::numeric_limits::max(); + + // Calculate and convert new scale, new bias, mean and std to double array (may be dequantized) + ORT_RETURN_IF_ERROR(PreprocessMean(mean_info, + is_npu_backend, + mean_unpacked_tensor.data(), + mean_unpacked_tensor.size(), + mean_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessStd(var_info, + is_npu_backend, + var_unpacked_tensor.data(), + var_unpacked_tensor.size(), + epsilon, + std_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessScale(scale_info, + is_npu_backend, + scale_unpacked_tensor.data(), + scale_unpacked_tensor.size(), + std_double_tensor, + scale_rmax, + scale_rmin, + scale_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessBias(bias_info, + is_npu_backend, + bias_unpacked_tensor.data(), + bias_unpacked_tensor.size(), + scale_double_tensor, + mean_double_tensor, + bias_rmax, + bias_rmin, + bias_double_tensor)); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { + std::vector scale_raw_tensor; + Qnn_QuantizeParams_t scale_quant_param = scale_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(scale_info, + is_npu_backend, + scale_double_tensor, + scale_rmax, + scale_rmin, + scale_quant_param, + scale_raw_tensor)); + Qnn_TensorType_t scale_tensor_type = GetInputTensorType(qnn_model_wrapper, scale_name); + QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, scale_quant_param, + std::move(scale_info.shape), std::move(scale_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(scale_name); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(bias_name)) { + std::vector bias_raw_tensor; + Qnn_QuantizeParams_t bias_quant_param = bias_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(bias_info, + is_npu_backend, + bias_double_tensor, + bias_rmax, + bias_rmin, + bias_quant_param, + bias_raw_tensor)); + Qnn_TensorType_t bias_tensor_type = GetInputTensorType(qnn_model_wrapper, bias_name); + QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, bias_quant_param, + std::move(bias_info.shape), std::move(bias_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(bias_name); + } + + return Status::OK(); +} + void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 2dfdfffe5fa54..fc8c5c357682c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -202,16 +202,8 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap // Qnn format is begin_0, end_0, begin_1, end_1, ... ReArranagePads(pad_amount); - std::vector pad_amount_dim{static_cast(pad_amount.size() / 2), static_cast(2)}; - QnnParamWrapper multiples_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), - std::move(pad_amount)); - param_tensor_names.push_back(multiples_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(multiples_param)); - - // Process optional input constant_value - if (node_unit.Inputs().size() > 2) { - ORT_RETURN_IF_ERROR(ProcessConstantValue(qnn_model_wrapper, param_tensor_names, node_unit, inputs[2])); - } // constant_value + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); NodeAttrHelper node_helper(node_unit); std::string mode = node_helper.Get("mode", "constant"); @@ -220,6 +212,10 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap if ("constant" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_PAD_SCHEME_CONSTANT; } else if ("reflect" == mode) { + for (size_t i = 0; i < input_shape.size(); i++) { + ORT_RETURN_IF(pad_amount[i * 2] > input_shape[i] - 1 || pad_amount[(i * 2) + 1] > input_shape[i] - 1, + "Pad amount should not be greater than shape(input[0])[i] - 1"); + } mode_qnn_scalar.uint32Value = QNN_OP_PAD_SCHEME_MIRROR_REFLECT; } else if ("edge" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_PAD_SCHEME_EDGE; @@ -227,10 +223,21 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad mode only support constant."); } + std::vector pad_amount_dim{static_cast(pad_amount.size() / 2), static_cast(2)}; QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_SCHEME, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(mode_param)); + QnnParamWrapper multiples_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_PAD_AMOUNT, + std::move(pad_amount_dim), std::move(pad_amount)); + param_tensor_names.push_back(multiples_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(multiples_param)); + + // Process optional input constant_value + if (node_unit.Inputs().size() > 2) { + ORT_RETURN_IF_ERROR(ProcessConstantValue(qnn_model_wrapper, param_tensor_names, node_unit, inputs[2])); + } // constant_value + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), diff --git a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh index 5830c9dd0bf27..f87b436d04a17 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh @@ -58,7 +58,7 @@ auto GetCKSoftmaxTypeStringAndOps() { auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, alpha, beta, params->input, params->output, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh index 86d023886cfaf..2518f45e0995e 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -61,7 +61,7 @@ auto GetCKGemmTypeStringAndOps() { params->lda, params->ldb, params->ldc, nop, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -164,7 +164,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); auto nop = Nop{}; auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, @@ -174,7 +174,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() { params->batch, nop, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index d5f9de26ada22..b9c0cdcc1c341 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -221,7 +221,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != HIPBLAS_STATUS_SUCCESS, - "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported (", params->Signature(), ")"); + "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported"); IAllocatorUniquePtr workspace_buffer; if (workspace_size > 0) { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h index 8e894e63c5de1..a391d1af8868c 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h @@ -168,8 +168,7 @@ auto GetRocBlasGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; @@ -238,8 +237,7 @@ auto GetRocBlasBatchedGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; @@ -308,8 +306,7 @@ auto GetRocBlasStridedBatchedGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index f8b23b6bc3b2d..9b237a40e3073 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -365,6 +365,46 @@ std::unique_lock TensorrtExecutionProvider::GetApiLock() const { return std::unique_lock(singleton); } +Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, + std::vector& shape_values, + nvinfer1::ICudaEngine* trt_engine, + int binding_index, + cudaStream_t stream) { + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + nvinfer1::Dims dims = trt_engine->getBindingDimensions(static_cast(binding_index)); + int nb_dims = dims.nbDims; + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + shape_values.resize(shape_size, 1); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = input[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto input = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + for (int j = 0; j < shape_size; ++j) { + shape_values[j] = static_cast(input[j]); + } + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); + } + } + return Status::OK(); +} + /* * Get the shape of "shape tensor" input */ diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 26c739e9a1ce1..02a3d16b5b64f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -26,11 +26,7 @@ WebNNExecutionProvider::WebNNExecutionProvider( ORT_THROW("Failed to get ml from navigator."); } emscripten::val context_options = emscripten::val::object(); - // Currently WebNN implementation in Chromium temporarily reuses the MLContextOptions - // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType - // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at - // https://github.com/webmachinelearning/webnn/issues/302. - context_options.set("devicePreference", emscripten::val(webnn_device_flags)); + context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; diff --git a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc index ea5c75a955cc4..8e7e228f974e6 100644 --- a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc +++ b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc @@ -7,22 +7,22 @@ #include "core/common/common.h" #include "core/framework/op_node_proto_helper.h" -#include "core/graph/graph_viewer.h" #include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" #include "core/providers/common.h" #include "core/providers/cpu/nn/pool_attributes.h" -#include "core/providers/xnnpack/detail/utils.h" #include "core/providers/shared/node_unit/node_unit.h" +#include "core/providers/xnnpack/detail/utils.h" // each operator provides a helper to check if supported -#include "core/providers/xnnpack/nn/conv.h" -#include "core/providers/xnnpack/nn/conv_transpose.h" -#include "core/providers/xnnpack/nn/max_pool.h" #include "core/providers/xnnpack/math/gemm.h" #include "core/providers/xnnpack/math/matmul.h" +#include "core/providers/xnnpack/math/softmax.h" #include "core/providers/xnnpack/nn/average_pool.h" -#include "core/providers/xnnpack/nn/resize.h" -#include "core/providers/xnnpack/nn/softmax.h" +#include "core/providers/xnnpack/nn/conv.h" +#include "core/providers/xnnpack/nn/conv_transpose.h" +#include "core/providers/xnnpack/nn/max_pool.h" +#include "core/providers/xnnpack/tensor/resize.h" namespace onnxruntime { namespace xnnpack { diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index baca4eef537d7..1a32612981120 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -25,7 +25,7 @@ const char* OpTypeToString(OpComputeType opCtype) { case op_compute_type_fp16: return "fp16"; case op_compute_type_qs8_per_channel: - return "qc8"; + return "qs8_qc8w"; case op_compute_type_qs8: return "qs8"; case op_compute_type_qu8: diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index 24c233e2415ca..f7b736b0ff903 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -78,7 +78,7 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra return supported; } -Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info) { +Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*enable_caches*/ true) { const auto& node{Node()}; info.GetAttrOrDefault("alpha", &alpha_, 1.f); @@ -146,14 +146,9 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride, B_->Data(), // const float* kernel, bias_Data, // const float* bias, - output_min, - output_max, + output_min, output_max, flags, -#ifdef XNN_CACHE_ENABLE - &xnn_caches_, -#else - 0, -#endif + GetCodeCache(), GetWeightsCache(), &p); if (status != xnn_status_success) { @@ -165,20 +160,25 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, } Status Gemm::Compute(OpKernelContext* context) const { - pthreadpool_t t_pool = GetThreadPool(); + pthreadpool_t threadpool = GetThreadPool(); const auto* A = context->Input(0); auto Y = context->Output(0, {M_, N_}); // if input is empty tensor, return as nothing need to be calculated and we've set the shape for the output - if (M_ == 0 || N_ == 0) + if (M_ == 0 || N_ == 0) { return Status::OK(); + } + + xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), + // Number of rows to multiply + trans_A_ == CblasNoTrans ? M_ : K_, + threadpool); + + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status); + } - xnn_status status = xnn_setup_fully_connected_nc_f32( - op0_.get(), - trans_A_ == CblasNoTrans ? M_ : K_, // Number of rows to multiply - A->Data(), - Y->MutableData(), - t_pool); + status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data(), Y->MutableData()); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status); @@ -192,7 +192,15 @@ Status Gemm::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 12, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 8, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 9, 10, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 11, 12, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.h b/onnxruntime/core/providers/xnnpack/math/gemm.h index 9191ba204bc25..6d11a8531c20f 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.h +++ b/onnxruntime/core/providers/xnnpack/math/gemm.h @@ -41,14 +41,6 @@ class Gemm : protected GemmBase, public XnnpackKernel { float alpha_; float beta_; - -#ifdef XNN_CACHE_ENABLE -#if XNN_PLATFORM_JIT - xnn_code_cache code_cache_; -#endif - xnn_caches xnn_caches_ = {0, 0}; - xnn_weights_cache weights_cache_; -#endif }; } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index fc7335c79b603..e90aa11c9d087 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -62,7 +62,7 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g return supported; } -MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info) {} +MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ true) {} Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -99,9 +99,11 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, output_max, flags, #ifdef XNN_CACHE_ENABLE - &xnn_caches_, + GetCodeCache(), + GetWeightsCache(), #else - 0, + nullptr, + nullptr, #endif &p); @@ -116,7 +118,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, Status MatMul::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); - pthreadpool_t t_pool = GetThreadPool(); + pthreadpool_t threadpool = GetThreadPool(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_)); Tensor* y = ctx->Output(0, helper.OutputShape()); @@ -126,13 +128,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const { auto* y_data = y->MutableData(); - xnn_status status = xnn_setup_fully_connected_nc_f32( - op0_.get(), - a->Shape()[0], - a->Data(), - y_data, - t_pool); + xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool); + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status); + } + status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data(), y_data); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status); } @@ -144,7 +145,11 @@ Status MatMul::Compute(OpKernelContext* ctx) const { return Status::OK(); } -ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 8, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 9, 12, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.h b/onnxruntime/core/providers/xnnpack/math/matmul.h index f4ed92c6146fb..b76e42c4d3729 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.h +++ b/onnxruntime/core/providers/xnnpack/math/matmul.h @@ -32,14 +32,6 @@ class MatMul : public XnnpackKernel { AllocatorPtr myAlloc; XnnpackOperator op0_ = nullptr; - -#ifdef XNN_CACHE_ENABLE -#if XNN_PLATFORM_JIT - xnn_code_cache code_cache_; -#endif - xnn_caches xnn_caches_ = {0, 0}; - xnn_weights_cache weights_cache_; -#endif }; } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/nn/softmax.cc b/onnxruntime/core/providers/xnnpack/math/softmax.cc similarity index 80% rename from onnxruntime/core/providers/xnnpack/nn/softmax.cc rename to onnxruntime/core/providers/xnnpack/math/softmax.cc index bca84317ad891..87440b7814176 100644 --- a/onnxruntime/core/providers/xnnpack/nn/softmax.cc +++ b/onnxruntime/core/providers/xnnpack/math/softmax.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/xnnpack/nn/softmax.h" +#include "core/providers/xnnpack/math/softmax.h" #include @@ -25,6 +25,7 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph output_type != TensorTypeUint8) { break; } + // to ensure its output scale and zp are 1/256 and 0, otherwise xnnpack EP has to do extra requantization // idealy, QlinearSoftmax or QDQSoftmax will keep this output scale and zp, but we have to handle some // qdq models converted from other framework @@ -33,6 +34,7 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph if (fabs(q_scale.DataAsSpan()[0] - 1.0f / 256.0f) > 0.0001f) { break; } + if (zero_tensor) { Initializer q_zp(*zero_tensor, node_unit.ModelPath()); if (q_zp.DataAsSpan()[0] != 0) { @@ -57,6 +59,7 @@ bool Softmax::IsOnnxNodeSupported(const NodeUnit& node_unit, IsQuantSoftmaxSupported(node_unit, graph) == false) { return false; } + // use do {} while(false) so it's easier to set a breakpoint on the return do { // SoftMax has 1 input. @@ -133,6 +136,7 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} { ORT_ENFORCE(status.IsOK(), "opset must be existed in attributes of QlinearSoftmax"); opset_ = gsl::narrow_cast(opset); } + int64_t axis = -1; Status status = info.GetAttr("axis", &axis); // our op checker function has ensured that axis must be the last dim @@ -162,23 +166,22 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} { if (op_type_ == OpComputeType::op_compute_type_qu8) { // the order of input tensor, x,x_scale, x_zp, y_scale, y_zp OpQuantParam quant_param = ParseQuantParamForOp(info, x_dtype, 1); - xstatus = xnn_create_softmax_nc_qu8( - channels, - channels, - channels, - quant_param[0].first[0], // x_scale - quant_param[1].second, // y_zp - quant_param[1].first[0], // y_scale - 0, // flags, - &p); + xstatus = xnn_create_softmax_nc_qu8(channels, + channels, + channels, + quant_param[0].first[0], // x_scale + quant_param[1].second, // y_zp + quant_param[1].first[0], // y_scale + 0, // flags, + &p); } else if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_softmax_nc_f32( - channels, - channels, - channels, - 0, // flags, - &p); + xstatus = xnn_create_softmax_nc_f32(channels, + channels, + channels, + 0, // flags, + &p); } + ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_softmax_nc_", OpTypeToString(op_type_), " failed. Status:", xstatus); op0_.reset(p); @@ -194,39 +197,48 @@ Status Softmax::Compute(OpKernelContext* ctx) const { if (X_shape.Size() == 0) { return Status::OK(); } - pthreadpool_t t_pool = GetThreadPool(); + + pthreadpool_t threadpool = GetThreadPool(); const size_t N = X_shape.SizeToDimension(axis_); // const size_t D = X_shape.SizeFromDimension(axis_); // the step D is 1 xnn_status status = xnn_status_invalid_state; + + auto reshape_fn = op_type_ == OpComputeType::op_compute_type_qu8 ? xnn_reshape_softmax_nc_qu8 + : xnn_reshape_softmax_nc_f32; + status = reshape_fn(op0_.get(), N, threadpool); + + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_softmax_nc_", OpTypeToString(op_type_), + " returned ", status); + } + if (op_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_softmax_nc_qu8( - op0_.get(), - N, - X->Data(), - Y->MutableData(), - t_pool); + status = xnn_setup_softmax_nc_qu8(op0_.get(), X->Data(), Y->MutableData()); } else { - status = xnn_setup_softmax_nc_f32( - op0_.get(), - N, - X->Data(), - Y->MutableData(), - t_pool); + status = xnn_setup_softmax_nc_f32(op0_.get(), X->Data(), Y->MutableData()); } + if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_softmax_nc_", - OpTypeToString(op_type_), " returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_softmax_nc_", OpTypeToString(op_type_), + " returned ", status); } - status = xnn_run_operator(op0_.get(), t_pool); + + status = xnn_run_operator(op0_.get(), threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); } + return Status::OK(); } -ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 12, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 1, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Softmax); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Softmax, kOnnxDomain, 11, 12, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Softmax); + ONNX_OPERATOR_KERNEL_EX(Softmax, kOnnxDomain, 13, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Softmax); diff --git a/onnxruntime/core/providers/xnnpack/nn/softmax.h b/onnxruntime/core/providers/xnnpack/math/softmax.h similarity index 100% rename from onnxruntime/core/providers/xnnpack/nn/softmax.h rename to onnxruntime/core/providers/xnnpack/math/softmax.h diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index 767218fbfd20b..58c209a13cd0c 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -2,10 +2,13 @@ // Licensed under the MIT License. #include "core/providers/xnnpack/nn/average_pool.h" +#include + #include "core/common/status.h" #include "core/graph/graph.h" #include "core/providers/utils.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "core/providers/xnnpack/detail/utils.h" namespace onnxruntime { @@ -90,6 +93,10 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit, const auto& inputs = node_unit.Inputs(); // use do {} while(false) so it's easier to set a breakpoint on the return do { + if (node_unit.SinceVersion() < 7) { + break; + } + // AveragePool has 1 input. const auto& x_arg = inputs[0].node_arg; @@ -141,6 +148,11 @@ bool AveragePool::IsOnnxNodeSupported(const NodeUnit& node_unit, break; } + // need dilations to all be 1 + if (!pool_attrs.default_dilations) { + break; + } + supported = true; } while (false); @@ -221,24 +233,47 @@ Status AveragePool::Compute(OpKernelContext* context) const { return Status::OK(); } - pthreadpool_t t_pool = GetThreadPool(); - xnn_status status = xnn_status_invalid_state; + pthreadpool_t threadpool = GetThreadPool(); + + // setup allocator/automated dellocate for workspace + size_t workspace_size = 0; + size_t workspace_alignment = 0; + xnn_allocator* allocator = GetStoredAllocator().second; + auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; + + std::unique_ptr workspace(nullptr, deallocator); + + auto reshape_fn = (avgpool_type_ == OpComputeType::op_compute_type_fp32) + ? xnn_reshape_average_pooling2d_nhwc_f32 + : xnn_reshape_average_pooling2d_nhwc_qu8; + + auto status = reshape_fn(op0_.get(), N, H, W, + &workspace_size, &workspace_alignment, + /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, + threadpool); + + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_average_pooling2d_nhwc_", OpTypeToString(avgpool_type_), + " returned ", status); + } + + workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); + if (avgpool_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), N, H, W, - X.Data(), Y.MutableData(), - t_pool /*threadpool */); + status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(), + X.Data(), Y.MutableData()); + } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), N, H, W, - X.Data(), Y.MutableData(), - t_pool /*threadpool */); + status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(), + X.Data(), Y.MutableData()); } if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_average_pooling2d_nhwc_", - OpTypeToString(avgpool_type_), " returned ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_average_pooling2d_nhwc_", OpTypeToString(avgpool_type_), + " returned ", status); } - status = xnn_run_operator(op0_.get(), t_pool); + status = xnn_run_operator(op0_.get(), threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); } @@ -246,8 +281,26 @@ Status AveragePool::Compute(OpKernelContext* context) const { return Status::OK(); } +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + AveragePool, kMSInternalNHWCDomain, 7, 9, + kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + AveragePool); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + AveragePool, kMSInternalNHWCDomain, 10, 10, + kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + AveragePool); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + AveragePool, kMSInternalNHWCDomain, 11, 18, + kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + AveragePool); + ONNX_OPERATOR_KERNEL_EX( - AveragePool, kMSInternalNHWCDomain, 11, + AveragePool, kMSInternalNHWCDomain, 19, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), AveragePool); diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 0772dec59e30e..0cdb9c840aa2d 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -3,12 +3,13 @@ #include "conv.h" +#include "core/common/gsl.h" #include "core/common/inlined_containers_fwd.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "core/providers/xnnpack/detail/utils.h" -#include "core/framework/tensorprotoutils.h" -#include "core/common/gsl.h" namespace onnxruntime { namespace xnnpack { @@ -64,21 +65,48 @@ Status Conv::Compute(OpKernelContext* context) const { if (Y->Shape().Size() == 0) { return Status::OK(); } - pthreadpool_t t_pool = GetThreadPool(); - xnn_status status = xnn_status_invalid_state; + pthreadpool_t threadpool = GetThreadPool(); + + // setup allocator/automated dellocate for workspace + size_t workspace_size = 0; + size_t workspace_alignment = 0; + xnn_allocator* allocator = GetStoredAllocator().second; + auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; + std::unique_ptr workspace(nullptr, deallocator); + + auto reshape_fn = xnn_reshape_convolution2d_nhwc_f32; + if (conv_type_ == OpComputeType::op_compute_type_qs8) { + reshape_fn = xnn_reshape_convolution2d_nhwc_qs8; + } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { + reshape_fn = xnn_reshape_convolution2d_nhwc_qu8; + } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { + reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w; + } + + auto status = reshape_fn(op0_.get(), N, H, W, + &workspace_size, &workspace_alignment, + /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, + threadpool); + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_), + "returned ", status); + } + + workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); + if (conv_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), N, H, W, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qs8) { - status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), N, H, W, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_convolution2d_nhwc_qu8(op0_.get(), N, H, W, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_convolution2d_nhwc_qu8(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { - status = xnn_setup_convolution2d_nhwc_qc8(op0_.get(), N, H, W, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_convolution2d_nhwc_qs8_qc8w(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } if (status != xnn_status_success) { @@ -86,7 +114,7 @@ Status Conv::Compute(OpKernelContext* context) const { OpTypeToString(conv_type_), "returned ", status); } - status = xnn_run_operator(op0_.get(), t_pool); + status = xnn_run_operator(op0_.get(), threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); } @@ -94,6 +122,10 @@ Status Conv::Compute(OpKernelContext* context) const { return Status::OK(); } +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index b692f373ff4ce..d21014569234e 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -23,7 +23,8 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, const std::optional>& clip_min_max, const Tensor& Weight, const Tensor* Bias, XnnpackOperator& op_uptr, - xnn_caches_t caches_t, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, const OpQuantParam& quant_param, OpComputeType conv_type, bool is_transpose = false) { @@ -75,7 +76,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, C, M, // input channel stride, output channel stride Weight.Data(), B_data, foutput_min, foutput_max, flags, - caches_t, + code_cache, weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; @@ -99,7 +100,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - caches_t, + code_cache, weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8_per_channel) { auto* B_data = Bias ? Bias->Data() : nullptr; @@ -107,7 +108,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, const int8_t output_zero_point = quant_param[2].second; const int8_t output_min = xnn_u8s8_quantize(foutput_min, output_scale, output_zero_point); const int8_t output_max = xnn_u8s8_quantize(foutput_max, output_scale, output_zero_point); - status = xnn_create_convolution2d_nhwc_qc8( + status = xnn_create_convolution2d_nhwc_qs8_qc8w( input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, kernel_height, kernel_width, subsampling_height, subsampling_width, @@ -123,7 +124,7 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - caches_t, + code_cache, weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qu8) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -148,15 +149,17 @@ Status CreateXnnpackKernel(const ConvAttributes* conv_attrs_ptr, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - caches_t, + code_cache, weights_cache, &p); } + if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create xnnpack kernel. xnn_create_", is_transpose ? "deconvolution2d" : "convolution2d", "_nhwc_", OpTypeToString(conv_type), " returned ", status); } + op_uptr.reset(p); return Status::OK(); } @@ -296,6 +299,11 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& const onnxruntime::Node& node = node_unit.GetNode(); // use do {} while(false) so it's easier to set a breakpoint on the return do { + // Internal NHWC domain starts at opset 11 + if (node_unit.SinceVersion() < 11) { + break; + } + // Conv has at least 2 inputs. const auto& inputs = node_unit.Inputs(); const auto& x_arg = inputs[0].node_arg; @@ -367,7 +375,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& } ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) - : XnnpackKernel(info), + : XnnpackKernel(info, /*enable_caches*/ true), conv_attrs_(info), conv_transpose_attrs_(info), convbase_attrs_ref_(is_transpose ? conv_transpose_attrs_ : conv_attrs_), @@ -383,16 +391,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) } } } - // xnnpack cache_code, unfortunately these definitions are only available in xnnpack/cache.h, -#ifdef XNN_CACHE_ENABLE -#if XNN_PLATFORM_JIT - xnn_init_code_cache(&code_cache_); - xnn_caches_.code_cache = &code_cache_; -#endif - // TODO(Jicwen) enable weight-cache and code-cache - xnn_init_weights_cache(&weights_cache_); - xnn_caches_.weights_cache = &weights_cache_; -#endif + const auto& node{Node()}; const auto& input_defs = node.InputDefs(); const NodeArg& X = *input_defs[0]; @@ -477,11 +476,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) Status ConvBase::CreateKernel() { auto ret = CreateXnnpackKernel(&convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_, B_, op0_, -#ifdef XNN_CACHE_ENABLE - &xnn_caches_, -#else - 0, -#endif + GetCodeCache(), GetWeightsCache(), quant_param_, conv_type_, is_transpose_); return ret; } diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.h b/onnxruntime/core/providers/xnnpack/nn/conv_base.h index d3501a56ea24c..53ad51378c6be 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.h @@ -39,14 +39,6 @@ class ConvBase : public XnnpackKernel { std::optional> clip_min_max_; XnnpackOperator op0_ = nullptr; - // we can't have the definition here because we can't import xnnpack/cache.h -#ifdef XNN_CACHE_ENABLE -#if XNN_PLATFORM_JIT - xnn_code_cache code_cache_; -#endif - xnn_caches xnn_caches_ = {0, 0}; - xnn_weights_cache weights_cache_; -#endif OpQuantParam quant_param_; OpComputeType conv_type_ = OpComputeType::op_compute_type_invalid; }; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 61d8f7f488547..8698c0739509d 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -81,29 +81,34 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { if (Y->Shape().Size() == 0) { return Status::OK(); } - pthreadpool_t t_pool = GetThreadPool(); + pthreadpool_t threadpool = GetThreadPool(); auto output_pad_0 = gsl::narrow_cast(conv_transpose_attrs_.output_padding[0]); auto output_pad_1 = gsl::narrow_cast(conv_transpose_attrs_.output_padding[1]); xnn_status status = xnn_status_invalid_state; + + auto reshape_fn = xnn_reshape_deconvolution2d_nhwc_f32; + if (conv_type_ == OpComputeType::op_compute_type_qs8) { + reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8; + } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { + reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8; + } + + status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1, + /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, + threadpool); + + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_deconvolution2d_nhwc_", + OpTypeToString(conv_type_), " returned ", status); + } + if (conv_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_deconvolution2d_nhwc_f32( - op0_.get(), N, H, W, - output_pad_0, - output_pad_1, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_deconvolution2d_nhwc_f32(op0_.get(), X.Data(), Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qs8) { - status = xnn_setup_deconvolution2d_nhwc_qs8( - op0_.get(), N, H, W, - output_pad_0, - output_pad_1, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data(), Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_deconvolution2d_nhwc_qu8( - op0_.get(), N, H, W, - output_pad_0, - output_pad_1, X.Data(), Y->MutableData(), - t_pool /*threadpool*/); + status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -111,7 +116,7 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { OpTypeToString(conv_type_), " returned ", status); } - status = xnn_run_operator(op0_.get(), t_pool); + status = xnn_run_operator(op0_.get(), threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); } diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index de6dd68bba9c3..2ef9f97f77b14 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -41,6 +41,10 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit, const onnxruntime::Node& node = node_unit.GetNode(); // use do {} while(false) so it's easier to set a breakpoint on the return do { + if (node_unit.SinceVersion() < 8) { + break; + } + // MaxPool has 1 input. auto input_defs = node.InputDefs(); const auto& x_arg = *input_defs[0]; @@ -220,20 +224,29 @@ Status MaxPool::Compute(OpKernelContext* context) const { return Status::OK(); } - pthreadpool_t t_pool = GetThreadPool(); - xnn_status status = xnn_status_invalid_state; + pthreadpool_t threadpool = GetThreadPool(); + + auto reshape_fn = xnn_reshape_max_pooling2d_nhwc_f32; + if (maxpool_type_ == OpComputeType::op_compute_type_qu8) + reshape_fn = xnn_reshape_max_pooling2d_nhwc_u8; + else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) { + reshape_fn = xnn_reshape_max_pooling2d_nhwc_s8; + } + + auto status = reshape_fn(op0_.get(), N, H, W, + /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, + threadpool); + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_max_pooling2d_nhwc_", + OpTypeToString(maxpool_type_), " returned ", status); + } + if (maxpool_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), N, H, W, - X.Data(), Y->MutableData(), - t_pool /*threadpool */); + status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), X.Data(), Y->MutableData()); } else if (maxpool_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), N, H, W, - X.Data(), Y->MutableData(), - t_pool /*threadpool */); + status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), X.Data(), Y->MutableData()); } else { - status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), N, H, W, - X.Data(), Y->MutableData(), - t_pool /*threadpool */); + status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), X.Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -241,7 +254,7 @@ Status MaxPool::Compute(OpKernelContext* context) const { OpTypeToString(maxpool_type_), " returned ", status); } - status = xnn_run_operator(op0_.get(), t_pool); + status = xnn_run_operator(op0_.get(), threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); } @@ -249,12 +262,24 @@ Status MaxPool::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), - MaxPool); +ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + MaxPool); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + MaxPool); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + MaxPool); + ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionProvider, KernelDefBuilder() .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc similarity index 67% rename from onnxruntime/core/providers/xnnpack/nn/resize.cc rename to onnxruntime/core/providers/xnnpack/tensor/resize.cc index 76c6b6acbfe32..0c9e2e9fc17a2 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/xnnpack/nn/resize.h" +#include "core/providers/xnnpack/tensor/resize.h" #include #include @@ -10,6 +10,7 @@ #include "core/common/inlined_containers_fwd.h" #include "core/framework/op_kernel.h" #include "core/optimizer/initializer.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { namespace xnnpack { @@ -18,26 +19,67 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph_viewer) { bool supported = false; do { + if (node_unit.SinceVersion() < 10) { + break; + } + // Resize has 1-4 input. const auto& inputs = node_unit.Inputs(); const auto& x_arg = inputs[0].node_arg; const auto* x_type = x_arg.TypeAsProto(); - if (x_type == nullptr || - (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + if (x_type == nullptr || (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && + x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { break; } const auto* x_shape = x_arg.Shape(); - //'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 (NCHW) or - // 4-D input with outermost and innermost scales as 1 (NHWC) - // but we just support 4-d tensor for now, and the channel must be known. + + // 'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 (NCHW) can be supported. + // we only support 4-d tensor for now, and the channel must be known. + // we assume the input in NCHW for this test. if (!x_shape || x_shape->dim_size() != 4 || x_shape->dim(1).dim_value() <= 0) { break; } + // validate it is in fact NCHW + // + // opset 10 had `scales` as input 1 and no sizes. later opsets added roi as input 1 followed by scales and sizes. + auto opset_version = node_unit.SinceVersion(); + size_t scale_idx = opset_version == 10 ? 1 : 2; + size_t size_idx = 3; + + // onnx shape inferencing validates that one and not both of sizes and scales are provided + const auto* scale_tensor = inputs.size() >= scale_idx + 1 + ? graph_viewer.GetConstantInitializer(inputs[scale_idx].node_arg.Name(), true) + : nullptr; + const auto* size_tensor = opset_version > 10 && inputs.size() >= size_idx + 1 + ? graph_viewer.GetConstantInitializer(inputs[size_idx].node_arg.Name(), true) + : nullptr; + + // if both scales and sizes are nullptr the one that was provided was not a constant initializer + if (!scale_tensor && !size_tensor) { + break; + } + + // check the scale for the second dim is 1 or the size of the second dim matches the input shape. + // if not, it is not the C dim as a Resize will not change the number of channels. + InlinedVector scale(4, 1.0F); + if (scale_tensor) { + const Initializer scale_val(*scale_tensor, node_unit.ModelPath()); + if (scale_val.DataAsSpan()[1] != 1.0F) { + break; + } + } + + if (size_tensor) { + const Initializer size_val(*size_tensor, node_unit.ModelPath()); + if (size_val.DataAsSpan()[1] != x_shape->dim(1).dim_value()) { + break; + } + } + const auto* output_shape = node_unit.Outputs()[0].node_arg.Shape(); bool length_resized_compatible_pytorch_half_pixel = true; // when length_resized > 1, there is no difference between pytorch_half_pixel and half_pixel @@ -48,18 +90,11 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, // if coordinate_transformation_mode is "pytorch_half_pixel", // x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0 // - if (output_shape->dim(2).dim_value() <= 1 || output_shape->dim(1).dim_value() <= 1) { + if (output_shape->dim(2).dim_value() <= 1 || output_shape->dim(3).dim_value() <= 1) { + // we don't know the output H or W so we don't know if it will be compatible length_resized_compatible_pytorch_half_pixel = false; } - // Refer to onnxruntime/core/providers/cpu/tensor/upsamplebase.h, - size_t scale_idx = 2; - size_t size_idx = 3; - auto opset_version = node_unit.SinceVersion(); - if (opset_version == 10) { - scale_idx = 1; - } - ProtoHelperNodeContext nc(node_unit.GetNode()); OpNodeProtoHelper info(&nc); @@ -78,6 +113,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, std::vector axes; if (info.GetAttrs("axes", axes).IsOK() && axes.size() > 0) { + // TODO: We should be able to handle this if required break; } @@ -95,9 +131,10 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, // Coordinate transformation mode attr was introduced in version 11. // before that asymmetric mode was the only available transformation mode std::string coordinate_transform_mode_name = - opset_version > 10 - ? info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel") - : "asymmetric"; + opset_version > 10 ? info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel") + : "asymmetric"; + + // TODO: Opset 19 added half_pixel_symmetric. Need to see if that can be supported. if (coordinate_transform_mode_name != "asymmetric" && coordinate_transform_mode_name != "half_pixel" && @@ -106,59 +143,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, break; } - auto exclude_outside = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; - if (exclude_outside) { - break; - } - - // roi only takes effect when coordinate_transformation_mode is "tf_crop_and_resize" - - // size or scales shouldnt't be provided in the same time but should at least be provided one of them - const auto* scale_tensor = inputs.size() >= scale_idx + 1 - ? graph_viewer.GetConstantInitializer(inputs[scale_idx].node_arg.Name(), true) - : nullptr; - const auto* size_tensor = inputs.size() >= size_idx + 1 - ? graph_viewer.GetConstantInitializer(inputs[size_idx].node_arg.Name(), true) - : nullptr; - - bool has_size = false; - bool has_scale = false; - InlinedVector scale(4, 1.0F); - if (scale_tensor) { - const Initializer scale_val(*scale_tensor, node_unit.ModelPath()); - auto scale_span = scale_val.DataAsSpan(); - if (scale_span.size() == 4) { - has_scale = true; - std::copy(scale_span.begin(), scale_span.end(), scale.begin()); - } - } - - if (size_tensor) { - auto input_shape = utils::GetTensorShapeFromTensorShapeProto(*x_shape); - const Initializer size_val(*size_tensor, node_unit.ModelPath()); - - auto size_span = size_val.DataAsSpan(); - if (size_span.size() == 4) { - has_size = true; - scale = {size_span[0] / static_cast(input_shape[0]), - size_span[1] / static_cast(input_shape[1]), - size_span[2] / static_cast(input_shape[2]), - size_span[3] / static_cast(input_shape[3])}; - } - } - - if ((has_size && has_scale) || (!has_size && !has_scale)) { - break; - } - - if (scale[0] != 1.0F || (scale[1] != 1.0F && scale[3] != 1.0F)) { - break; - } - - // only support xnn_create_resize_bilinear2d_nchw_f32 - const bool is_NHWC = scale[3] == 1.0F; - if (!is_NHWC && (x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || - x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + if (info.GetAttrOrDefault("exclude_outside", 0) != 0) { break; } @@ -210,8 +195,7 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf } } - is_NHWC_ = scales_[3] == 1.0F; - int64_t channels = x_shape->dim(is_NHWC_ ? 3 : 1).dim_value(); + int64_t channels = x_shape->dim(3).dim_value(); uint32_t flags = 0; ORT_ENFORCE(mode_ == UpsampleMode::LINEAR, "only support bilinear resize"); @@ -225,18 +209,16 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf xnn_status xstatus = xnn_status_invalid_state; struct xnn_operator* p = nullptr; if (op_type_ == OpComputeType::op_compute_type_fp32) { - auto create_func = is_NHWC_ ? xnn_create_resize_bilinear2d_nhwc_f32 : xnn_create_resize_bilinear2d_nchw_f32; - xstatus = create_func( - channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_f32(channels, channels, channels, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - xstatus = xnn_create_resize_bilinear2d_nhwc_u8( - channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_u8(channels, channels, channels, flags, &p); } else { - xstatus = xnn_create_resize_bilinear2d_nhwc_s8( - channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_s8(channels, channels, channels, flags, &p); } - ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", - OpTypeToString(op_type_), " failed. Status:", xstatus); + + ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:", + xstatus); + op0_.reset(p); } @@ -245,48 +227,56 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, const TensorShapeVector& output_dims) const { const auto& X_shape = input->Shape(); auto N = X_shape[0]; - auto H = is_NHWC_ ? X_shape[1] : X_shape[2]; - auto W = is_NHWC_ ? X_shape[2] : X_shape[3]; + auto H = X_shape[1]; + auto W = X_shape[2]; Tensor* output = ctx->Output(0, TensorShape(output_dims)); - pthreadpool_t t_pool = GetThreadPool(); - xnn_status status = xnn_status_invalid_state; + pthreadpool_t threadpool = GetThreadPool(); + + // setup allocator/automated dellocate for workspace + size_t workspace_size = 0; + size_t workspace_alignment = 0; + xnn_allocator* allocator = GetStoredAllocator().second; + auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; + std::unique_ptr workspace(nullptr, deallocator); + + auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32; + if (op_type_ == OpComputeType::op_compute_type_qu8) { + reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8; + } else if (op_type_ == OpComputeType::op_compute_type_qs8) { + reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; + } + + auto status = reshape_fn(op0_.get(), N, H, W, output_dims[1], output_dims[2], + &workspace_size, &workspace_alignment, threadpool); + if (status != xnn_status_success) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), + " returned ", status); + } + + workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); + if (op_type_ == OpComputeType::op_compute_type_fp32) { - auto oH = is_NHWC_ ? output_dims[1] : output_dims[2]; - auto oW = is_NHWC_ ? output_dims[2] : output_dims[3]; - auto setup_func = is_NHWC_ ? xnn_setup_resize_bilinear2d_nhwc_f32 : xnn_setup_resize_bilinear2d_nchw_f32; - status = setup_func( - op0_.get(), - N, - H, W, oH, oW, - input->Data(), - output->MutableData(), - t_pool); + status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_resize_bilinear2d_nhwc_u8( - op0_.get(), - N, - H, W, output_dims[1], output_dims[2], - input->Data(), - output->MutableData(), - t_pool); + status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else { - status = xnn_setup_resize_bilinear2d_nhwc_s8( - op0_.get(), - N, - H, W, output_dims[1], output_dims[2], - input->Data(), - output->MutableData(), - t_pool); + status = xnn_setup_resize_bilinear2d_nhwc_s8(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } + if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " returned ", status); } - status = xnn_run_operator(op0_.get(), t_pool); + + status = xnn_run_operator(op0_.get(), threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); } + return Status::OK(); } @@ -315,29 +305,29 @@ Status Resize::Compute(OpKernelContext* ctx) const { return ComputeInternal(ctx, X, output_shape); } -ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 10, 10, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 10, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 11, 12, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 11, 12, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 13, 17, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kMSInternalNHWCDomain, 18, 18, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider, +ONNX_OPERATOR_KERNEL_EX(Resize, kMSInternalNHWCDomain, 19, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.h b/onnxruntime/core/providers/xnnpack/tensor/resize.h similarity index 98% rename from onnxruntime/core/providers/xnnpack/nn/resize.h rename to onnxruntime/core/providers/xnnpack/tensor/resize.h index 4975510ee7db4..06ff1bdb61f59 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.h +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.h @@ -31,7 +31,6 @@ class Resize : public UpsampleBase, public XnnpackKernel { const TensorShapeVector& output_dims) const; private: - bool is_NHWC_; XnnpackOperator op0_; TensorShapeVector output_dims_; OpComputeType op_type_ = OpComputeType::op_compute_type_invalid; diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 494c718cde081..a2a776df439e4 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -27,88 +27,117 @@ KernelCreateInfo BuildKernelCreateInfo() { return info; } -#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ - BuildKernelCreateInfo< \ - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, End, Op)> +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op, Domain) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)> -#define KERNEL_CREATE_INFO(Start, Op) \ - BuildKernelCreateInfo< \ - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, Op)> +#define KERNEL_CREATE_INFO(Start, Op, Domain) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)> -#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ - BuildKernelCreateInfo< \ - ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, Start, type, Op)> +#define KERNEL_CREATE_INFO_TYPED(Start, Type, Op, Domain) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)> +// Layout sensitive operators in NHWC domain +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearConvTranspose); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearConvTranspose); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, QLinearAveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, - kDynamicDomainByCreate, 1, QLinearSoftmax); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 12, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 19, Resize); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +// ONNX operators +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Gemm); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 8, MatMul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 9, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, MatMul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax); + +// Internal domain +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kDynamicDomainByCreate, 1, QLinearSoftmax); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing - KERNEL_CREATE_INFO(11, Conv), - KERNEL_CREATE_INFO(11, ConvTranspose), - KERNEL_CREATE_INFO_VERSIONED(1, 10, ConvTranspose), - KERNEL_CREATE_INFO(1, QLinearConvTranspose), - KERNEL_CREATE_INFO_VERSIONED(11, 11, MaxPool), - KERNEL_CREATE_INFO(12, MaxPool), - KERNEL_CREATE_INFO(11, AveragePool), + // layout sensitive. nodes will be moved to kMSInternalNHWCDomain by layout transformation + KERNEL_CREATE_INFO_VERSIONED(7, 9, AveragePool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(10, 10, AveragePool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(11, 18, AveragePool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO(19, AveragePool, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO_VERSIONED(1, 10, Conv, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO(11, Conv, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO_VERSIONED(1, 10, ConvTranspose, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO(11, ConvTranspose, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO_VERSIONED(8, 9, MaxPool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(10, 10, MaxPool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(11, 11, MaxPool, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO(12, MaxPool, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO(1, QLinearConvTranspose, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO_VERSIONED(10, 10, Resize, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Resize, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(13, 17, Resize, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED(18, 18, Resize, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO(19, Resize, kMSInternalNHWCDomain), + // layout insensitive, use ONNX-domain directly - BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Softmax)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 12, Gemm)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, Gemm)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, MatMul)>, - BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, MatMul)>, + KERNEL_CREATE_INFO_VERSIONED(1, 10, Softmax, kOnnxDomain), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Softmax, kOnnxDomain), + KERNEL_CREATE_INFO(13, Softmax, kOnnxDomain), + + KERNEL_CREATE_INFO_VERSIONED(7, 8, Gemm, kOnnxDomain), + KERNEL_CREATE_INFO_VERSIONED(9, 10, Gemm, kOnnxDomain), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Gemm, kOnnxDomain), + KERNEL_CREATE_INFO(13, Gemm, kOnnxDomain), + + KERNEL_CREATE_INFO_VERSIONED(1, 8, MatMul, kOnnxDomain), + KERNEL_CREATE_INFO_VERSIONED(9, 12, MatMul, kOnnxDomain), + KERNEL_CREATE_INFO(13, MatMul, kOnnxDomain), // quantization op - KERNEL_CREATE_INFO_TYPED(10, uint8_t, QLinearConv), - KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv), - KERNEL_CREATE_INFO(1, QLinearAveragePool), - BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kDynamicDomainByCreate, 1, QLinearSoftmax)>, + KERNEL_CREATE_INFO(1, QLinearAveragePool, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO_TYPED(10, uint8_t, QLinearConv, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain), + + KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate), }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_init.cc b/onnxruntime/core/providers/xnnpack/xnnpack_init.cc index 27634a8b7090c..c3aa1d987c194 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_init.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_init.cc @@ -26,13 +26,15 @@ void xnn_deallocate(void* context, void* pointer) { } void* xnn_aligned_allocate(void* context, size_t alignment, size_t size) { + if (size == 0) + return nullptr; + #if defined(__wasm__) && !defined(__wasm_relaxed_simd__) && !defined(__wasm_simd128__) ORT_ENFORCE(alignment <= 2 * sizeof(void*)); return xnn_allocate(context, size); #else void* ptr = xnn_allocate(context, size); - ORT_ENFORCE((int64_t(ptr) & (alignment - 1)) == 0, - " xnnpack wants to allocate a space with ", alignment, "bytes aligned. But it's not satisfied"); + ORT_ENFORCE((int64_t(ptr) & (alignment - 1)) == 0, "xnnpack allocation was not aligned to ", alignment, " bytes."); // if ptr is not aligned, we have to find a way to return a aligned ptr and store the original ptr return ptr; #endif diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_init.h b/onnxruntime/core/providers/xnnpack/xnnpack_init.h index d309edd0c3a4e..a1e64bf6046b2 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_init.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_init.h @@ -5,6 +5,47 @@ struct xnn_allocator; namespace onnxruntime { namespace xnnpack { +// copy #define logic from XNNPACK src/xnnpack/common.h to determine workspace alignment +#if defined(__APPLE__) +#include +#endif + +#if defined(__i386__) || defined(__i486__) || defined(__i586__) || defined(__i686__) || defined(_M_IX86) +#define XNN_ARCH_X86 1 +#else +#define XNN_ARCH_X86 0 +#endif + +#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) && !defined(_M_ARM64EC) +#define XNN_ARCH_X86_64 1 +#else +#define XNN_ARCH_X86_64 0 +#endif + +#if defined(__wasm__) && !defined(__wasm_relaxed_simd__) && !defined(__wasm_simd128__) +#define XNN_ARCH_WASM 1 +#else +#define XNN_ARCH_WASM 0 +#endif + +#if defined(__ANDROID__) || (defined(__APPLE__) && TARGET_OS_IPHONE) +#define XNN_PLATFORM_MOBILE 1 +#else +#define XNN_PLATFORM_MOBILE 0 +#endif + +#if XNN_ARCH_WASM +#define XNN_ALLOCATION_ALIGNMENT 4 +#elif XNN_ARCH_X86 || XNN_ARCH_X86_64 +#if XNN_PLATFORM_MOBILE +#define XNN_ALLOCATION_ALIGNMENT 32 +#else +#define XNN_ALLOCATION_ALIGNMENT 64 +#endif +#else +#define XNN_ALLOCATION_ALIGNMENT 16 +#endif + std::pair GetStoredAllocator(); } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h index ada39c767f7c6..0978a88288114 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h @@ -4,6 +4,7 @@ #pragma once #include "core/framework/op_kernel.h" #include "core/providers/xnnpack/xnnpack_execution_provider.h" +#include "xnnpack.h" struct pthreadpool; @@ -12,18 +13,59 @@ namespace xnnpack { class XnnpackKernel : public OpKernel { public: - explicit XnnpackKernel(const OpKernelInfo& info) - : OpKernel(info), - xnnpack_threadpool_( - static_cast(info.GetExecutionProvider()) - ->GetPrivateThreadPool()) { + explicit XnnpackKernel(const OpKernelInfo& info, bool enable_caches = false) + : OpKernel{info}, + xnnpack_threadpool_{ + static_cast(info.GetExecutionProvider())->GetPrivateThreadPool()}, + caches_{enable_caches} { } [[nodiscard]] pthreadpool* GetThreadPool() const { return xnnpack_threadpool_; } + // see comment below about enabling code cache + // xnn_code_cache_t GetCodeCache() { return caches_.auto_code_cache.get();} + xnn_code_cache_t GetCodeCache() { return nullptr; } + xnn_weights_cache_t GetWeightsCache() { return caches_.auto_weights_cache.get(); } + private: pthreadpool* xnnpack_threadpool_; + + // Helper class to wrap usage of the XNNPACK weights and code caches. + // NOTE: Currently creating/freeing the code cache is not exposed via the public xnnpack.h header so usage is + // commented out. If we need to use it, we'll need to add the 'src' directory of XNNPACK to the include path + // and #include "xnnpack/cache.h" + struct Caches { + Caches(bool enable) + : // auto_code_cache(nullptr, xnn_release_code_cache), + auto_weights_cache(nullptr, xnn_delete_weights_cache) { + if (enable) { +#ifdef XNN_CACHE_ENABLE + xnn_status status = xnn_status_success; +#if XNN_PLATFORM_JIT + // status = xnn_init_code_cache(&code_cache_); + // ORT_ENFORCE(status == xnn_status_success, "Failed to initialize XNNPACK code cache");) + // auto_code_cache.reset(&code_cache_); +#endif + // status = xnn_init_weights_cache(&weights_cache_); + xnn_weights_cache_t weights_cache = nullptr; + status = xnn_create_weights_cache(&weights_cache, 0); + ORT_ENFORCE(status == xnn_status_success, "Failed to create XNNPACK weights cache"); + auto_weights_cache.reset(weights_cache); +#endif + } + } + + // std::unique_ptr auto_code_cache; + std::unique_ptr auto_weights_cache; + + // private: + // #if defined(XNN_CACHE_ENABLE) && XNN_PLATFORM_JIT + // xnn_code_cache code_cache_; + // #endif + }; + + Caches caches_; }; } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 3189346b44ed9..b3e7bd4935c1a 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -25,6 +25,7 @@ #if !defined(ORT_MINIMAL_BUILD) static constexpr uint32_t min_ort_version_with_optional_io_support = 8; static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; +static constexpr uint32_t min_ort_version_with_custom_version = 17; #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -714,8 +715,19 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust KernelDefBuilder def_builder; def_builder.SetName(op->GetName(op)) - .SetDomain(domain) - .SinceVersion(1); + .SetDomain(domain); + + if (op->version >= min_ort_version_with_custom_version) { + if (op->GetStartVersion && op->GetEndVersion) { + def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); + } else if (op->GetStartVersion) { + def_builder.SinceVersion(op->GetStartVersion(op)); + } else { + def_builder.SinceVersion(1); + } + } else { + def_builder.SinceVersion(1); + } // GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions // to work with newer versions (> 12) of the ORT binary. @@ -836,7 +848,11 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); } schema.SetDomain(domain); - schema.SinceVersion(1); + if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) { + schema.SinceVersion(op->GetStartVersion(op)); + } else { + schema.SinceVersion(1); + } schema.AllowUncheckedAttributes(); if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9e59883478227..df4dd55417755 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1432,7 +1432,7 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_vpu_fast_compile"] = legacy_ov_options->enable_vpu_fast_compile; + ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index 7638a12bb820c..59d5a77bfbea3 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -60,8 +60,6 @@ void addIoBindingMethods(pybind11::module& m) { }) // This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector& shape, int64_t data_ptr) -> void { - ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid"); - PyArray_Descr* dtype; if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { throw std::runtime_error("Not a valid numpy type"); diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 04dfa9b51e112..ff76887e917cd 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -5,7 +5,7 @@ #include #include -#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" #include "core/util/thread_utils.h" @@ -53,15 +53,16 @@ void QuantizeMatMul4BitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - contrib::QuantizeBlockwise( - static_cast(dst_buf.ptr), - static_cast(src_buf.ptr), - static_cast(scale_buf.ptr), - is_symmetric ? nullptr : static_cast(zp_buf.ptr), + MlasQuantizeBlockwise( + reinterpret_cast(dst_buf.ptr), + reinterpret_cast(scale_buf.ptr), + is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), + reinterpret_cast(src_buf.ptr), block_size, - 4, - N, + true, K, + N, + N, tp.get()); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7faca3b4681b8..2027b592326df 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -813,10 +813,10 @@ std::unique_ptr CreateExecutionProviderInstance( if (option.first == "device_type") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "enable_vpu_fast_compile") { + } else if (option.first == "enable_npu_fast_compile") { if (!(option.second == "True" || option.second == "true" || option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_vpu_fast_compile: ", option.second); + ORT_THROW("Invalid value passed for enable_npu_fast_compile: ", option.second); } OV_provider_options_map[option.first] = option.second; } else if (option.first == "enable_opencl_throttling") { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 5bb6bcc38b6fe..a5bcbce89bac6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -60,11 +60,11 @@ struct OrtStatus { #elif OPENVINO_CONFIG_GPU_FP16 #define BACKEND_OPENVINO "-OPENVINO_GPU_FP16" -#elif OPENVINO_CONFIG_VPUX_FP16 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_FP16" +#elif OPENVINO_CONFIG_NPU_FP16 +#define BACKEND_OPENVINO "-OPENVINO_NPU_FP16" -#elif OPENVINO_CONFIG_VPUX_U8 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_U8" +#elif OPENVINO_CONFIG_NPU_U8 +#define BACKEND_OPENVINO "-OPENVINO_NPU_U8" #elif OPENVINO_CONFIG_MULTI #define BACKEND_OPENVINO "-OPENVINO_MULTI" diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py index 9cb937a13ff27..111e156cd6d01 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py @@ -56,7 +56,7 @@ def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): a = np.random.rand(m, k).astype(dtype) b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype) - zeropoints = np.random.rand((n * ((k + 31) // 32) + 1) // 2).astype(dtype) + zeropoints = np.random.rand(n * (((k + 31) // 32 + 1) // 2)).astype(dtype) output_d = ke.DeviceArray(output) a_d = ke.DeviceArray(a) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index fea9e5e8cb739..1c3c212b54fa4 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -61,7 +61,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: # block wise quantization, each block comes from a single column packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) return (packed, scales, zero_point) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef1c46b83946a..81ae3eb9e7e5d 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -147,6 +147,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatherElements": self._infer_GatherElements, "GatherND": self._infer_GatherND, "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, "If": self._infer_If, "Loop": self._infer_Loop, "MatMul": self._infer_MatMul, @@ -200,6 +201,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFastGelu": self._infer_GemmFastGelu, "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, "MultiHeadAttention": self._infer_MultiHeadAttention, @@ -2376,6 +2378,11 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + def _infer_BiasSplitGelu(self, node): # noqa: N802 input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index 966c360a3a4f5..0f06676641a96 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -89,7 +89,13 @@ def split_and_sort_output(string_list): string_list = string_list.split("\n") - string_list.sort() + + def custom_sort(item): + # Parse digits + numbers = re.findall(r"\d+", item) + return int(numbers[0]) if numbers else float("inf") + + string_list.sort(key=custom_sort) return string_list diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index b32ae64c5b0c0..7aca5e8526a23 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): +def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): # Add model input for past sequence length @@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads # Replace MultiHeadAttention with GroupQueryAttention for node in model.model.graph.node: if node.op_type == "MultiHeadAttention": + num_heads_mha = 0 + for att in node.attribute: + if att.name == "num_heads": + num_heads_mha = att.i gqa_node = onnx.helper.make_node( "GroupQueryAttention", inputs=[ @@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", - num_heads=node.attribute[0].i, - kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + num_heads=num_heads_mha // world_size, + kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, is_past_bsnh=0, ) model.model.graph.node.remove(node) diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index c5d7bc16d64f7..67f4f0b55cff8 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -130,3 +130,8 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]): for node in nodes: if node not in self.nodes_to_remove: self.nodes_to_remove.append(node) + + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + for node in nodes: + if node not in self.nodes_to_remove and node not in nodes_to_keep: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 3c5029ac5752f..ceee836e33f77 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -323,6 +323,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # qkv_nodes_1 is for LLaMA-2 Microsoft # qkv_nodes_2 is for LLaMA-2 Hugging Face + # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model qkv_nodes = None qkv_nodes_1 = self.model.match_parent_path( normalize_node, @@ -334,18 +335,27 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], ) + qkv_nodes_3 = self.model.match_parent_path( + normalize_node, + ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0, 0], + ) if qkv_nodes_1 is not None: _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 qkv_nodes = qkv_nodes_1 elif qkv_nodes_2 is not None: _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 qkv_nodes = qkv_nodes_2 + elif qkv_nodes_3 is not None: + _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3 + qkv_nodes = qkv_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match qkv nodes") return # v_nodes_1 is for LLaMA-2 Microsoft # v_nodes_3 is for LLaMA-2 Hugging Face + # v_nodes_4 is for LLaMA-2 70B model past_v, present_v, past_seq_len = "", "", "" v_nodes = None v_nodes_1 = self.model.match_parent_path( @@ -363,6 +373,118 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) + _, v_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qkv, + [ + ( + ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 2, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 3, 0, 0, 0, 1, 0, 0], + ), + ], + output_name_to_node=None, + ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 v_nodes = v_nodes_1 @@ -388,6 +510,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): transpose_v, reshape_v, matmul_v = v_nodes_3 v_nodes = v_nodes_3 present_v = transpose_v.output[0] + elif v_nodes_4 is not None and len(v_nodes_4) == 9: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] + v_nodes = v_nodes_4 + past_v = concat_v.input[0] + present_v = concat_v.output[0] else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -427,6 +554,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 2, 1, 0, 0, 0], ) + attn_mask_nodes_5 = self.model.match_parent_path( + add_qk, + ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_6 = self.model.match_parent_path( + add_qk, + ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] @@ -439,12 +576,19 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif attn_mask_nodes_4 is not None: # Reshape from (B,1,S,T) to (B,N,S,T) add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) + elif attn_mask_nodes_5 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_5[0].output[0] + elif attn_mask_nodes_6 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_6[0].output[0] else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return # k_nodes_1 is for LLaMA-2 Microsoft # k_nodes_2 is for LLaMA-2 Hugging Face + # k_nodes_4 is for LLaMA-2 70B Hugging Face past_k, present_k = "", "" k_nodes = None k_nodes_1 = self.model.match_parent_path( @@ -462,6 +606,174 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0, 0], ) + _, k_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qk, + [ + ( + [ + "Transpose", + "Reshape", + "Expand", + "Unsqueeze", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ], + output_name_to_node=None, + ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 k_nodes = k_nodes_1 @@ -489,6 +801,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_3 past_k = concat_k.input[0] present_k = concat_k.output[0] + elif k_nodes_4 is not None and len(k_nodes_4) == 9: + reshape_k, matmul_k = k_nodes_4[0][-2:] + concat_k, rotary_k = k_nodes_4[0][-5:-3] + k_nodes = k_nodes_4 + past_k = concat_k.input[0] + present_k = concat_k.output[0] else: logger.debug("fuse_rotary_attention: failed to match k nodes") return @@ -536,7 +854,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return root_output = reshape_qkv_2.output[0] - elif qkv_nodes == qkv_nodes_2: + elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3): if not self.check_runtime_shape_paths_for_nodes( reshape_qkv, reshape_q, @@ -557,6 +875,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) rotary_k.output[0] = rotary_k.name + "_output_0" + if qkv_nodes == qkv_nodes_3: + qkv_nodes = qkv_nodes[1:] + new_node = self.create_mha_node( matmul_q.input[0], root_output, @@ -578,7 +899,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(v_nodes[:-1]) + + if v_nodes != v_nodes_4: + self.nodes_to_remove.extend(v_nodes[:-1]) + else: + nodes_to_keep = [v_nodes[0][-1]] + for temp_path in v_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) + self.nodes_to_remove.extend(qk_nodes) if k_nodes == k_nodes_1: @@ -592,6 +920,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.append(k_nodes[1]) self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) + elif k_nodes == k_nodes_4: + nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] + for temp_path in k_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) if q_nodes == q_nodes_1: self.nodes_to_remove.extend(q_nodes[:-2]) diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index de17f195c99cc..50703b9c17e03 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,6 +1,6 @@ import logging from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union import numpy import torch @@ -229,7 +229,7 @@ def __del__(self): del self.io_binding del self.ort_session - def allocate_buffers(self, shape_dict: Dict[str, tuple]): + def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): """Allocate tensors for I/O Binding""" if self.enable_cuda_graph: for name, shape in shape_dict.items(): diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 9619e6cb52a91..1bb6940d1cd74 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -10,6 +10,8 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. - `requirements-quant.txt` - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements-70b-model.txt` + - For running the LLaMA-2 70B model on multiple GPUs - `requirements.txt` - Package versions needed in each of the above files @@ -79,6 +81,15 @@ model.save_pretrained(name.split("/")[-1] + "-onnx") Here are some additional examples for exporting LLaMA-2. +Export Model with Different GPU Device Ids +``` +# From source using first GPU: +$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel using second GPU: +$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +``` + Export Saved Model on Disk ``` # From source: @@ -153,6 +164,19 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` +Export LLaMA-2 70B sharded model into 4 partitions +``` +# From source: +# 1. Install necessary packages from requirements-70b-model.txt + +# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: +$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ + +# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: +$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda + +``` + ## Benchmark LLaMA-2 Here are some examples of how you can benchmark LLaMA-2. @@ -220,11 +244,11 @@ python3 -m models.llama.benchmark \ --device cuda ``` -6. ONNX Runtime, FP32, convert_to_onnx +6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -232,11 +256,11 @@ python3 -m models.llama.benchmark \ --device cpu ``` -7. ONNX Runtime, FP16, convert_to_onnx +7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index a721979eb0bcb..be678931de5d1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,9 +11,9 @@ import onnx import psutil import torch -from benchmark_helper import setup_logger +from dist_settings import get_rank, get_size from llama_inputs import ( - convert_inputs_for_ort, + add_io_bindings, get_merged_sample_with_past_kv_inputs, get_msft_sample_inputs, get_sample_inputs, @@ -25,7 +25,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory +from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -48,9 +48,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: - # Set max_seq_len to 2048 for Hugging Face model since that is the default value - # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported - max_seq_len = 2048 + # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) + # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value + # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_seq_len = ( + 2048 + if args.benchmark_type == "ort-msft" + else 16384 + if "codellama" in temp_name + else 4096 + if "llama2" in temp_name + else 2048 + ) if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( @@ -95,7 +105,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -104,7 +116,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) @@ -116,8 +130,11 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="ort", return_dict=True, + world_size=args.world_size, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, @@ -125,26 +142,11 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, - use_fp16=args.use_fp16, - return_dict=True, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + engine="ort", + return_dict=True, + world_size=args.world_size, ) elif args.benchmark_type == "ort-msft": @@ -156,6 +158,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=0, seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, split_kv=split_kv, ) @@ -164,26 +167,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=args.sequence_length, seq_len=1, - use_fp16=args.use_fp16, - split_kv=split_kv, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + split_kv=split_kv, ) else: @@ -261,10 +247,10 @@ def get_model(args: argparse.Namespace): if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx - logger.info(f"Loading model from {args.ort_model_path}") + logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}") start_time = time.time() model = ort.InferenceSession( - args.ort_model_path, + args.ort_model_path.format(args.rank), sess_options, providers=[args.execution_provider], ) @@ -332,10 +318,11 @@ def time_fn(args, fn, inputs): latency = total_time / args.num_runs throughput = args.batch_size / latency - logger.info(f"Batch Size: {args.batch_size}") - logger.info(f"Sequence Length: {args.sequence_length}") - logger.info(f"Latency: {latency} s") - logger.info(f"Throughput: {throughput} tps") + if args.rank == 0: + logger.info(f"Batch Size: {args.batch_size}") + logger.info(f"Sequence Length: {args.sequence_length}") + logger.info(f"Latency: {latency} s") + logger.info(f"Throughput: {throughput} tps") return @@ -375,7 +362,8 @@ def measure_fn(args, fn, inputs): process.cpu_percent(interval=0.1) fn(inputs) - logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%") + if args.rank == 0: + logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%") # Measure memory usage gc.collect() @@ -449,7 +437,7 @@ def get_logits(inputs): def run_ort_inference(args, init_inputs, iter_inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -467,29 +455,13 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.past_present_share_buffer: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.past_present_share_buffer and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) - + io_binding, kv_cache_ortvalues = add_io_bindings( + model, inputs, args.device, int(args.rank), kv_cache_ortvalues + ) setattr(args, "io_binding", io_binding) # noqa: B010 - return io_binding + return io_binding, kv_cache_ortvalues - return inputs + return inputs, kv_cache_ortvalues def with_io_binding(io_binding): # Inference pass with IO binding @@ -501,9 +473,10 @@ def without_io_binding(inputs): return outputs generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + kv_cache_ortvalues = {} if args.profile: - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") # Turn profiling off to stop appending to log file @@ -513,7 +486,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log @@ -524,12 +497,12 @@ def without_io_binding(inputs): # ORT evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) @@ -543,7 +516,7 @@ def run_inference(args, init_inputs, iter_inputs, model): raise Exception(f"Cannot recognize {args.benchmark_type}") -def get_args(): +def get_args(rank=0): parser = argparse.ArgumentParser() parser.add_argument( "-bt", @@ -601,7 +574,7 @@ def get_args(): parser.add_argument( "-s", "--sequence-lengths", - default="8 16 32 64 128 256 512", + default="32 64 128 256 512", ) parser.add_argument( "-d", @@ -638,9 +611,9 @@ def get_args(): if "ort" in args.benchmark_type: setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 if args.execution_provider == "CUDAExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) args.device = "cuda" # Check that paths have been specified for any benchmarking with ORT @@ -667,14 +640,19 @@ def get_args(): def main(): - args = get_args() + rank = get_rank() + world_size = get_size() + + args = get_args(rank) setup_logger(args.verbose) logger.info(args.__dict__) torch.backends.cudnn.benchmark = True + args.rank = rank + args.world_size = world_size tokenizer = LlamaTokenizer.from_pretrained(args.model_name) config = LlamaConfig.from_pretrained(args.model_name) - target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device + target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device use_fp16 = args.precision == "fp16" setattr(args, "tokenizer", tokenizer) # noqa: B010 @@ -688,7 +666,7 @@ def main(): # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: - onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False) + onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False) gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" @@ -698,7 +676,8 @@ def main(): # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): - logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") + if args.rank == 0: + logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh new file mode 100644 index 0000000000000..38f1916456658 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python benchmark.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 951b2549368f7..b35a5e27f9ea3 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -247,6 +247,7 @@ def main(): torch.backends.cudnn.benchmark = True all_results = [] + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # Benchmark PyTorch without torch.compile if args.hf_pt_eager: @@ -266,8 +267,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -298,8 +297,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -332,8 +329,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -366,8 +361,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -399,8 +392,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh new file mode 100644 index 0000000000000..637d15c10e0c7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python convert_to_onnx.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 69603fd3ed488..b0e0b41e75d3d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1,16 +1,16 @@ import argparse import logging import os -import tempfile +import shutil from itertools import chain from typing import List import onnx import torch -from benchmark_helper import Precision, prepare_environment, setup_logger -from convert_generation import replace_mha_with_gqa +from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check +from llama_torch import setup_torch_model from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -18,8 +18,11 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger +from onnxruntime.transformers.convert_generation import replace_mha_with_gqa logger = logging.getLogger("") +init_dist() def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): @@ -129,7 +132,9 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st # del onnx_model # temp_dir.cleanup() # -def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_dynamo_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): from torch._dynamo import config config.capture_scalar_outputs = True @@ -150,9 +155,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -160,7 +165,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll # Export decoder_with_past_model.onnx input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length + l_config, device, batch_size, sequence_length, world_size=world_size ) temp_dir = args.output # tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") @@ -172,9 +177,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -183,10 +188,21 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") -def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def _prepare_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +def run_torchscript_separate_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") # Export decoder_model.onnx decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) @@ -199,8 +215,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_inputs, @@ -218,18 +238,25 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) # Export decoder_with_past_model.onnx - decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length) + decoder_with_past_inputs = get_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, + ) input_names = [ "input_ids", "attention_mask", @@ -247,8 +274,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_past_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_with_past_inputs, @@ -266,27 +297,45 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_with_past_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info( + f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!" + ) -def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_merged_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 - device = torch.device("cpu") + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") + + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, past_sequence_length + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, ) input_names = [ "input_ids", @@ -305,8 +354,12 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi ), ] dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_merged_inputs, @@ -324,17 +377,17 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_merged_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!") # Optimize the model as FP32 @@ -357,12 +410,16 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str): remove_existing_model(input_path) -def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]): - decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") +def convert_to_float16( + args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 +): + decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx" + ) + decoder_merged_model_fp16_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx" ) - decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx") new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] logger.info("Converting to float16...") @@ -370,7 +427,7 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: if os.path.exists(fp32_path): model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False) - model = use_group_query_attention(config, model) + model = use_group_query_attention(config, model, world_size) model.save_model_to_file(fp16_path, use_external_data_format=True) del model logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") @@ -380,9 +437,11 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel): +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1): # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes - fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads) + fp16_model_opt = replace_mha_with_gqa( + fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size + ) fp16_model_opt.prune_graph() fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt @@ -406,7 +465,7 @@ def smooth_quant( calibration_sampling_size=[args.calibration_sampling_size], recipes={ "optypes_to_exclude_output_quant": ["MatMul"], - "smooth_quant": args.smooth_quant, + "smooth_quant": True, "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, }, op_type_dict={ @@ -526,15 +585,6 @@ def get_args(): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-r", "--reexport", @@ -655,6 +705,14 @@ def get_args(): ) parser.set_defaults(use_dynamo_export=False) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() return args @@ -673,145 +731,183 @@ def main(): remove_existing_files(args.output) logger.info(f"Arguments: {args}") + world_size = get_size() + rank = get_rank() + # Load model and config use_auth_token = args.input == os.path.join(".") setattr(args, "use_auth_token", use_auth_token) # noqa: B010 - location = args.model_name if use_auth_token else args.input - l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True) original_model_name = args.model_name setattr(args, "original_model_name", original_model_name) # noqa: B010 args.model_name = args.model_name.split("/")[-1] - # Set model paths for FP32 model - decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") - decoder_with_past_model_fp32_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx" - ) - decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - missing_separate_exports = ( - args.no_merged - and not os.path.exists(decoder_model_fp32_path) - and not os.path.exists(decoder_with_past_model_fp32_path) - ) - missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) - - # Export to ONNX - if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_separate_exports: - logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") - logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") - logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") - logger.warning( - "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" - ) - logger.warning( - "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." - ) - run_dynamo_export(args, l_config, llama) - elif args.no_merged: - run_torchscript_separate_export(args, l_config, llama) - else: - run_torchscript_merged_export(args, l_config, llama) + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 - # Set model paths to store FP32 optimized model - decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") - decoder_with_past_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx" - ) - decoder_merged_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx" + location = args.original_model_name if use_auth_token else args.input + + # use cuda for Llama-2-70b to speedup export, other models use CPU by default + l_config, llama = setup_torch_model( + args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None ) - new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] - # Run the optimizer script - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(l_config, input_path=orig_path, output_path=opt_path) + assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0 - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + barrier() + for i in range(world_size): + if i == rank: + # Set model paths for FP32 model + decoder_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx" + ) + decoder_with_past_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx" + ) + decoder_merged_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx" + ) + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - ) - - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, l_config, old_paths) - - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx") - decoder_with_past_model_int8_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx") - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: - logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." + ) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama, rank, world_size) + else: + run_torchscript_merged_export(args, l_config, llama, rank, world_size) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx" + ) + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [ + decoder_model_fp32_opt_path, + decoder_with_past_model_fp32_opt_path, + decoder_merged_model_fp32_opt_path, + ] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" ) - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + ) + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." ) - logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") - remove_existing_model(decoder_model_fp32_path) - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info( + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + barrier() - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, l_config, old_paths) - - decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx") - decoder_with_past_model_int4_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx") - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) - - del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -824,7 +920,12 @@ def main(): # Verify parity on all saved ONNX models for filename in os.listdir(args.output): - if ".data" in filename or ".onnx" not in filename: + if ( + ".data" in filename + or ".onnx" not in filename + or args.precision not in filename + or f"rank_{rank}" not in filename + ): continue parity_cmd = [ @@ -834,10 +935,10 @@ def main(): os.path.join(args.output, filename), "-ep", args.execution_provider, - "-id", - args.device_id, "-fp", args.precision, + "--cache_dir", + args.cache_dir, ] if "with_past" in filename: parity_cmd.append("--use_past_kv") @@ -845,6 +946,7 @@ def main(): parity_cmd.append("--merged") try: + logger.debug(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py new file mode 100644 index 0000000000000..50b0669d6d83a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -0,0 +1,45 @@ +import os + +import torch.distributed as dist + +comm = None + + +def init_dist(): + if "LOCAL_RANK" in os.environ: + int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) + elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + from mpi4py import MPI + + comm = MPI.COMM_WORLD # noqa: F841 + + int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) + else: + # don't need to do init for single process + pass + + +def get_rank(): + return comm.Get_rank() if comm is not None else 0 + + +def get_size(): + return comm.Get_size() if comm is not None else 1 + + +def barrier(): + if comm is not None: + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 2652e9f0ca64e..6530eead55f03 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,7 +4,7 @@ import torch from transformers import LlamaConfig -from onnxruntime import OrtValue +from onnxruntime import InferenceSession, OrtValue # Get position_ids from attention_mask @@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: + # Shape: (batch_size, 1) position_ids = position_ids[:, -1].unsqueeze(-1) + + # Shape: (batch_size, sequence_length) return position_ids # Inputs for first pass to get initial past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, sequence_length) +# position_ids: (batch_size, sequence_length) def get_sample_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + engine: str = "pt", + return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64) - # position_ids is of shape (batch_size, seq_len) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) position_ids = get_position_ids(attention_mask, use_past_kv=False) + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + if not return_dict: + # For export return (input_ids, attention_mask, position_ids) inputs = { @@ -39,85 +53,194 @@ def get_sample_inputs( # Inputs for subsequent passes with past_key_values +# input_ids: (batch_size, 1) +# attention_mask: (batch_size, past_sequence_length + 1) +# position_ids: (batch_size, 1) +# past_key: (batch_size, num_heads, past_sequence_length, head_size) +# past_value: (batch_size, num_heads, past_sequence_length, head_size) def get_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, + world_size: int = 1, ): - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) - attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs # Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# position_ids: (batch_size, sequence_length) +# past_key: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +# past_value: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length def get_merged_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, past_seq_len: int, + max_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, + world_size: int = 1, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_fp16: # If model has GQA + del inputs["attention_mask"] + inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs -# Create past_key_values -def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool +# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx +def get_msft_sample_inputs( + config: LlamaConfig, + batch_size: int, + past_seq_len: int, + seq_len: int, + max_seq_len: int, + use_fp16: bool, + split_kv: bool, ): - num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads + np_dtype = np.float16 if use_fp16 else np.float32 + head_size = config.hidden_size // config.num_attention_heads + + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + + if use_fp16: # If model has GQA + del ort_inputs["attn_mask"] + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +# Create past_key_values +# Each is of shape (batch_size, num_heads, past_sequence_length, head_size) +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): + num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), ) for _ in range(config.num_hidden_layers) ] return past_kv -# Convert list of past_kv to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool): +# Convert list of past_key_values to dict of past_key and past_value +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): past_kv = {} - np_dtype = np.float16 if use_fp16 else np.float32 for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype) - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype) + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv @@ -136,7 +259,7 @@ def convert_inputs_for_ort( if isinstance(v, np.ndarray): ort_inputs[k] = v elif k == "past_key_values": - ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + ort_inputs.update(flatten_past_kv_inputs(v)) elif k == "attention_mask" and use_fp16 and use_buffer_share: # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, # and GQA supports a causal mask by default @@ -146,59 +269,55 @@ def convert_inputs_for_ort( else: ort_inputs[k] = v.detach().cpu().numpy() - # Enable past-present-share-buffer by using device memory directly + # Reshape kv caches if using past-present-share-buffer if use_buffer_share and device != "" and device != "cpu" and device_id > -1: - for k, v in ort_inputs.items(): - new_v = v - # Allocate new buffers with max_sequence_length for GQA - if "cache" in k or "past_key_values" in k: - # Copy v (BxSxPxH) into new_v (BxSxMxH) - batch_size, num_heads, _, head_size = v.shape - new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) - new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v - ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs -# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs( - config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool -): - np_dtype = np.float16 if use_fp16 else np.float32 - head_size = config.hidden_size // config.num_attention_heads - max_seq_len = 2048 +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs - if not split_kv: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } - else: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( - np.int32 - ), - "pos": np.array(past_seq_len, dtype=np.int64), - } - for i in range(config.num_hidden_layers): - ort_inputs.update( - { - f"k_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - f"v_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - } - ) - return ort_inputs +# Add IO bindings for execution providers +def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict): + use_fp16 = False + io_binding = model.io_binding() + + for k, v in ort_inputs.items(): + # Detect if model is in FP16 + if v.dtype == np.float16: + use_fp16 = True + + # Bind OrtValue inputs to device + if use_fp16 and ("cache" in k or "past_key_values" in k): + if k not in kv_cache_ortvalues: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + kv_cache_ortvalues[k] = v_device + else: + kv_cache_ortvalues[k].update_inplace(v) + io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k]) + else: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + + for output in model.get_outputs(): + name = output.name + if use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to past KV cache inputs in order to buffer share + input_name = name.replace("out", "cache").replace("present", "past_key_values") + io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) + else: + io_binding.bind_output(name, device_type=device, device_id=device_id) + + return io_binding, kv_cache_ortvalues diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 4353d0606803d..42581caf3bb9e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -6,44 +6,57 @@ import numpy as np import torch -from benchmark_helper import setup_logger +from dist_settings import get_rank, get_size from llama_inputs import ( + add_io_bindings, convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs, ) +from llama_torch import setup_torch_model from transformers import LlamaConfig, LlamaForCausalLM import onnxruntime as ort +from onnxruntime.transformers.benchmark_helper import setup_logger logger = logging.getLogger("") def get_sequence_lengths(args: argparse.Namespace): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) - max_sequence_length = 2048 + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 return past_sequence_length, curr_sequence_length, max_sequence_length def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity + world_size = get_size() batch_size = 2 - past_sequence_length, sequence_length, _ = get_sequence_lengths(args) + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, batch_size, - sequence_length, - past_sequence_length, + seq_len=sequence_length, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, + world_size=world_size, ) elif args.use_past_kv: inputs = get_sample_with_past_kv_inputs( - config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True + config, + args.device, + batch_size, + sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + world_size=world_size, ) else: inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) @@ -51,31 +64,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): return inputs -def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): - # Add IO bindings for non-CPU execution providers - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.use_fp16: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.use_fp16 and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) - - return io_binding - - -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict): inputs = get_inputs(args, config) # Run inference with PyTorch @@ -87,6 +76,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") + del pt_model # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) @@ -97,12 +87,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, device=args.execution_provider, - device_id=int(args.device_id), + device_id=int(args.rank), ) ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": - ep = (ep, {"device_id": args.device_id}) + ep = (ep, {"device_id": args.rank}) ort_model = ort.InferenceSession( args.onnx_model_path, sess_options=ort.SessionOptions(), @@ -111,7 +101,9 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": - io_binding = add_io_bindings(args, ort_model, inputs) + io_binding, kv_cache_ortvalues = add_io_bindings( + ort_model, inputs, args.execution_provider, int(args.rank), kv_cache_ortvalues + ) io_binding.synchronize_inputs() start_time = time.time() @@ -120,6 +112,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + del ort_model else: start_time = time.time() @@ -131,17 +124,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = ( - 2e1 - if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path - else 1e-3 - if args.precision == "fp32" - else 5e-1 - ) + tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") + return kv_cache_ortvalues def get_args(argv: List[str]): @@ -179,15 +167,6 @@ def get_args(argv: List[str]): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-v", "--verbose", @@ -219,6 +198,14 @@ def get_args(argv: List[str]): help="Precision of model", ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -234,32 +221,35 @@ def main(argv: List[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}") + rank = get_rank() # Load model and config setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 - setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010 + args.rank = rank + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained( + config, llama = setup_torch_model( + args, location, + use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), - use_auth_token=use_auth_token, - use_cache=True, - ).to(args.device) + device=args.device, + ) + kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) else: # Verify prompt generation in merged model (decoder_model.onnx) args.use_past_kv = False - verify_parity(args, config, llama) + kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py new file mode 100644 index 0000000000000..cf6406dde5be0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -0,0 +1,38 @@ +import logging +import os + +import torch +from dist_settings import barrier, get_rank, get_size +from transformers import LlamaConfig, LlamaForCausalLM + +logger = logging.getLogger("") + + +def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, device=None): + world_size = get_size() + logger.info(f"world_size: {world_size}") + rank = get_rank() + barrier() + + if not os.path.exists(args.cache_dir): + os.makedirs(args.cache_dir, exist_ok=True) + + for i in range(world_size): + if i == rank % (world_size): + l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) + l_config.use_cache = True + llama = LlamaForCausalLM.from_pretrained( + location, + use_auth_token=use_auth_token, + config=l_config, + torch_dtype=torch_dtype, + cache_dir=args.cache_dir, + ) + if world_size > 1: + llama.parallel_model() + if device: + llama.to(device) + llama.eval() + llama.requires_grad_(False) + barrier() + return l_config, llama diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt new file mode 100644 index 0000000000000..572cfdb71be4a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt @@ -0,0 +1,4 @@ +-r requirements.txt +git+https://github.com/frankdongms/transformers.git@frdong/shard_llama +mpi4py +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 07870373e90b0..66ec0de88b44c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -337,6 +337,18 @@ def match_parent_paths(self, node, paths, output_name_to_node): return i, matched, return_indice return -1, None, None + def match_parent_paths_all(self, node, paths, output_name_to_node): + match_i, matches, return_indices = [], [], [] + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + match_i.append(i) + matches.append(matched) + return_indices.append(return_indice) + return match_i, matches, return_indices + def match_parent_path( self, node, diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index dc8efbbaf3709..918ee0e6eb976 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -14,7 +14,6 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" -#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" #include #include @@ -25,6 +24,8 @@ namespace onnxruntime { namespace test { +static constexpr int QBits = 4; + void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -35,27 +36,29 @@ void QuantizeDequantize(std::vector& raw_vals, OrtThreadPoolParams to; auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); - contrib::QuantizeBlockwise( + + MlasQuantizeBlockwise( quant_vals.data(), - raw_vals.data(), scales.data(), zp != nullptr ? zp->data() : nullptr, + raw_vals.data(), block_size, - 4, - N, + true, K, + N, + N, tp.get()); // Note that input1_f_vals is NxK after dequant - contrib::DequantizeBlockwise( - raw_vals.data(), - quant_vals.data(), - scales.data(), - zp != nullptr ? zp->data() : nullptr, - block_size, - 4, - N, - K, + MlasDequantizeBlockwise( + raw_vals.data(), // dequantized output + quant_vals.data(), // quantized input + scales.data(), // quantization scales + zp != nullptr ? zp->data() : nullptr, // quantization zero points + block_size, // quantization block size + true, // columnwise quantization + K, // number of rows + N, // number of columns tp.get()); } @@ -69,13 +72,21 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); #endif - int64_t block_per_k = (K + block_size - 1) / block_size; - int64_t number_of_block = block_per_k * N; - int64_t block_blob_size = block_size * 4 / 8; - int64_t buf_size = number_of_block * (block_size * 4 / 8); - std::vector input1_vals(buf_size); - std::vector scales(number_of_block); - std::vector zp((N * block_per_k + 1) / 2); + int meta_rows; + int meta_cols; + MlasBlockwiseQuantMetaShape((int)block_size, true, (int)K, (int)N, meta_rows, meta_cols); + + int q_rows; + int q_cols; + MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); + + std::vector input1_vals(q_rows * q_cols); + std::vector scales(meta_rows * meta_cols); + + // TODO!! THIS SHOULD BE PROVIDED BY MLAS + // sub 8b packing always happen on the column dimension + const int packed_meta_rows = (meta_rows * QBits + 7) / 8; + std::vector zp(packed_meta_rows * meta_cols); QuantizeDequantize(input1_f_vals, input1_vals, @@ -100,13 +111,13 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop test.AddAttribute("K", K); test.AddAttribute("N", N); test.AddAttribute("block_size", block_size); - test.AddAttribute("bits", 4); + test.AddAttribute("bits", QBits); if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); - test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); - test.AddInput("scales", {N * block_per_k}, ToFloat16(scales), true); + test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + test.AddInput("scales", {meta_cols * meta_rows}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -117,10 +128,10 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else { test.AddInput("A", {M, K}, input0_vals, false); - test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); - test.AddInput("scales", {N * block_per_k}, scales, true); + test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + test.AddInput("scales", {meta_cols * meta_rows}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); } test.AddOutput("Y", {M, N}, expected_vals); diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc new file mode 100644 index 0000000000000..fefd5722054de --- /dev/null +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { + constexpr int64_t B = 2; + constexpr int64_t C = 16; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + -0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f, + -0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f, + -0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f, + -2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f, + 1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f, + -0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f, + -0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f, + -1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f, + -0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f, + 1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f, + 0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f, + 0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f, + 0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f, + -0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f, + 1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f, + -0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f}; + + std::vector gamma_data = { + 0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f, + 0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f}; + + std::vector beta_data = { + -1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f, + -1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f}; + + std::vector skip_data_nhwc = { + 0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f, + -0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f, + 0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f, + -0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f, + 0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f, + -0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f, + -0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f, + 0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f, + -0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f, + 0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f, + -0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f, + -1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f, + -0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f, + 0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f, + 0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f, + -0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f}; + + std::vector bias_data = { + -0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f, + 1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f}; + + std::vector norm_data_nhwc = { + -1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f, + 1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f, + -0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f, + 0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f, + -1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f, + -0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f, + 2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f, + -1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f, + -1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f, + 1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f, + -0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f, + -0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f, + -0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f, + -0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f, + 0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f, + -1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f, + -1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f, + 1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f, + -0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f, + -2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f, + -1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f, + -0.542480f, 1.990234f}; + + std::vector add_out_data_nhwc = { + -0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f, + 0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f, + -1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f, + -1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f, + 2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f, + -0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f, + -1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f, + 0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f, + -1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f, + 2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f, + -0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f, + -0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f, + -1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f, + 1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f, + 1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f, + 0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array channels_last_values = {-1, 1}; + + for (const int channels_last : channels_last_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 4); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + test.AddInput("skip", dims_nhwc, ToFloat16(skip_data_nhwc)); + test.AddInput("bias", {C}, ToFloat16(bias_data)); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { + constexpr int64_t B = 1; + constexpr int64_t C = 64; + constexpr int64_t H = 1; + constexpr int64_t W = 1; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + 0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f, + -0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f, + 0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f, + -1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f, + -0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f, + -1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f, + -1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f, + -0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f}; + + std::vector gamma_data = { + 0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f, + 0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f, + 0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f, + 1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f, + -0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f, + 0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f, + 0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f, + -0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f}; + + std::vector beta_data = { + -0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f, + 1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f, + -2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f, + -1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f, + -1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f, + -0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f, + -1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f, + -0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f}; + + std::vector skip_data = { + 0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f, + 0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f, + -0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f, + -0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f, + -0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f, + -0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f, + -1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f, + -1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f, + 1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f, + -0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f, + 1.248047f, -2.515625f, 0.080383f, 0.376221f}; + + std::vector norm_data_nhwc = { + 0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f, + 1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f, + -2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f, + -2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f, + -0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f, + -0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f, + -1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f, + -0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f}; + + std::vector add_out_data_nhwc = { + 1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f, + 1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f, + -0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f, + 1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f, + -1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f, + -0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f, + -0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f, + -0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f, + 0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f, + 0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f, + 1.401367f, -3.828125f, 2.292969f, 1.061523f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array has_add_out_values = {true, false}; + std::array skip_dims = {2, 4}; + + constexpr int channels_last = 1; + for (const int skip_dim : skip_dims) { + for (const bool has_add_out : has_add_out_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 8); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + if (skip_dim == 2) { + test.AddInput("skip", {B, C}, ToFloat16(skip_data)); + } else { + test.AddInput("skip", {B, 1, 1, C}, ToFloat16(skip_data)); + } + // no bias + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + + if (has_add_out) { + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + } + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index c38baee39216b..1804c09043c7b 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -4,6 +4,7 @@ #include "core/framework/allocator.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/graph/model.h" +#include "core/graph/node_attr_utils.h" #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" @@ -110,6 +111,70 @@ TEST(TransformerTest, InsertCastAllCPUTest) { } } +TEST(TransformerTest, CastRemovalDoesNotLowerPrecisionTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_float_32; + tensor_float_32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + TypeProto tensor_float_64; + tensor_float_64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_DOUBLE); + onnxruntime::NodeArg n1_def("N1", &tensor_float_64), + n2_def("N2", &tensor_float_32), + n3_def("N3", &tensor_float_64); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))}}; + + graph.AddNode("node1", "Cast", "F64 to F32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "F32 to F64 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting f64 -> f32 -> f64 we should not be optimising away the cast since there is a loss of precision. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + +TEST(TransformerTest, CastRemovalDoesNotRemoveSignednessTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_uint32; + tensor_uint32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_UINT32); + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + onnxruntime::NodeArg n1_def("N1", &tensor_int32), + n2_def("N2", &tensor_uint32), + n3_def("N3", &tensor_int32); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_UINT32))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32))}}; + + graph.AddNode("node1", "Cast", "I32 to UI32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "UI32 to I32 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting i32 -> ui32 -> i32 we should not be optimising away the cast since applying the casts produces a very different result. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + // test that when there are 3 Cast ops in a row we remove the correct ones TEST(TransformerTest, ThreeInARowRemoval) { auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx"); diff --git a/onnxruntime/test/mlas/unittest/test_activation.cpp b/onnxruntime/test/mlas/unittest/test_activation.cpp index eb3e35d739bb3..2bb0bbcd35e26 100644 --- a/onnxruntime/test/mlas/unittest/test_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_activation.cpp @@ -256,9 +256,6 @@ class MlasActivationTest : public MlasTestBase { } }; -template <> -MlasActivationTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; }); diff --git a/onnxruntime/test/mlas/unittest/test_blkq8.cpp b/onnxruntime/test/mlas/unittest/test_blkq8.cpp index 15bbd1b4cb28d..5cff86d411ca9 100644 --- a/onnxruntime/test/mlas/unittest/test_blkq8.cpp +++ b/onnxruntime/test/mlas/unittest/test_blkq8.cpp @@ -150,12 +150,6 @@ class MlasBlkQ8ShortExeTest : public MlasTestFixture> { size_t M_, K_; }; -template <> -MlasBlkQ8Test* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasBlkQ8Test* MlasTestFixture>::mlas_tester(nullptr); - static size_t BlkQ8ReisterShortTests() { size_t cnt = 0; cnt += MlasBlkQ8ShortExeTest::RegisterShortExecuteTests(); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index 6f06e0f2eead8..f836da8277bb8 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -96,7 +96,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, + columnwise, rows, columns, threadpool_ptr); MlasTranspose(dequant_buf, transposed, columns, rows); @@ -104,7 +105,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, + columnwise, rows, columns, columns, threadpool_ptr); for (int c = 0; c < columns; c++) { for (int r = 0; r < rows; r += 2) { @@ -194,9 +196,6 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MlasBlockwiseQdqTest() = default; }; -template <> -MlasBlockwiseQdqTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { diff --git a/onnxruntime/test/mlas/unittest/test_conv2d.cpp b/onnxruntime/test/mlas/unittest/test_conv2d.cpp index 97560bbfc2e7e..1700cd8f1800f 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv2d.cpp @@ -4,11 +4,6 @@ #include "test_conv2d.h" #include "test_conv2d_fixture.h" -template <> -MlasConv2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasConv2DTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Conv2dRegistLongExecute() { size_t count = MlasLongExecuteTests>::RegisterLongExecute(); if (GetMlasThreadPool() != nullptr) { diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp index 78a047e385b99..e5a536eb9e4f0 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp @@ -4,11 +4,6 @@ #include "test_conv2d_nchwc.h" #include "test_conv2d_fixture.h" -template <> -MlasNchwcConv2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasNchwcConv2DTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Conv2dNchwcRegistLongExecute() { size_t count = 0; diff --git a/onnxruntime/test/mlas/unittest/test_exp.cpp b/onnxruntime/test/mlas/unittest/test_exp.cpp index ce8c4e97748f8..f9cdffef1947d 100644 --- a/onnxruntime/test/mlas/unittest/test_exp.cpp +++ b/onnxruntime/test/mlas/unittest/test_exp.cpp @@ -50,9 +50,6 @@ class MlasComputeExpTest : public MlasTestBase { } }; -template <> -MlasComputeExpTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { // no long execute needed return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.cpp b/onnxruntime/test/mlas/unittest/test_fgemm.cpp index 6b8d4529faadb..e3f50baf3633d 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_fgemm.cpp @@ -7,24 +7,6 @@ #include #include -template <> -MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -#ifdef MLAS_SUPPORTS_GEMM_DOUBLE - -template <> -MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasFgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -#endif - static size_t FGemmRegistLongExecute() { size_t count = 0; diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp index a9e062e0b6534..484a9a22429d5 100644 --- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp @@ -148,9 +148,6 @@ class MlasFp16ActivationTest : public MlasTestBase { } }; -template <> -MlasFp16ActivationTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; }); diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index 1a307d339b0f2..2a478675d09eb 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -89,42 +89,6 @@ class HalfGemmShortExecuteTest : public MlasTestFixture -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t HalfGemmRegistLongExecute() { size_t count = 0; diff --git a/onnxruntime/test/mlas/unittest/test_minmax.cpp b/onnxruntime/test/mlas/unittest/test_minmax.cpp index f0df504720c0c..245879deccffd 100644 --- a/onnxruntime/test/mlas/unittest/test_minmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_minmax.cpp @@ -46,9 +46,6 @@ class MlasFindMinMaxElementsTest : public MlasTestBase { } }; -template <> -MlasFindMinMaxElementsTest* MlasTestFixture::mlas_tester(nullptr); - #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" diff --git a/onnxruntime/test/mlas/unittest/test_pool2d.cpp b/onnxruntime/test/mlas/unittest/test_pool2d.cpp index 012e7f25fddce..8cefb8332ec32 100644 --- a/onnxruntime/test/mlas/unittest/test_pool2d.cpp +++ b/onnxruntime/test/mlas/unittest/test_pool2d.cpp @@ -4,20 +4,6 @@ #include "test_pool2d.h" #include "test_pool2d_fixture.h" -template <> -MlasPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool2DTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool2DTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Pool2dRegistLongExecute() { size_t count = 0; count += MlasLongExecuteTests>::RegisterLongExecute(); diff --git a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp index 190fbe7d5a6f1..bee690b10b737 100644 --- a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp +++ b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp @@ -4,20 +4,6 @@ #include "test_pool2d_nchwc.h" #include "test_pool2d_fixture.h" -template <> -MlasNchwcPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasNchwcPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasNchwcPool2DTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasNchwcPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasNchwcPool2DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasNchwcPool2DTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Pool2dNchwcRegistLongExecute() { size_t count = 0; if (MlasNchwcGetBlockSize() > 1) { diff --git a/onnxruntime/test/mlas/unittest/test_pool3d.cpp b/onnxruntime/test/mlas/unittest/test_pool3d.cpp index a93698234f7da..e0ce4c240be80 100644 --- a/onnxruntime/test/mlas/unittest/test_pool3d.cpp +++ b/onnxruntime/test/mlas/unittest/test_pool3d.cpp @@ -4,20 +4,6 @@ #include "test_pool3d.h" #include "test_pool3d_fixture.h" -template <> -MlasPool3DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool3DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool3DTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasPool3DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool3DTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasPool3DTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Pool3dRegistLongExecute() { size_t count = 0; count += MlasLongExecuteTests>::RegisterLongExecute(); diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q4gemm.cpp index 2c3bf23a9330b..dccd7d00b6d3f 100644 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_q4gemm.cpp @@ -83,19 +83,6 @@ class Q4GemmShortExecuteTest : public MlasTestFixture -MlasQ4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ4GemmTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Q4GemmRegistShortExecute() { size_t count = 0; diff --git a/onnxruntime/test/mlas/unittest/test_q4qdq.cpp b/onnxruntime/test/mlas/unittest/test_q4qdq.cpp index 8215c63a2cc56..955c3b1201989 100644 --- a/onnxruntime/test/mlas/unittest/test_q4qdq.cpp +++ b/onnxruntime/test/mlas/unittest/test_q4qdq.cpp @@ -141,9 +141,6 @@ class MlasQ4dqTest : public MlasTestBase { MlasQ4dqTest() = default; }; -template <> -MlasQ4dqTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { if (MlasQ4GemmPackBSize(BlkQ4Sym, 32, 32) == 0) { return (size_t)0; diff --git a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp index bac16b0103a6e..a78a3261d1f2a 100644 --- a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp @@ -271,19 +271,6 @@ class Q8Q4GemmShortExecuteTest : public MlasTestFixture -MlasQ8Q4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ8Q4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ8Q4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ8Q4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ8Q4GemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQ8Q4GemmTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t Q8Q4GemmRegistShortExecute() { size_t count = 0; diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_qgemm.cpp index a55331f1377fa..6bb93d35357f8 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm.cpp @@ -1,60 +1,6 @@ #include "test_qgemm.h" #include "test_qgemm_fixture.h" -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t QGemmRegistLongExecute() { size_t count = 0; diff --git a/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp b/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp index 93dda4bee183b..5876f186eaa0d 100644 --- a/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp +++ b/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp @@ -163,11 +163,6 @@ class MlasQLinearMulTest : public MlasQLinearBinaryOpTest { } }; -template <> -MlasQLinearAddTest* MlasTestFixture::mlas_tester(nullptr); -template <> -MlasQLinearMulTest* MlasTestFixture::mlas_tester(nullptr); - static bool UNUSED_VARIABLE added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { diff --git a/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp b/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp index aeb13af5b941a..e6c230df57fbc 100644 --- a/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp +++ b/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp @@ -162,11 +162,6 @@ class MlasQLinearGlobalAveragePoolTest : public MlasTestBase { } }; -template <> -MlasQLinearGlobalAveragePoolTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQLinearGlobalAveragePoolTest* MlasTestFixture>::mlas_tester(nullptr); - template <> const std::vector MlasQLinearGlobalAveragePoolTest::ZeroPoints = {-128, -110, 1, 103, 127}; diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp index 2832598fef1a9..986d158d2b1b9 100644 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp @@ -71,15 +71,6 @@ class MlasQuantizeLinearTest : public MlasTestBase { } }; -template <> -MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasQuantizeLinearTest* MlasTestFixture>::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { diff --git a/onnxruntime/test/mlas/unittest/test_reorder_output.cpp b/onnxruntime/test/mlas/unittest/test_reorder_output.cpp index 21373fe9f66e7..e39abd8578da4 100644 --- a/onnxruntime/test/mlas/unittest/test_reorder_output.cpp +++ b/onnxruntime/test/mlas/unittest/test_reorder_output.cpp @@ -88,9 +88,6 @@ class MlasReorderOutputTest : public MlasTestBase { } }; -template <> -MlasReorderOutputTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { return (MlasNchwcGetBlockSize() > 1 && is_short_execute) ? MlasDirectShortExecuteTests::RegisterShortExecute() diff --git a/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp b/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp index 7732b1fa8c72e..34f17843b0726 100644 --- a/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp +++ b/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp @@ -77,9 +77,6 @@ class MlasScaleOutputTest : public MlasTestBase { } }; -template <> -MlasScaleOutputTest* MlasTestFixture::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; }); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 3df2b88f9652a..4c5e11bbe9566 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -97,11 +97,6 @@ class MlasSoftmaxTest : public MlasTestBase { } }; -template <> -MlasSoftmaxTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasSoftmaxTest* MlasTestFixture>::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { diff --git a/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp index adfe5564ebbbf..bb3aea02cc011 100644 --- a/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp @@ -1,10 +1,5 @@ #include "test_symm_qgemm_fixture.h" -template <> -MlasSymmQgemmTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasSymmQgemmTest* MlasTestFixture>::mlas_tester(nullptr); - static size_t SymmQgemmRegistLongExecute() { if (MlasSymmQgemmPackBSize(16, 16, true) == 0) { return 0; diff --git a/onnxruntime/test/mlas/unittest/test_transpose.cpp b/onnxruntime/test/mlas/unittest/test_transpose.cpp index 74ce5868f411d..8fa98411a21ab 100644 --- a/onnxruntime/test/mlas/unittest/test_transpose.cpp +++ b/onnxruntime/test/mlas/unittest/test_transpose.cpp @@ -45,13 +45,6 @@ class MlasTransposeTest : public MlasTestBase { } }; -template <> -MlasTransposeTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasTransposeTest* MlasTestFixture>::mlas_tester(nullptr); -template <> -MlasTransposeTest* MlasTestFixture>::mlas_tester(nullptr); - static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index c5ee8b4b6115a..db528ef7291cc 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -188,8 +188,7 @@ class MlasTestFixture : public testing::Test { mlas_tester = nullptr; }; - // Do not forgot to define this static member element when upon usage. - static TMlasTester* mlas_tester; + static inline TMlasTester* mlas_tester = nullptr; }; // Long Execute test. It is too heavy to register each single test, treat long execute big groups. diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 46b95a127b75c..e0f63ea58e772 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -32,6 +32,7 @@ #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" +#include "core/optimizer/pad_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" @@ -1080,6 +1081,163 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } +TEST_F(GraphTransformationTests, FusePadWithConv) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-conv.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::vector expected_pads; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Pad") { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + gsl::span pads_values = pads.DataAsSpan(); + expected_pads.resize(pads_values.size() - 4); + + for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) { + expected_pads[index] = pads_values[pads_index]; + expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)]; + } + } else if (node.OpType() == "Conv") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + for (uint32_t index = 0; index < expected_pads.size(); index++) { + expected_pads[index] += child_pads->Get(index); + } + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Pad"], 0); + ASSERT_EQ(op_to_count["Conv"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Conv") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + ASSERT_EQ(child_pads->size(), static_cast(expected_pads.size())) + << "fusion should produce the same size of pads integer as the Conv node"; + for (uint32_t index = 0; index < expected_pads.size(); index++) { + ASSERT_EQ(expected_pads[index], child_pads->Get(index)) + << "fusion does not produce correct padding value"; + } + } + } +} + +TEST_F(GraphTransformationTests, FusePadWithMaxPool) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::vector expected_pads; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Pad") { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + gsl::span pads_values = pads.DataAsSpan(); + expected_pads.resize(pads_values.size() - 4); + + for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) { + expected_pads[index] = pads_values[pads_index]; + expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)]; + } + } else if (node.OpType() == "MaxPool") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + for (uint32_t index = 0; index < expected_pads.size(); index++) { + expected_pads[index] += child_pads->Get(index); + } + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Pad"], 0); + ASSERT_EQ(op_to_count["MaxPool"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "MaxPool") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + ASSERT_EQ(child_pads->size(), static_cast(expected_pads.size())) + << "fusion should produce the same size of pads integer as the MaxPool node"; + for (uint32_t index = 0; index < expected_pads.size(); index++) { + ASSERT_EQ(expected_pads[index], child_pads->Get(index)) + << "fusion does not produce correct padding value"; + } + } + } +} + +TEST_F(GraphTransformationTests, FusePadWithMaxPoolOpsetLessThan11) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool-opset8.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::vector expected_pads; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Pad") { + gsl::span pads_values = node.GetAttributes().at("pads").ints(); + expected_pads.resize(pads_values.size() - 4); + + for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) { + expected_pads[index] = pads_values[pads_index]; + expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)]; + } + } else if (node.OpType() == "MaxPool") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + for (uint32_t index = 0; index < expected_pads.size(); index++) { + expected_pads[index] += child_pads->Get(index); + } + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Pad"], 0); + ASSERT_EQ(op_to_count["MaxPool"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "MaxPool") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + ASSERT_EQ(child_pads->size(), static_cast(expected_pads.size())) + << "fusion should produce the same size of pads integer as the MaxPool node"; + for (uint32_t index = 0; index < expected_pads.size(); index++) { + ASSERT_EQ(expected_pads[index], child_pads->Get(index)) + << "fusion does not produce correct padding value"; + } + } + } +} + TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; @@ -1438,7 +1596,7 @@ TEST_F(GraphTransformationTests, NotWhereFusion) { ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where } -#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) +#if (defined(USE_CUDA) || defined(USE_JSEP)) && !defined(DISABLE_CONTRIB_OPS) // Conv->Add->Relu will be transformed to FusedConv TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; @@ -1618,6 +1776,10 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCudaExecutionProvider); } +#elif defined(USE_JSEP) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kJsExecutionProvider); + } #endif std::map op_to_count_before_fusion = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count_before_fusion[model.second] >= 1); @@ -1632,6 +1794,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { std::set cuda_rocm_supported = {"Relu"}; if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) { ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); + } else { + ASSERT_EQ(op_to_count_after_fusion[model.second], 0); + } +#elif defined(USE_JSEP) + std::set js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"}; + if (js_supported.find(model.second) == js_supported.end()) { + ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); } else { ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b1a04a00e89b1..6d075fec997b5 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -60,7 +60,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" - "\t [OpenVINO only] [enable_vpu_fast_compile]: Optionally enabled to speeds up the model's compilation on VPU device targets.\n" + "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" @@ -72,7 +72,7 @@ namespace perftest { "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_vpu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 41a1eafebbb50..b7a111783fc94 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -240,8 +240,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (key == "device_type") { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { ov_options[key] = value; } else if (value.find("HETERO:") == 0) { @@ -254,17 +253,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } else if (key == "device_id") { ov_options[key] = value; - } else if (key == "enable_vpu_fast_compile") { + } else if (key == "enable_npu_fast_compile") { if (value == "true" || value == "True" || value == "false" || value == "False") { ov_options[key] = value; } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_vpu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); } } else if (key == "enable_opencl_throttling") { if (value == "true" || value == "True" || @@ -299,7 +298,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_vpu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); + ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); } } session_options.AppendExecutionProvider("OpenVINO", ov_options); diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 459a8c71ad611..16cce85f7cb0a 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -399,6 +399,8 @@ bool SetEpsForAllNodes(Graph& graph, const std::vector>& execution_providers, const std::vector>* custom_registries) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; + const KernelRegistry::TypeConstraintMap type_constraint_map{}; + for (auto& node : graph.Nodes()) { if (node.OpType() == kConstant) continue; @@ -426,13 +428,28 @@ bool SetEpsForAllNodes(Graph& graph, break; } + // check the internal NHWC domain if EP requests NHWC as it may only have a kernel registered in that domain + if (ep->GetPreferredLayout() == DataLayout::NHWC) { + const KernelCreateInfo* kci = nullptr; + auto status = ep->GetKernelRegistry()->TryFindKernel(ep->Type(), + std::string_view(node.OpType()), + std::string_view(kMSInternalNHWCDomain), + node.SinceVersion(), + type_constraint_map, + &kci); + if (status.IsOK() && kci != nullptr) { + found = true; + break; + } + } + // Check the EP has an impl for the node from custom_registries if (custom_registries != nullptr && std::any_of(custom_registries->cbegin(), custom_registries->cend(), - [&](auto reg) { return KernelRegistry::HasImplementationOf( - *reg->GetKernelRegistry(), - node, ep->Type(), - kernel_type_str_resolver); })) { + [&](auto reg) { + return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(), + kernel_type_str_resolver); + })) { found = true; break; } @@ -760,7 +777,7 @@ void BaseTester::ExecuteModelForEps( for (const auto& ep : execution_providers) { providers.append(ep->Type() + " "); } - LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << "were unable to run the model."; + LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << " were unable to run the model."; return; } diff --git a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc index e37206d6aebf2..b7cead66bd7fb 100644 --- a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc @@ -143,7 +143,7 @@ void L1NormalizationWithZeroNorm() { vector expected_output = {0.5f, 0.5f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L1NormalizationWithZeroNorm) { @@ -163,7 +163,7 @@ void L2NormalizationWithZeroNorm() { vector expected_output = {1.f, 0.f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L2NormalizationWithZeroNorm) { diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index d1a523b1eecf9..b9875b9553a55 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -762,7 +762,7 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT - test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); }; // should batch batch_size to be valid @@ -860,7 +860,7 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(RNNTest, RNN_with_invalid_activation_load_failure) { diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index c95ac1603a317..c3d91100605e9 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -66,7 +66,7 @@ TEST(CompressTest, Compress_3dims_has_extra_condition) { // has condition length = 3 > input_dim[axis] = 2 test.AddInput("condition", {3}, {0, 1, 1}); test.AddOutput("output", {2, 1, 3}, {4.0f, 5.0f, 6.0f, 10.0f, 11.0f, 12.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(CompressTest, Compress_3dims_has_extra_input) { diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index 2120da604f94a..d2aa5dd428fec 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -99,7 +99,7 @@ TEST(TensorOpTest, Unsqueeze_scalar_2) { test.AddInput("input", {}, std::vector{1.0f}); test.AddInput("axes", {2}, std::vector{0, -1}, axes_is_initializer); test.AddOutput("output", {1, 1}, std::vector{1.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; run_test(false); run_test(true); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc index be0082f95feb8..13d4546d669e3 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc @@ -21,7 +21,7 @@ struct ConvOp { std::unique_ptr get_test() { RandomValueGenerator random{}; - auto test = std::make_unique("Conv", 7); + auto test = std::make_unique("Conv", 11); // internal NHWC domain starts at opset 11 std::vector input_data = random.Uniform(input_dims, 0.0f, 1.0f); std::vector weight_dims{channels, input_dims[1] / group, kernel_shape[0], kernel_shape[1]}; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index f7499fd7ad812..8955a83e66c01 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -203,7 +203,8 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { } TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model.onnx"; + // the internal NHWC domain supports opset 11 and later + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx"; SessionOptions so; // set this if you want to manually inspect the optimized model diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc index 9b65ca7bda3e2..b4e8f5390787c 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -175,13 +175,7 @@ static void RunBatchNormQDQTest(const TestInputDef& input_def, // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 3. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 4. -// Output quant params: scale=0.019084848463535309, zero_point=9. -// Expected val: 1.7755576372146606 -// QNN QDQ val: 2.9963212013244629 (err 1.2207635641098022) -// CPU QDQ val: 0.82064849138259888 (err 0.95490914583206177) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { +TEST_F(QnnHTPBackendTests, BatchNorm1D) { constexpr int64_t num_channels = 2; RunBatchNormQDQTest(TestInputDef({1, num_channels, 3}, false, {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data @@ -193,13 +187,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 14. -// Output quant params: scale=0.023071292787790298, zero_point=19. -// Expected val: 2.8554618358612061 -// QNN QDQ val: 5.3294687271118164 (err 2.4740068912506104) -// CPU QDQ val: 1.6611330509185791 (err 1.194328784942627) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm2D) { +TEST_F(QnnHTPBackendTests, BatchNorm2D) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; @@ -226,4 +214,4 @@ TEST_F(QnnHTPBackendTests, BatchNorm3D) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/qnn/pad_op_test.cpp b/onnxruntime/test/providers/qnn/pad_op_test.cpp index e92f0ae770a88..792dbeadfa758 100644 --- a/onnxruntime/test/providers/qnn/pad_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pad_op_test.cpp @@ -167,7 +167,7 @@ TEST_F(QnnCPUBackendTests, Pad2dPadsNotIni) { TEST_F(QnnCPUBackendTests, DISABLED_PadModeReflect) { bool has_constant_value = false; RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), - TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({4}, true, {0, 1, 0, 0}), TestInputDef({1}, true, {0.0f}), {utils::MakeAttribute("mode", "reflect")}, ExpectedEPNodeAssignment::All, @@ -266,13 +266,37 @@ TEST_F(QnnHTPBackendTests, PadHasConstantValueQuantized) { constant_value_quantized); } -// QNN graph execute error. Error code: 6031 -TEST_F(QnnHTPBackendTests, DISABLED_PadReflectMode) { +TEST_F(QnnHTPBackendTests, PadReflectMode) { + bool has_constant_value_input = false; + RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 1, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "reflect")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input); +} + +// Pad amount should not be greater than shape(input[0])[i] - 1 +TEST_F(QnnHTPBackendTests, PadReflectModeOutOfRangePadAmount) { bool has_constant_value_input = false; RunQDQPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), TestInputDef({4}, true, {0, 2, 0, 0}), TestInputDef({1}, true, {0.0f}), {utils::MakeAttribute("mode", "reflect")}, + ExpectedEPNodeAssignment::None, + has_constant_value_input); +} + +TEST_F(QnnHTPBackendTests, Pad4dReflectMode) { + bool has_constant_value_input = false; + RunQDQPadOpTest(TestInputDef({1, 2, 2, 2}, false, + {1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f, + 7.0f, 8.0f}), + TestInputDef({8}, true, {0, 1, 1, 1, 0, 1, 1, 1}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "reflect")}, ExpectedEPNodeAssignment::All, has_constant_value_input); } diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index ecf4b001eec68..c48b07422d452 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -140,6 +140,9 @@ def create_backend_test(test_name=None): if backend.supports_device("OPENVINO_CPU_FP16"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16") + if backend.supports_device("OPENVINO_NPU_FP16"): + current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU_FP16") + if backend.supports_device("OPENVINO"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18") diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index e0fb3979a9f55..6f691972181b5 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -7,7 +7,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT, INT64 +from onnxscript import FLOAT, FLOAT16, INT64 import onnxruntime as ort @@ -27,12 +27,23 @@ def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): return np.concatenate(selected_shards, axis=axis) -def translate_device_mesh_to_attrs(device_mesh: np.ndarray): +def translate_single_device_mesh(device_mesh: np.ndarray): device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" return device_mesh_shape, device_mesh_elements +def translate_all_device_meshes(device_meshes: np.ndarray): + assert all(len(mesh.shape) == 1 for mesh in device_meshes) + device_mesh_shapes = [] + device_mesh_elements = [] + for device_mesh in device_meshes: + device_mesh_shape, device_mesh_element = translate_single_device_mesh(device_mesh) + device_mesh_shapes.append(device_mesh_shape) + device_mesh_elements.append(device_mesh_element) + return device_mesh_shapes, device_mesh_elements + + def parse_sharding_spec(spec: str): axis_conditions = [] sharding_device_axes = [] @@ -90,29 +101,13 @@ def _check_distributed_reshape( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -134,11 +129,11 @@ def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = np.reshape(data_tensor, shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_reshape_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -176,9 +171,9 @@ def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): 3, ), target_shape=(6,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]",), ) @@ -191,9 +186,9 @@ def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): 4, ), target_shape=(8,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]",), ) @@ -210,9 +205,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): 2, 15, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -229,9 +224,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): 2, 20, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -248,9 +243,9 @@ def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): 2, 18, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) # Two axis fusion. @@ -268,9 +263,9 @@ def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): 2, 3, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -283,9 +278,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): 1, 16, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -298,9 +293,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -313,9 +308,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): 4, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -328,9 +323,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): 8, 2, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -343,9 +338,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): 16, 1, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -359,9 +354,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self) 1, 16, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -375,9 +370,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -390,9 +385,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): 4, 4, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -405,9 +400,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): 8, 2, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -420,9 +415,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self) 16, 1, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -444,9 +439,9 @@ def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -471,9 +466,9 @@ def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rr 64, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]R",), ) @@ -495,9 +490,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -519,9 +514,9 @@ def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self) 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRR",), ) @@ -546,9 +541,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_0101 7, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -573,9 +568,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_ 7, 7, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -600,9 +595,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101 7, 7, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -627,9 +622,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_6 7, 64, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -654,9 +649,9 @@ def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(s 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -678,9 +673,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -690,29 +685,16 @@ def _check_distributed_expand( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -734,11 +716,11 @@ def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = data_tensor * np.ones(shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_expand_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -780,9 +762,9 @@ def test_expand_sharded_on_expanded_axis(self): 8, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -799,9 +781,9 @@ def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): 8, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -818,9 +800,9 @@ def test_expand_replicated_on_expanded_axis(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -837,12 +819,12 @@ def test_expand_with_pass_through_sharding_spec(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=( "S[0]R", "R", ), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -863,13 +845,160 @@ def test_expand_in_tiny_llama(self): 256, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) +class TestDistributedReduce(unittest.TestCase): + def _check_distributed_reduce( + self, + keepdims: int, + dtype: np.dtype, + shape: Tuple[int, ...], + axes: Tuple[int, ...], + input_device_meshes: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshes: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) + + @onnxscript.script() + def distributed_reduce_sum_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceSum( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_max_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMax( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMean( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + + for onnx_func, np_func in zip( + [distributed_reduce_sum_instance, distributed_reduce_max_instance, distributed_reduce_mean_instance], + [np.sum, np.maximum.reduce, np.mean], + ): + data = np.random.randint(4, size=shape).astype(dtype) + expected = np_func(data, axis=axes, keepdims=bool(keepdims)) + + assert len(input_shard_specs) == 2 and len(input_device_meshes) == 2, "Reduce has two inputs." + assert "S" not in input_shard_specs[1], "Tensor `axes` should not be sharded." + assert len(output_shard_specs) == 1 and len(output_device_meshes) == 1, "Reduce has only one output." + + local_data = shard_tensor_per_spec(data, rank, input_shard_specs[0], input_device_meshes[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) + + if dtype == np.float32: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + elif dtype == np.int64: + onnx_model = onnx_func.to_model_proto( + input_types=[INT64[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[INT64[tuple(local_expected.shape)]], + ) + elif dtype == np.float16: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT16[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT16[tuple(local_expected.shape)]], + ) + else: + raise RuntimeError(f"Unsupported dtype: {dtype}") + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data, + "axes_tensor": np.array(axes, dtype=np.int64), + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reduce(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(0,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reduce_sharded(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(1,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. diff --git a/onnxruntime/test/python/quantization/test_op_gemm.py b/onnxruntime/test/python/quantization/test_op_gemm.py index 54ef1cc1d5446..bac0f6d48e9fc 100644 --- a/onnxruntime/test/python/quantization/test_op_gemm.py +++ b/onnxruntime/test/python/quantization/test_op_gemm.py @@ -192,24 +192,9 @@ def static_quant_test( check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes) data_reader.rewind() if activation_type_str == "f8e4m3fn" and weight_type_str == "f8e4m3fn": - # QGemm is not implemented for CPU. - try: - check_model_correctness( - self, - model_fp32_path, - model_int8_path, - data_reader.get_next(), - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - is_gemm=True, - ) - except Exception as e: - if ( - "Type 'tensor(float8e4m3fn)' of input parameter (input_quantized) of operator (QGemm) in node () is invalid." - in str(e) - ): - warnings.warn("Fix this test when QGemm is implemented.") - return - raise e + # QGemm for float 8 is not implemented. The test should be updated when it is. + warnings.warn("Fix this test when QGemm is implemented for float 8 types.") + return else: check_model_correctness(self, model_fp32_path, model_int8_path, data_reader.get_next(), is_gemm=True) diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py index e03a0167d070a..765825d4b86e3 100644 --- a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -38,8 +38,8 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=matrix_float_padded.dtype) - zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") matrix_float_padded = np.transpose(matrix_float_padded) for n in range(cols): @@ -61,10 +61,12 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i zp = min(15, max(0, round(zero_point_fp))) reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 - block_idx = n * k_blocks + k_id // block_size - scales[block_idx] = scale - zp_pair = zero_point[block_idx // 2] - zero_point[block_idx // 2] = ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + block_idx = k_id // block_size + scales[n, block_idx] = scale + zp_pair = zero_point[n, block_idx // 2] + zero_point[n, block_idx // 2] = ( + ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + ) blk_int0 = np.clip( np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), @@ -76,7 +78,7 @@ def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, i 0, 15, ).astype("uint8") - packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + packed[n, block_idx] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) return (packed, scales, zero_point) @@ -88,8 +90,8 @@ def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int k_blocks = (rows + block_size - 1) // block_size packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=matrix_float.dtype) - zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols, (k_blocks + 1) // 2), 136, dtype="uint8") from onnxruntime.capi._pybind_state import quantize_matmul_4bits quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) @@ -116,24 +118,22 @@ def test_quantize_blockwise_4bits(self): assert np.allclose(zero_point_ref, zero_point) for c in range(quant_value_ref.shape[0]): for k in range(quant_value_ref.shape[1]): - block_idx = c * quant_value_ref.shape[1] + k - zp_idx = block_idx // 2 assert np.allclose( dequantize_blockwise_4bits( - quant_value_ref[c][k], - scales_ref[block_idx], - (zero_point_ref[zp_idx] >> 4) - if (block_idx & 1) - else (zero_point_ref[zp_idx] & 0x0F), + quant_value_ref[c, k], + scales_ref[c, k], + (zero_point_ref[c, k // 2] >> 4) + if (k & 1) + else (zero_point_ref[c, k // 2] & 0x0F), min(block_size, rows - k * block_size), ), dequantize_blockwise_4bits( - quant_value[c][k], - scales[block_idx], - (zero_point[zp_idx] >> 4) if (block_idx & 1) else (zero_point[zp_idx] & 0x0F), + quant_value[c, k], + scales[c, k], + (zero_point[c, k // 2] >> 4) if (k & 1) else (zero_point[c, k // 2] & 0x0F), min(block_size, rows - k * block_size), ), - atol=1.2 * abs(scales[block_idx]), + atol=1.2 * abs(scales[c, k]), ) diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 04351cd6e6782..319fed87dc9eb 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -10,7 +10,10 @@ # license information. # ------------------------------------------------------------------------- import math +import os +import platform import random +import unittest import numpy import torch @@ -22,6 +25,8 @@ torch.manual_seed(0) +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + class Formats: BSNH = 0 @@ -159,7 +164,7 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_no_past(config, causal=False): +def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH): nodes = [ helper.make_node( "GroupQueryAttention", @@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False): "key", "value", ], - ["output"], + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, unidirectional=1 if causal else 0, + is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0, domain="com.microsoft", ), ] @@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False): TensorProto.FLOAT16, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen): return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) -# TODO(aciddelgado): rename def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): onnx_model_str = create_packed_multihead_attention_graph(config) qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) @@ -548,8 +573,8 @@ def mha_func(q, k, v, config): return output -def gqa_no_past_func(q, k, v, config, causal=True): - onnx_model_str = create_group_query_attention_graph_no_past(config, causal) +def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH): + onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) @@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True): } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) - ort_output = ort_session.run(None, ort_inputs) + ort_output, _, _ = ort_session.run(None, ort_inputs) ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) return output @@ -689,17 +714,12 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 - # ) causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) if causal: # Some rows are completely masked out so we fill them with zero instead of NaN attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: @@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(present_k[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref.shape) - - # print(present_k - k_cache_ref.detach().cpu().numpy()) - # Make sure past-present buffer updating correctly if past_format == Formats.BSNH: assert numpy.allclose( @@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff( ) +class TestMHA(unittest.TestCase): + def test_packed_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST PACKED MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) + + def test_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + torch.manual_seed(69) + print("-------- TEST GQA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1024, 1024), + (1023, 1024), + (2048, 2048), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + if major < 5 or (major == 5 and minor < 3): + return + print("------- MEMORY EFFICIENT ATTENTION ---------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + + def test_gqa_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- TEST GQA PAST ---------") + print("-------- MEMORY EFFICEINT --------") + batches = [2] if pipeline_mode else [1, 2] + seqs = ( + [(1, 128), (3, 1024), (64, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 512), + (16, 128 * 512), + (128, 128), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + + if __name__ == "__main__": - print("-------- TEST PACKED MHA ---------") - for b in [5]: - for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s, 0, n, n, h) - parity_check_mha(config, True) - print("-------- TEST MHA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s2, 0, n, n, h) - parity_check_mha(config, False) - print("-------- TEST GQA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True, False]: - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) - print("-------- TEST GQA PAST ---------") - random.seed(69) - for b in [2]: - for s, s2 in [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 512), - (16, 128 * 512), - (128, 128), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py new file mode 100644 index 0000000000000..bf295a65c8b53 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -0,0 +1,541 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import statistics +from dataclasses import dataclass +from enum import Enum +from time import perf_counter +from typing import Optional, Tuple + +import numpy +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +torch.manual_seed(0) + + +class GroupNormOpType(Enum): + GROUP_NORM = 1 + SKIP_GROUP_NORM = 2 + + +@dataclass +class GroupNormConfig: + batch_size: int + height: int + width: int + channels: int + epsilon: float = 1e-5 + num_groups: int = 32 + activation: bool = False + channels_last: bool = True + fp16: bool = False + + op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM + has_bias: bool = False + has_add_out: bool = False + broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C) + + def get_skip_symbolic_shape(self): + skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]} + return skip_shape[self.broadcast_skip] + + def get_skip_shape(self): + skip_shape = { + 0: [self.batch_size, self.height, self.width, self.channels], + 2: [self.batch_size, self.channels], + 4: [self.batch_size, 1, 1, self.channels], + } + return skip_shape[self.broadcast_skip] + + def broadcast(self, skip: torch.Tensor): + if self.broadcast_skip == 2: + return skip.reshape(self.batch_size, 1, 1, self.channels) + + return skip + + @staticmethod + def create( + b: int, + h: int, + w: int, + c: int, + fp16: bool = False, + activation: bool = False, + template: int = 0, + num_groups: int = 32, + ): + if template == 0: + return GroupNormConfig( + b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups + ) + + if template == 1: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 2: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=2, + num_groups=num_groups, + ) + + if template == 3: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=False, + broadcast_skip=4, + num_groups=num_groups, + ) + + if template == 4: # No bias + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 5: # No bias, no add_out + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=0, + num_groups=num_groups, + ) + + return None + + +def create_group_norm_graph(config: GroupNormConfig) -> bytes: + inputs = ["input", "gamma", "beta"] + outputs = ["output"] + op_type = "GroupNorm" + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + op_type = "SkipGroupNorm" + inputs = [*inputs, "skip"] + if config.has_bias: + inputs = [*inputs, "bias"] + if config.has_add_out: + outputs = [*outputs, "add_out"] + + nodes = [ + helper.make_node( + op_type, + inputs, + outputs, + op_type + "_0", + activation=int(config.activation), + channels_last=int(config.channels_last), + epsilon=config.epsilon, + groups=config.num_groups, + domain="com.microsoft", + ), + ] + + float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT + + input_shapes = [ + helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]), + ] + output_shapes = [ + helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]), + ] + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + input_shapes = [ + *input_shapes, + helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()), + ] + if config.has_bias: + input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])] + if config.has_add_out: + output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])] + + graph = helper.make_graph( + nodes, + "Group_Norm_Graph", + input_shapes, + output_shapes, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def group_norm_ort( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, + measure_latency=False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]: + onnx_model_str = create_group_norm_graph(config) + ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"]) + + session = CudaSession(ort_session, device=torch.device("cuda:0")) + + io_shape = { + "input": [config.batch_size, config.height, config.width, config.channels], + "gamma": [config.channels], + "beta": [config.channels], + "output": [config.batch_size, config.height, config.width, config.channels], + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + io_shape["skip"] = config.get_skip_shape() + if config.has_bias: + io_shape["bias"] = [config.channels] + if config.has_add_out: + io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels] + + session.allocate_buffers(io_shape) + + ort_inputs = { + "input": src, + "gamma": gamma, + "beta": beta, + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + ort_inputs["skip"] = skip + if config.has_bias: + ort_inputs["bias"] = bias + + ort_outputs = session.infer(ort_inputs) + output = ort_outputs["output"] + + add_out = ( + ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None + ) + + if measure_latency: + latency_list = [] + for _ in range(10000): + start_time = perf_counter() + session.infer(ort_inputs) + end_time = perf_counter() + latency_list.append(end_time - start_time) + average_latency = statistics.mean(latency_list) + return output, add_out, average_latency + + return output, add_out, None + + +def group_norm_torch( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + add_out = src + + if skip is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + config.broadcast(skip) + + if bias is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0]) + + x = add_out + if config.channels_last: + x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW + + weight = gamma.to(x.dtype) + bias = beta.to(x.dtype) + output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon) + + if config.activation: + torch.nn.functional.silu(output, inplace=True) + + if config.channels_last: + output = output.permute(0, 2, 3, 1) # from NCHW to NHWC + + return output, add_out + + +def print_tensor(name, tensor): + # Print in the format that could be directly added to unit tests in C++. + torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000) + print(name) + if tensor is not None: + print("shape", tensor.shape) + text = str(tensor.clone().flatten()) + print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,")) + else: + print(tensor) + + +def run_parity(config, measure_latency=True, verbose=False): + float_type = torch.float16 if config.fp16 else torch.float32 + + input_tensor = torch.randn( + config.batch_size, + config.height, + config.width, + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + gamma = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + beta = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + skip = None + bias = None + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + skip = torch.randn( + *config.get_skip_shape(), + device="cuda", + dtype=float_type, + requires_grad=False, + ) + if config.has_bias: + bias = torch.randn( + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + if verbose: + print(config) + print_tensor("input", input_tensor) + print_tensor("gamma", gamma) + print_tensor("beta", beta) + print_tensor("skip", skip) + print_tensor("bias", bias) + + out_ort, ort_add_out, latency = group_norm_ort( + input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency + ) + + if verbose: + print_tensor("out_ort", out_ort) + print_tensor("ort_add_out", ort_add_out) + + torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config) + + if verbose: + print_tensor("torch_out", torch_out) + print_tensor("torch_add_out", torch_add_out) + + average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy())) + + is_close = numpy.allclose( + out_ort.detach().cpu().numpy(), + torch_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + + is_add_out_close = ( + numpy.allclose( + ort_add_out.detach().cpu().numpy(), + torch_add_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + if ort_add_out is not None + else "" + ) + + # Compare results + print( + config.op_type.name, + " B:", + config.batch_size, + " H:", + config.height, + " W:", + config.width, + " C:", + config.channels, + " G:", + config.num_groups, + " activation:", + int(config.activation), + " channels_last:", + int(config.channels_last), + " fp16:", + int(config.fp16), + f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "", + " AvgDiff:", + average_diff, + " Pass:", + is_close, + is_add_out_close, + ) + + +def get_latent_height_width(): + default_size = [(512, 512), (768, 768), (1024, 1024)] + small_img_size = [(512, 768), (768, 512)] + xl_img_size = [ + (1152, 896), + (896, 1152), + (1216, 832), + (832, 1216), + (1344, 768), + (768, 1344), + (1536, 640), + (640, 1536), + ] + return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size] + + +def get_channels(): + return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304] + + +def run_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32") + for b in [2]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_no_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32") + for b in [1, 2, 4]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_all_groups(template: int, fp16, measure_latency=False): + group_sizes = [1, 2, 4, 8, 16, 32] + print("Test GroupNorm for different group sizes:", group_sizes) + for group_size in group_sizes: + for h, w in get_latent_height_width()[:3]: + for c in get_channels()[:2]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_odd_channels(template: int, fp16, measure_latency=False): + # Test some random number of channels that can be divisible by 2 * num_groups + for h, w in get_latent_height_width(): + for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_small_inputs(template: int, fp16): + config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + +def run_performance(fp16): + # Run perf test to tune parameters for given number of channels. + for h, w in get_latent_height_width()[:3]: + for c in get_channels(): + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0) + run_parity(config, measure_latency=True) + + +def run_all(template: int): + for fp16 in [True, False]: + run_small_inputs(template, fp16) + run_odd_channels(template, fp16) + run_all_groups(template, fp16) + run_activation(template, fp16) + run_no_activation(template, fp16) + + +def run_not_implemented(): + # Expect failure. Check whether the error message is expected. + try: + config = GroupNormConfig(1, 2, 2, 513, num_groups=3) + run_parity(config) + except RuntimeError as e: + assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e) + + +def main(): + run_performance(True) + + run_not_implemented() + + for template in range(6): + run_all(template) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index fedba2a25dfc2..373ad86ced1a7 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -96,7 +96,7 @@ def create_inputs_and_outputs(self, model_type: str): helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), ] - if model_type in {"past", "merged", "llama2_msft"}: + if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}: inputs.extend( [ helper.make_tensor_value_info( @@ -164,14 +164,14 @@ def get_first_rope_input(node_type: str): if is_fused or model_type == "llama2_msft": # q_out/k_out return f"{node_type}_out" - if model_type in {"no_past", "past", "merged"}: + if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}: if node_type == "k": return "k_before_rope" return "q_before_rope" return "" def get_first_rope_output(node_type: str): - if is_fused or model_type in {"llama2_msft", "past", "merged"}: + if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}: if node_type == "q": return "q_rope" return "k_rope" @@ -295,23 +295,225 @@ def create_k_path_hf(self, model_type: str): ) k_nodes = [reshape_k_node, transpose_k_1_node] - if model_type in {"past", "merged"}: + if model_type == "70b_distributed_merged": concat_k_node = helper.make_node( "Concat", inputs=["past_key", "k_rope"], outputs=["present_key"], axis=2, ) - k_nodes.append(concat_k_node) + shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1") + shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2") + shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3") + shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4") + + gather_k_1 = helper.make_node( + "Gather", + inputs=["shape_k1_out", "one"], + outputs=["gather_k1_out"], + name="Gather_k_1", + axis=0, + ) + gather_k_2 = helper.make_node( + "Gather", + inputs=["shape_k2_out", "one"], + outputs=["gather_k2_out"], + name="Gather_k_2", + axis=0, + ) + gather_k_3 = helper.make_node( + "Gather", + inputs=["shape_k3_out", "one"], + outputs=["gather_k3_out"], + name="Gather_k_3", + axis=0, + ) + gather_k_4 = helper.make_node( + "Gather", + inputs=["shape_k4_out", "one"], + outputs=["gather_k4_out"], + name="Gather_k_4", + axis=0, + ) - transpose_k_2_node = helper.make_node( - "Transpose", - inputs=["present_key"], - outputs=["k"], - name="Transpose_k_2", - perm=[0, 1, 3, 2], - ) - return k_nodes + [transpose_k_2_node] # noqa: RUF005 + unsqueeze_k_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_k1_out"], + name="Unsqueeze_k1", + ) + unsqueeze_k_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k2_out"], + name="Unsqueeze_k2", + ) + unsqueeze_k_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_k2_out", "zero"], + outputs=["unsqueeze_k3_out"], + name="Unsqueeze_k3", + ) + unsqueeze_k_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k4_out"], + name="Unsqueeze_k4", + ) + unsqueeze_k_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k5_out"], + name="Unsqueeze_k5", + ) + + concat_k_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"], + outputs=["concat_k2_ouot"], + name="Concat_k2", + axis=0, + ) + reshape_k_2 = helper.make_node( + "Reshape", + inputs=["concat_k2_ouot", "One"], + outputs=["reshape_k2_out"], + name="Reshape_k_2", + ) + shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5") + constant_of_shape_k_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_k5_out"], + outputs=["constant_of_shape_k1_out"], + name="ConstantOfShape_k1", + ) + mul_k_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_k1_out", "One"], + outputs=["mul_k1_out"], + name="mul_k1", + ) + equal_k_1 = helper.make_node( + "Equal", + inputs=["reshape_k2_out", "mul_k1_out"], + outputs=["equal_k_1_out"], + name="equal_k1", + ) + where_k_1 = helper.make_node( + "Where", + inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"], + outputs=["where_k_1_out"], + name="where_k1", + ) + unsqueeze_k_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k6_out"], + name="Unsqueeze_k6", + ) + mul_k_2 = helper.make_node( + "Mul", + inputs=["gather_k2_out", "One"], + outputs=["mul_k2_out"], + name="mul_k2", + ) + unsqueeze_k_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_k2_out", "zero"], + outputs=["unsqueeze_k7_out"], + name="Unsqueeze_k7", + ) + unsqueeze_k_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k8_out"], + name="Unsqueeze_k8", + ) + unsqueeze_k_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k9_out"], + name="Unsqueeze_k9", + ) + concat_k_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"], + outputs=["concat_k3_out"], + name="Concat_k3", + axis=0, + ) + expand_k_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_k1_out", "where_k_1_out"], + outputs=["expand_k1_out"], + name="expand_k1", + ) + reshape_k_3 = helper.make_node( + "Reshape", + inputs=["expand_k1_out", "concat_k3_out"], + outputs=["reshape_k3_out"], + name="Reshape_k_3", + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["reshape_k3_out"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + + k_nodes_for_70b_model = [ + concat_k_node, + shape_k1, + shape_k2, + shape_k3, + shape_k4, + gather_k_1, + gather_k_2, + gather_k_3, + gather_k_4, + unsqueeze_k_1, + unsqueeze_k_2, + unsqueeze_k_3, + unsqueeze_k_4, + unsqueeze_k_5, + concat_k_2, + reshape_k_2, + shape_k5, + constant_of_shape_k_1, + mul_k_1, + equal_k_1, + where_k_1, + unsqueeze_k_6, + mul_k_2, + unsqueeze_k_7, + unsqueeze_k_8, + unsqueeze_k_9, + concat_k_3, + expand_k_1, + reshape_k_3, + transpose_k_2_node, + ] + k_nodes.extend(k_nodes_for_70b_model) + return k_nodes + else: + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 def create_k_path(self, model_type: str): if model_type == "llama2_msft": @@ -505,7 +707,7 @@ def create_v_path(self, model_type: str): if model_type == "no_past": return v_nodes - if model_type in {"past", "merged"}: + if model_type in {"past", "merged", "70b_distributed_merged"}: concat_v_node = helper.make_node( "Concat", inputs=["past_value", "transpose_v_1_out"], @@ -513,7 +715,194 @@ def create_v_path(self, model_type: str): name="Concat_v", axis=2, ) - return v_nodes + [concat_v_node] # noqa: RUF005 + + if model_type != "70b_distributed_merged": + return v_nodes + [concat_v_node] # noqa: RUF005 + + shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1") + shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2") + shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3") + shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4") + gather_v_1 = helper.make_node( + "Gather", + inputs=["shape_1_out", "one"], + outputs=["gather_1_out"], + name="Gather_v1", + axis=0, + ) + gather_v_2 = helper.make_node( + "Gather", + inputs=["shape_2_out", "one"], + outputs=["gather_2_out"], + name="Gather_v2", + axis=0, + ) + gather_v_3 = helper.make_node( + "Gather", + inputs=["shape_3_out", "one"], + outputs=["gather_3_out"], + name="Gather_v3", + axis=0, + ) + gather_v_4 = helper.make_node( + "Gather", + inputs=["shape_4_out", "one"], + outputs=["gather_4_out"], + name="Gather_v4", + axis=0, + ) + unsqueeze_v_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_v1_out"], + name="Unsqueeze_v1", + ) + unsqueeze_v_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v2_out"], + name="Unsqueeze_v2", + ) + unsqueeze_v_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_2_out", "zero"], + outputs=["unsqueeze_v3_out"], + name="Unsqueeze_v3", + ) + unsqueeze_v_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v4_out"], + name="Unsqueeze_v4", + ) + unsqueeze_v_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v5_out"], + name="Unsqueeze_v5", + ) + concat_v_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"], + outputs=["concat_v2_ouot"], + name="Concat_v2", + axis=0, + ) + reshape_v_2 = helper.make_node( + "Reshape", + inputs=["concat_v2_ouot", "One"], + outputs=["reshape_v2_out"], + name="Reshape_v2", + ) + shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5") + constant_of_shape_v_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_5_out"], + outputs=["constant_of_shape_v1_out"], + name="ConstantOfShape_v1", + ) + mul_v_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_v1_out", "One"], + outputs=["mul_v1_out"], + name="mul_v1", + ) + equal_v_1 = helper.make_node( + "Equal", + inputs=["reshape_v2_out", "mul_v1_out"], + outputs=["equal_v_1_out"], + name="equal_v1", + ) + where_v_1 = helper.make_node( + "Where", + inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"], + outputs=["where_v_1_out"], + name="where_v1", + ) + unsqueeze_v_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v6_out"], + name="Unsqueeze_v6", + ) + mul_v_2 = helper.make_node( + "Mul", + inputs=["gather_2_out", "One"], + outputs=["mul_v2_out"], + name="mul_v2", + ) + unsqueeze_v_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_v2_out", "zero"], + outputs=["unsqueeze_v7_out"], + name="Unsqueeze_v7", + ) + unsqueeze_v_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v8_out"], + name="Unsqueeze_v8", + ) + unsqueeze_v_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v9_out"], + name="Unsqueeze_v9", + ) + concat_v_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"], + outputs=["concat_v3_out"], + name="Concat_v3", + axis=0, + ) + expand_v_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_v1_out", "where_v_1_out"], + outputs=["expand_v1_out"], + name="expand_v1", + ) + reshape_v_3 = helper.make_node( + "Reshape", + inputs=["expand_v1_out", "concat_v3_out"], + outputs=["reshape_v3_out"], + name="Reshape_v3", + ) + + v_nodes_for_70b_model = [ + concat_v_node, + shape_v1, + shape_v2, + shape_v3, + shape_v4, + gather_v_1, + gather_v_2, + gather_v_3, + gather_v_4, + unsqueeze_v_1, + unsqueeze_v_2, + unsqueeze_v_3, + unsqueeze_v_4, + unsqueeze_v_5, + concat_v_2, + reshape_v_2, + shape_v5, + constant_of_shape_v_1, + mul_v_1, + equal_v_1, + where_v_1, + unsqueeze_v_6, + mul_v_2, + unsqueeze_v_7, + unsqueeze_v_8, + unsqueeze_v_9, + concat_v_3, + expand_v_1, + reshape_v_3, + ] + v_nodes.extend(v_nodes_for_70b_model) + + return v_nodes # Create extra nodes for `position_ids` unsqueeze_v_node = helper.make_node( @@ -672,7 +1061,28 @@ def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[Nod return extra_nodes - def create_end_nodes(self): + def create_end_nodes(self, model_type): + if model_type == "70b_distributed_merged": + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + all_reduce = helper.make_node( + "AllReduce", + inputs=["output_proj"], + outputs=["allreduce_proj"], + name="allreduce_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "allreduce_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, all_reduce, end_node] + matmul_o_node = helper.make_node( "MatMul", inputs=["attn_output", "o_weight"], @@ -711,7 +1121,7 @@ def create_fused_model(self, model_type: str, interleaved: bool, initializers: L num_heads=self.num_heads, ) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) graph = helper.make_graph( nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, @@ -740,7 +1150,7 @@ def create_test_model(self, model_type: str, interleaved: bool, initializers: Li reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes @@ -790,6 +1200,11 @@ def test_hf_decoder_merged_model(self): interleaved = False self.check_models(model_type, interleaved) + def test_hf_70b_distributed_decoder_merged_model(self): + model_type = "70b_distributed_merged" + interleaved = False + self.check_models(model_type, interleaved) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ba282193c5ca6..7dee0bc41a6f3 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2832,6 +2832,58 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #endif #ifdef USE_TENSORRT +TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { + const auto& api = Ort::GetApi(); + + // Test input tensor which is shape tensor with explicit trt profile shapes + Ort::SessionOptions session_options; + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptr + rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + + const char* trt_profile_min_shapes = "data:2x2,shape:4x1"; + const char* trt_profile_max_shapes = "data:2x2,shape:4x1"; + const char* trt_profile_opt_shapes = "data:2x2,shape:4x1"; + std::vector keys{"trt_profile_min_shapes", "trt_profile_max_shapes", "trt_profile_opt_shapes"}; + std::vector values{trt_profile_min_shapes, trt_profile_max_shapes, trt_profile_opt_shapes}; + ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); + + auto model_path = ORT_TSTR("testdata/trt_reshape.onnx"); + + std::vector input_value_0{1.1f, 1.2f, 1.3f, 1.4f}; + std::vector input_shape_0{2, 2}; + std::vector input_value_1{4, 1}; + std::vector input_shape_1{2}; + + std::vector input_names{"data", "shape"}; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + std::vector ort_inputs; + ort_inputs.emplace_back(Ort::Value::CreateTensor(info, input_value_0.data(), input_value_0.size(), input_shape_0.data(), input_shape_0.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor(info, input_value_1.data(), input_value_1.size(), input_shape_1.data(), input_shape_1.size())); + + const char* output_names[] = {"reshaped"}; + + Ort::Session session(*ort_env, model_path, session_options); + session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); + + // Test input tensor which is shape tensor with implicit trt profile shapes + Ort::SessionOptions session_options_2; + OrtTensorRTProviderOptionsV2* trt_options_2; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options_2) == nullptr); + std::unique_ptr + rel_trt_options_2(trt_options_2, api.ReleaseTensorRTProviderOptions); + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options_2), + rel_trt_options_2.get()) == nullptr); + Ort::Session session_2(*ort_env, model_path, session_options_2); + session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); +} + TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; @@ -3323,6 +3375,22 @@ TEST(LiteCustomOpTest, CustomFunc) { ASSERT_TRUE(floats_output[1] == 16); } +TEST(LiteCustomOpTest, CustomFuncOpsetMismatch) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + EXPECT_THROW(Ort::Session(*ort_env, TSTR("testdata/fuse_select_filter_opset_8.onnx"), session_options), std::exception); +} + struct Merge { Merge(const OrtApi* ort_api, const OrtKernelInfo* info) { int64_t reverse; diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index ad99b675c7d20..85edfa0e59f1d 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -94,23 +94,28 @@ void Select(const Ort::Custom::Span& indices_in, } } -void Filter(const Ort::Custom::Tensor& floats_in, - Ort::Custom::Tensor& floats_out) { - const float* in = floats_in.Data(); - auto in_len = floats_in.NumberOfElement(); +struct Filter { + Filter(const OrtApi*, const OrtKernelInfo*) {} + Ort::Status Compute(const Ort::Custom::Tensor& floats_in, + Ort::Custom::Tensor& floats_out) { + const float* in = floats_in.Data(); + auto in_len = floats_in.NumberOfElement(); + + std::vector filter_floats; + for (int64_t i = 0; i < in_len; ++i) { + if (in[i] > 1.f) { + filter_floats.push_back(in[i]); + } + } - std::vector filter_floats; - for (int64_t i = 0; i < in_len; ++i) { - if (in[i] > 1.f) { - filter_floats.push_back(in[i]); + float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); + for (size_t j = 0; j < filter_floats.size(); ++j) { + out[j] = filter_floats[j]; } - } - float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); - for (size_t j = 0; j < filter_floats.size(); ++j) { - out[j] = filter_floats[j]; + return Ort::Status{nullptr}; } -} +}; void Box(const Ort::Custom::Tensor* float_in_1, const Ort::Custom::Tensor* float_in_2, @@ -293,9 +298,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; static const std::unique_ptr c_MulTopOpFloat{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; static const std::unique_ptr c_MulTopOpInt32{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; - static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse)}; + static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse, {}, 10, 12)}; static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; - static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; + static const std::unique_ptr c_Filter{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", 15, 17)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; static const std::unique_ptr c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic)}; static const std::unique_ptr c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined)}; @@ -314,7 +319,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_MulTopOpInt32.get()); domain.Add(c_Fuse.get()); domain.Add(c_Select.get()); - domain.Add(c_Fill.get()); + domain.Add(c_Filter.get()); domain.Add(c_Box.get()); domain.Add(c_CopyTensorArrayAllVariadic.get()); domain.Add(c_CopyTensorArrayCombined.get()); diff --git a/onnxruntime/test/testdata/fuse_select_filter.onnx b/onnxruntime/test/testdata/fuse_select_filter.onnx index 15d7dd64788d3..0b881228edb9d 100644 --- a/onnxruntime/test/testdata/fuse_select_filter.onnx +++ b/onnxruntime/test/testdata/fuse_select_filter.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä P vector_1 vector_2 @@ -25,4 +25,5 @@ N ÿÿÿÿÿÿÿÿÿb& vector_filtered  - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file diff --git a/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx new file mode 100644 index 0000000000000..3ea27767eb9f5 --- /dev/null +++ b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx @@ -0,0 +1,29 @@ + :Ä +P +vector_1 +vector_2 +alpha vector_fused fuse_node"Fuse* + fuse_algo :v2 +4 +indicesindices_selected select_node"Select:v2 +N + vector_fused +indices_selectedvector_gathered gather_node"GatherElements +; +vector_gatheredvector_filtered filter_node"Filter:v2graphZ +vector_1 + + ÿÿÿÿÿÿÿÿÿZ +vector_2 + + ÿÿÿÿÿÿÿÿÿZ +alpha + + ÿÿÿÿÿÿÿÿÿZ +indices + + ÿÿÿÿÿÿÿÿÿb& +vector_filtered + + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 44db7c0078cfc..c552ec3aea72d 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -521,6 +521,10 @@ "test_scan_sum_cpu", // Disabled due to output mismatch with tolerance. "test_scan9_sum_cpu" // Disabled due to output mismatch with tolerance. ], + "current_failing_tests_OPENVINO_NPU_FP16": [ + "^test_prelu_broadcast", + "test_loop11_cpu" + ], "current_failing_tests_OPENVINO_opset18": [ // pending opset 18 support, RUNTIME_EXCEPTION : Encountered unknown exception in Initialize() "^test_center_crop_pad_crop_axes_chw", diff --git a/onnxruntime/test/testdata/squeezenet/model_opset11.onnx b/onnxruntime/test/testdata/squeezenet/model_opset11.onnx new file mode 100644 index 0000000000000..dcf322a58c042 Binary files /dev/null and b/onnxruntime/test/testdata/squeezenet/model_opset11.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-pad-conv.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-pad-conv.onnx new file mode 100644 index 0000000000000..ced1950005985 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-pad-conv.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-pad-maxpool-opset8.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-pad-maxpool-opset8.onnx new file mode 100644 index 0000000000000..feb1f024ceed7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-pad-maxpool-opset8.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-pad-maxpool.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-pad-maxpool.onnx new file mode 100644 index 0000000000000..32e959262f6b5 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-pad-maxpool.onnx differ diff --git a/onnxruntime/test/testdata/trt_reshape.onnx b/onnxruntime/test/testdata/trt_reshape.onnx new file mode 100644 index 0000000000000..7d195af2ae204 --- /dev/null +++ b/onnxruntime/test/testdata/trt_reshape.onnx @@ -0,0 +1,16 @@ + :‰ +) +data +shapereshapedReshape"Reshapetrt_engine_wrapperZ +data +  +N +Z +shape + + +b +reshaped +  + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/trt_reshape_test.py b/onnxruntime/test/testdata/trt_reshape_test.py new file mode 100644 index 0000000000000..42777bd3d50c7 --- /dev/null +++ b/onnxruntime/test/testdata/trt_reshape_test.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import TensorProto, helper + + +def generate_model(model_name): + nodes = [ + helper.make_node( + "Reshape", + ["data", "shape"], + ["reshaped"], + "Reshape", + ), + ] + + graph = helper.make_graph( + nodes, + "trt_engine_wrapper", + [ # input + helper.make_tensor_value_info("data", TensorProto.FLOAT, ["N", 2]), + helper.make_tensor_value_info( + "shape", + TensorProto.INT64, + [ + 2, + ], + ), + ], + [ # output + helper.make_tensor_value_info("reshaped", TensorProto.FLOAT, [4, 1]), + ], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + generate_model("trt_reshape.onnx") diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 8068d4825cee3..93ba836b53e22 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -70,6 +70,7 @@ static std::unordered_map> {"Split", {1}}, {"Clip", {1, 2}}, {"Pad", {1, 2}}, + {"MatMulBnb4", {1, 2}}, // quantified weight (non float) and absmax constant don't need gradients. {"Multinomial", {0}}, {"RandomNormalLike", {0}}, {"RandomUniformLike", {0}}, diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 6547f53a3c2ae..7100cedaf78a0 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -494,6 +494,42 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetMatmulBnb4Gradient) { + auto attributes = SrcNodeAttributes(); + std::vector attrs; + bool find_transB = false; + for (auto& attr : attributes) { + if (attr.first == "transB") { + int64_t transB_value = attr.second.i(); + transB_value = (transB_value + 1) % 2; // revert the transpose + attrs.push_back(MakeAttribute("transB", transB_value)); + find_transB = true; + } else { + attrs.push_back(attr.second); + } + } + + if (!find_transB) { + attrs.push_back(MakeAttribute("transB", int64_t(0))); // default is 1, so we need to set it to 0 + } + + std::vector result; + // Y = A * B + // dA = dY * B', dB = A' * dY + if (IsGradientRequiredForSrcNodeInput(0)) { + // B is 1-D, so don't need transpose here. + result.push_back(NodeDef(OpDef{"MatMulBnb4", kMSDomain, 1}, + {GO(0), I(1), I(2)}, + {GI(0)}, + attrs)); + } + + ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(1), "Gradient propagation to B is not supported yet."); + ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(2), "Gradient propagation to absmax is not supported yet."); + + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetSplitGradient) { std::vector result = {}; std::vector input_args; diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 28a316261e2f6..08987a86ebda9 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -54,6 +54,7 @@ DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient) DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossInternalGradient) DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetGemmGradient) +DECLARE_GRADIENT_BUILDER(GetMatmulBnb4Gradient) DECLARE_GRADIENT_BUILDER(GetDropoutGradient) DECLARE_GRADIENT_BUILDER(GetGatherNDGradient) DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4b8c68aef078a..f280a02cb490f 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -68,6 +68,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient); REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient); REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); + REGISTER_GRADIENT_BUILDER("MatMulBnb4", GetMatmulBnb4Gradient); REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py index 40398b33d8f04..03bb0f4373d8d 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py @@ -393,7 +393,7 @@ def _bwd_kernel_one_col_block( dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # There seems to be some problem with Triton pipelining that makes results wrong for # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop - # may have zero step, and pipelining with the bias matrix could screw it up. + # may have zero step, and pipelining with the bias matrix could cause the problem. # So we just exit early. if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8c5469740d9bd..de2c04d262412 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -3,7 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import sys +from typing import ClassVar import torch import torch.utils.checkpoint @@ -18,13 +21,55 @@ register_torch_autograd_function, ) from onnxruntime.training import ortmodule -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pytorch_type_to_onnx_dtype from ._custom_op_symbolic_registry import wrap_custom_export_function from ._fallback import ORTModuleONNXModelException, wrap_exception from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version +class _SpecialCustomFunctionHandler: + """A class to handle high priority export of torch.autograd.Function. + `register_high_priority_handler` can be used as function decorator to register a handler for a torch.autograd.Function. + """ + + _HIGH_PRIORITY_EXPORT_HANDLER_MAP: ClassVar[dict[str, callable]] = {} + + @staticmethod + def add_handler(func_name: str, handler: callable) -> None: + """Add a handler for a function name. + + Args: + func_name (str): The function name. + handler (callable): The handler. + + """ + _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP[func_name] = handler + + @staticmethod + def get_handler(func_name: str) -> callable | None: + """Get the handler for a function name. + + Args: + func_name (str): The function name. + + Returns: + callable | None: The handler. + + """ + return _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP.get(func_name, None) + + +def register_high_priority_handler(func_name): + """Register a handler for a torch.autograd.Function using its full qualified class name.""" + + def symbolic_wrapper(fn): + _SpecialCustomFunctionHandler.add_handler(func_name, fn) + return fn + + return symbolic_wrapper + + def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. @@ -96,6 +141,30 @@ def alias_input(node_proto_str: str) -> Tuple[List[int], List[int]]: ) +def _get_training_mode() -> bool: + # TODO move to public API once the exporter team exposes that + training_mode = None + if get_runtime_pytorch_version() >= version.parse("1.12"): + # FIXME: using private modules + from torch.onnx import _globals + + # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9, + # training_mode is still a bool type + if isinstance(_globals.GLOBALS.training_mode, bool): + training_mode = _globals.GLOBALS.training_mode + else: + if _globals.GLOBALS.training_mode not in [ + torch.onnx.TrainingMode.EVAL, + torch.onnx.TrainingMode.TRAINING, + ]: + raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}") + training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING + else: + training_mode = symbolic_helper._training_mode + + return bool(training_mode) + + def _export_pt_1_10(g, n, *args, **kwargs): """Export torch.autograd.Function in ORT PythonOp. @@ -114,6 +183,15 @@ def _export_pt_1_10(g, n, *args, **kwargs): func_class = n.pyobj().__self__ func_full_qual_name = get_fully_qualified_class_name(func_class) + # Check if the function is handled by high priority exporter. + hi_pri_handler = _SpecialCustomFunctionHandler.get_handler(func_full_qual_name) + if hi_pri_handler: + try_export = hi_pri_handler(g, n, *args, **kwargs) + if try_export is not None: + return try_export + + # Fall back to common exporter if not handled by high priority exporter. + # Check if the checkpointing activation is allowed. is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1 if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES: @@ -123,26 +201,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): "wrap exportable sub-nn.Module's as ORTModule." ) - # TODO move to public API once the exporter team exposes that - training_mode = None - if get_runtime_pytorch_version() >= version.parse("1.12"): - # FIXME: using private modules - from torch.onnx import _globals - - # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9, - # training_mode is still a bool type - if isinstance(_globals.GLOBALS.training_mode, bool): - training_mode = _globals.GLOBALS.training_mode - else: - if _globals.GLOBALS.training_mode not in [ - torch.onnx.TrainingMode.EVAL, - torch.onnx.TrainingMode.TRAINING, - ]: - raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}") - training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING - else: - training_mode = symbolic_helper._training_mode - cconv = n.cconv() input_tensor_types = [] @@ -179,7 +237,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): if call_type == "d": # Got a tensor variable. tensor_args.append(arg) - scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) input_tensor_types.append(scalar_type) input_tensor_ranks.append(arg.type().dim()) continue @@ -258,7 +316,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): output_tensor_ranks = [] for arg in n.outputs(): # Type of tensor's elements. - scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) @@ -270,7 +328,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): "input_tensor_ranks_i": input_tensor_ranks, "output_tensor_types_i": output_tensor_types, "output_tensor_ranks_i": output_tensor_ranks, - "training_mode_i": 1 if training_mode else 0, + "training_mode_i": 1 if _get_training_mode() else 0, "comment_s": debug_comment, } @@ -336,3 +394,55 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model index += 1 return exported_model + + +@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit") +def _matmul4bit_export(g, n, *args, **kwargs): + cconv = n.cconv() + can_converted = cconv[0] == "d" and cconv[1] == "d" and cconv[2] == "c" and cconv[3] == "c" and cconv[4] == "c" + can_converted = can_converted and (args[2] is None and args[3] is None and args[4] is not None) + if not can_converted: + return None + + quant_state = args[4] + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + + # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16 + if blocksize < 16 or blocksize & (blocksize - 1) != 0: + return None + + # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too) + if compressed_stats is not None: + return None + + # The PyTorch linear weight shape is [out_feature, in_feature] + in_feature = shape[1] + out_feature = shape[0] + if quant_type == "fp4": + quant_type = 0 + elif quant_type == "nf4": + quant_type = 1 + else: + return None + attrs = { + "K_i": in_feature, + "N_i": out_feature, + "block_size_i": blocksize, + "quant_type_i": quant_type, + "training_mode_i": 1 if _get_training_mode() else 0, + } + + # Make sure the quant weight can be flatten to 1D tensor safely, which com.microsoft::MatMulBnb4 requires. + found_dim1 = any(v == 1 for v in args[1].type().sizes()) + if not found_dim1: + return None + + absmax = g.op( + "Constant", + value_t=torch.tensor(absmax, dtype=pytorch_scalar_type_to_pytorch_dtype(args[0].type().scalarType())), + ) + quant_weight = g.op( + "Reshape", args[1], g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) # flatten to 1D + tensor_args = [args[0], quant_weight, absmax] + return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 6e694dcdf2e39..99e8851b6a697 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -12,7 +12,7 @@ from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype from ._utils import get_runtime_pytorch_version @@ -145,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index, weight_casted, ignore_index, reduction_s=reduction, - output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()), + output_type_i=pytorch_type_to_onnx_dtype(output_type.scalarType()), outputs=2, ) output.setType(output_type) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5e8805bfddb4d..ba61b2e4c8804 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,7 +19,7 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils @@ -345,7 +345,7 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" ) if os.path.exists(cache_dir) and os.path.isfile(filename): - self._logger.info( + self._logger.warning( f"Cached model detected! Cached model will be used to save export and initialization time." f"If you want the model to be re-exported then DELETE {filename}." ) @@ -627,7 +627,7 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), device=device, ).requires_grad_() diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index d0dea66fda2a7..3a5d7da926f2a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -14,7 +14,7 @@ register_shape_inference_function, register_torch_autograd_function, ) -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary from ._utils import get_fully_qualified_class_name @@ -149,7 +149,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: c, graph_input.name, len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank - pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + pytorch_type_to_onnx_dtype(zero_stage3_named_params[graph_input.name].dtype), # new data type ) # Delete exported_model.graph.input diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 0eb6790d7a462..ff0cde37195cb 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -366,7 +366,7 @@ def _override_from_env_vars(self): # Cache exported model if "ORTMODULE_CACHE_DIR" in os.environ: - self._logger.info("ORTModule cache optimization is ON.") + self._logger.warning("ORTModule optimization for caching exported model is ON.") self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR") # Experimental features. diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index fa7c9f2750cdd..d40a6ddf7daf3 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,11 @@ extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import ( + onnx_dtype_to_pytorch_dtype, + pytorch_scalar_type_to_pytorch_dtype, + pytorch_type_to_onnx_dtype, +) __all__ = [ "PrimitiveType", @@ -17,6 +21,7 @@ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", - "pytorch_dtype_to_onnx", - "onnx_dtype_to_pytorch", + "pytorch_type_to_onnx_dtype", + "onnx_dtype_to_pytorch_dtype", + "pytorch_scalar_type_to_pytorch_dtype", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index b1cb5c19e8979..0d268a7a4a5cf 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -17,7 +17,7 @@ from onnxruntime.training.utils import ( ORTModelInputOutputType, extract_data_and_schema, - pytorch_dtype_to_onnx, + pytorch_type_to_onnx_dtype, unflatten_data_using_schema, ) @@ -324,7 +324,7 @@ def infer_shape( start_offset = len(tensor_input_shapes) - len(partitioned_params) for index, param in enumerate(partitioned_params): tensor_output_shapes[start_offset + index] = list(param.ds_shape) - tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype)) + tensor_output_dtypes[start_offset + index] = int(pytorch_type_to_onnx_dtype(param.dtype)) assert len(tensor_output_shapes) == len(tensor_input_shapes) assert len(tensor_output_dtypes) == len(tensor_input_dtypes) diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index bdacab8ad04fe..2b429f3fd4f3a 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -36,8 +36,10 @@ _ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} -def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: - """Converts a pytorch dtype or scalar type string to an onnx dtype.""" +def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: + """Converts a pytorch dtype or scalar type string to an onnx dtype. + PyTorch type can be either a dtype or a scalar type string. + """ dtype = dtype_or_scalar_type if isinstance(dtype, str): if dtype not in _CAST_PYTORCH_TO_ONNX: @@ -49,7 +51,15 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc return _DTYPE_TO_ONNX[dtype] -def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: +def pytorch_scalar_type_to_pytorch_dtype(dtype: str) -> torch.dtype: + """Converts a pytorch scalar type string to a pytorch dtype.""" + assert isinstance(dtype, str) + if dtype not in _CAST_PYTORCH_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _CAST_PYTORCH_TO_ONNX[dtype][1] + + +def onnx_dtype_to_pytorch_dtype(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: """Converts an onnx dtype to a pytorch dtype.""" if dtype not in _ONNX_TO_DTYPE: raise RuntimeError(f"Unsupported dtype {dtype}") diff --git a/setup.py b/setup.py index f6308c56d0590..9eca9845c9e8b 100644 --- a/setup.py +++ b/setup.py @@ -203,19 +203,22 @@ def run(self): "libcurand.so.10", ] rocm_dependencies = [ - "librccl.so.1", - "libnuma.so.1", "libamd_comgr.so.2", + "libamdhip64.so.5", "libdrm.so.2", - "librocblas.so.0", "libdrm_amdgpu.so.1", - "libamdhip64.so.5", - "libroctracer64.so.4", - "libMIOpen.so.1", - "libtinfo.so.6", "libelf.so.1", - "librocm_smi64.so.5", + "libhipfft.so.0", + "libhiprtc.so.5", "libhsa-runtime64.so.1", + "libMIOpen.so.1", + "libnuma.so.1", + "librccl.so.1", + "librocblas.so.3", + "librocfft.so.0", + "librocm_smi64.so.5", + "libroctracer64.so.4", + "libtinfo.so.6", ] tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 806e536cb4ddb..25d69ef3a1eb5 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -66,15 +66,13 @@ def _str_to_bool(s): def _openvino_verify_device_type(device_read): - choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16", "VPUX_FP16", "VPUX_U8"] + choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"] choices1 = [ "CPU_FP32_NO_PARTITION", "CPU_FP16_NO_PARTITION", "GPU_FP32_NO_PARTITION", "GPU_FP16_NO_PARTITION", - "VPUX_FP16_NO_PARTITION", - "VPUX_U8_NO_PARTITION", ] status_hetero = True res = False @@ -89,7 +87,7 @@ def _openvino_verify_device_type(device_read): if len(comma_separated_devices) < 2: print("At least two devices required in Hetero/Multi/Auto Mode") status_hetero = False - dev_options = ["CPU", "GPU", "VPUX"] + dev_options = ["CPU", "GPU"] for dev in comma_separated_devices: if dev not in dev_options: status_hetero = False @@ -100,7 +98,7 @@ def invalid_hetero_build(): print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") - print("are ['CPU','GPU', 'VPUX'] \n") + print("are ['CPU','GPU'] \n") print("An example of how to specify the hetero build type. Ex: HETERO:GPU,CPU \n") print("An example of how to specify the MULTI build type. Ex: MULTI:GPU,CPU \n") print("An example of how to specify the AUTO build type. Ex: AUTO:GPU,CPU \n") @@ -798,6 +796,7 @@ def run_subprocess( my_env.update(env) + log.info(" ".join(args)) return run(*args, cwd=cwd, capture_stdout=capture_stdout, shell=shell, env=my_env) @@ -1158,8 +1157,6 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ("ON" if args.use_openvino == "GPU_FP16" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ("ON" if args.use_openvino == "CPU_FP32" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16=" + ("ON" if args.use_openvino == "CPU_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16=" + ("ON" if args.use_openvino == "VPUX_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8=" + ("ON" if args.use_openvino == "VPUX_U8" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP=" + ("ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP=" @@ -1168,9 +1165,6 @@ def generate_build_tree( + ("ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16_NP=" + ("ON" if args.use_openvino == "CPU_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16_NP=" - + ("ON" if args.use_openvino == "VPUX_FP16_NP_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8_NP=" + ("ON" if args.use_openvino == "VPUX_U8_NP_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_HETERO=" + ("ON" if args.use_openvino.startswith("HETERO") else "OFF"), "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino), "-Donnxruntime_USE_OPENVINO_MULTI=" + ("ON" if args.use_openvino.startswith("MULTI") else "OFF"), @@ -2031,13 +2025,6 @@ def build_python_wheel( run_subprocess(args, cwd=cwd) -def derive_linux_build_property(): - if is_windows(): - return '/p:IsLinuxBuild="false"' - else: - return '/p:IsLinuxBuild="true"' - - def build_nuget_package( cmake_path, source_dir, @@ -2050,7 +2037,6 @@ def build_nuget_package( use_dnnl, use_tvm, use_winml, - use_snpe, use_qnn, enable_training_apis, msbuild_extra_options, @@ -2061,83 +2047,93 @@ def build_nuget_package( ) csharp_build_dir = os.path.join(source_dir, "csharp") - is_linux_build = derive_linux_build_property() # in most cases we don't want/need to include the Xamarin mobile targets, as doing so means the Xamarin # mobile workloads must be installed on the machine. # they are only included in the Microsoft.ML.OnnxRuntime nuget package sln = "OnnxRuntime.DesktopOnly.CSharp.sln" + have_exclude_mobile_targets_option = "IncludeMobileTargets=false" in msbuild_extra_options # derive package name and execution provider based on the build args target_name = "/t:CreatePackage" - execution_provider = '/p:ExecutionProvider="None"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime"' - enable_training_tests = '/p:TrainingEnabledNativeBuild="false"' + execution_provider = "/p:ExecutionProvider=None" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime" + enable_training_tests = "/p:TrainingEnabledNativeBuild=false" + if enable_training_apis: - enable_training_tests = '/p:TrainingEnabledNativeBuild="true"' + enable_training_tests = "/p:TrainingEnabledNativeBuild=true" if use_cuda: - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Training.Gpu"' + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Training.Gpu" else: - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Training"' + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Training" elif use_winml: - package_name = '/p:OrtPackageId="Microsoft.AI.MachineLearning"' + package_name = "/p:OrtPackageId=Microsoft.AI.MachineLearning" target_name = "/t:CreateWindowsAIPackage" elif use_openvino: - execution_provider = '/p:ExecutionProvider="openvino"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.OpenVino"' + execution_provider = "/p:ExecutionProvider=openvino" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.OpenVino" elif use_tensorrt: - execution_provider = '/p:ExecutionProvider="tensorrt"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.TensorRT"' + execution_provider = "/p:ExecutionProvider=tensorrt" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" elif use_dnnl: - execution_provider = '/p:ExecutionProvider="dnnl"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.DNNL"' + execution_provider = "/p:ExecutionProvider=dnnl" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" elif use_cuda: - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu" elif use_rocm: - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm" elif use_tvm: - execution_provider = '/p:ExecutionProvider="tvm"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Tvm"' - elif use_snpe: - execution_provider = '/p:ExecutionProvider="snpe"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Snpe"' + execution_provider = "/p:ExecutionProvider=tvm" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Tvm" elif use_qnn: - execution_provider = '/p:ExecutionProvider="qnn"' - package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.QNN"' + execution_provider = "/p:ExecutionProvider=qnn" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.QNN" elif any(map(lambda x: "OrtPackageId=" in x, msbuild_extra_options)): pass else: - # use the solution file that includes Xamarin mobile targets - sln = "OnnxRuntime.CSharp.sln" + # we currently only allow building with mobile targets on Windows. + # it should be possible to allow building with android targets on Linux but that requires updating the + # csproj to separate the inclusion of ios and android targets. + if is_windows() and have_exclude_mobile_targets_option is False: + # use the sln that include the mobile targets + sln = "OnnxRuntime.CSharp.sln" + + # explicitly exclude mobile targets in this case + if sln != "OnnxRuntime.CSharp.sln" and have_exclude_mobile_targets_option is False: + msbuild_extra_options.append("IncludeMobileTargets=false") + + # expand extra_options to add prefix + extra_options = ["/p:" + option for option in msbuild_extra_options] + + # we have to use msbuild directly if including Xamarin targets as dotnet only supports MAUI (.net6) + use_dotnet = sln != "OnnxRuntime.CSharp.sln" + + if use_dotnet: + cmd_args = ["dotnet", "restore", sln, "--configfile", "NuGet.CSharp.config", *extra_options] + else: + cmd_args = ["msbuild", sln, "/t:restore", "/p:RestoreConfigFile=NuGet.CSharp.config", *extra_options] # set build directory based on build_dir arg native_dir = os.path.normpath(os.path.join(source_dir, build_dir)) - ort_build_dir = '/p:OnnxRuntimeBuildDirectory="' + native_dir + '"' + ort_build_dir = "/p:OnnxRuntimeBuildDirectory=" + native_dir - # dotnet restore - cmd_args = ["dotnet", "restore", sln, "--configfile", "NuGet.CSharp.config"] run_subprocess(cmd_args, cwd=csharp_build_dir) # build csharp bindings and create nuget package for each config for config in configs: - if is_linux(): - native_build_dir = os.path.join(native_dir, config) - cmd_args = [cmake_path, "-DCMAKE_INSTALL_PREFIX=./nuget-staging/usr/local", "-Pcmake_install.cmake"] - run_subprocess(cmd_args, cwd=native_build_dir) - - configuration = '/p:Configuration="' + config + '"' - + configuration = "/p:Configuration=" + config if not use_winml: - cmd_args = [ - "dotnet", + cmd_args = ["dotnet"] if use_dotnet else [] + cmd_args += [ "msbuild", sln, configuration, package_name, - is_linux_build, ort_build_dir, enable_training_tests, + *extra_options, ] + run_subprocess(cmd_args, cwd=csharp_build_dir) else: winml_interop_dir = os.path.join(source_dir, "csharp", "src", "Microsoft.AI.MachineLearning.Interop") @@ -2148,7 +2144,7 @@ def build_nuget_package( "msbuild", winml_interop_project, configuration, - '/p:Platform="Any CPU"', + "/p:Platform=Any CPU", ort_build_dir, "-restore", ] @@ -2162,26 +2158,28 @@ def build_nuget_package( # this path is setup by cmake/nuget_helpers.cmake for MSVC on Windows nuget_exe = os.path.normpath(os.path.join(native_dir, config, "nuget_exe", "src", "nuget.exe")) else: - # user needs to make sure nuget is installed and can be found - nuget_exe = "nuget" + # `dotnet pack` is used on Linux + nuget_exe = "NugetExe_not_set" nuget_exe_arg = '/p:NugetExe="' + nuget_exe + '"' - cmd_args = [ - "dotnet", + cmd_args = ["dotnet"] if use_dotnet else [] + cmd_args += [ "msbuild", "OnnxRuntime.CSharp.proj", target_name, package_name, configuration, execution_provider, - is_linux_build, ort_build_dir, nuget_exe_arg, + *extra_options, ] - cmd_args.extend(msbuild_extra_options) + run_subprocess(cmd_args, cwd=csharp_build_dir) + log.info(f"nuget package was created in the {config} build output directory.") + def run_csharp_tests(source_dir, build_dir, use_cuda, use_openvino, use_tensorrt, use_dnnl, enable_training_apis): # Currently only running tests on windows. @@ -2644,6 +2642,7 @@ def main(): enable_training_apis=args.enable_training_apis, enable_rocm_profiling=args.enable_rocm_profiling, ) + if args.build_nuget: build_nuget_package( cmake_path, @@ -2657,7 +2656,6 @@ def main(): args.use_dnnl, args.use_tvm, args.use_winml, - args.use_snpe, args.use_qnn, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 14a9bbedf09a0..ac07d8c525372 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -716,44 +716,29 @@ stages: versionSpec: 6.2.1 - task: PowerShell@2 - displayName: Install .NET 6 workloads + displayName: Install mobile workloads inputs: targetType: 'inline' script: | - dotnet workload install android ios macos - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: PowerShell@2 - displayName: Build .NET 6 targets using dotnet - inputs: - targetType: 'inline' - # we don't specify 'Any CPU' as the platform here because if we do it gets added to the output path - # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\Any CPU\RelWithDebInfo\net6.0-ios\ - # which is inconsistent with the msbuild output path for the pre-.net6 targets - # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\monoandroid11.0 - # and makes it harder to do the packing - # - # 'Any CPU' is the default (first 'mixed' platform specified in the csproj) so this should be fine. - script: | - dotnet build .\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj -p:SelectedTargets=Net6 -p:Configuration=RelWithDebInfo -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + dotnet workload install android ios workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json for pre-.net6 targets' + displayName: 'Restore NuGet Packages and create project.assets.json' inputs: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:SelectedTargets=PreNet6 -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' + msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: MSBuild@1 - displayName: 'Build C# for pre-.net6 targets' + displayName: 'Build C# bindings' inputs: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' configuration: RelWithDebInfo platform: 'Any CPU' - msbuildArguments: '-p:SelectedTargets=PreNet6 -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - template: templates/win-esrp-dll.yml @@ -762,15 +747,6 @@ stages: DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} - - task: MSBuild@1 - displayName: Update projects.assets.json with combined list of all target frameworks - inputs: - solution: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:SelectedTargets=All -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: MSBuild@1 displayName: 'Build Nuget Packages' inputs: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 81e8d67b79021..2d92108efb46d 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -137,7 +137,7 @@ stages: - task: MSBuild@1 displayName: 'Restore NuGet Packages' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: '$(BuildConfig)' msbuildArguments: '-t:restore -p:OrtPackageId=${{ parameters.OrtPackageId }}' @@ -146,7 +146,7 @@ stages: - task: MSBuild@1 displayName: 'Build C#' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' configuration: '$(BuildConfig)' platform: 'Any CPU' msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=${{ parameters.OrtPackageId }} -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index af245c99700ec..4ce39ecc35bfb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -398,44 +398,29 @@ stages: versionSpec: 6.2.1 - task: PowerShell@2 - displayName: Install .NET 6 workloads + displayName: Install mobile workloads inputs: targetType: 'inline' script: | - dotnet workload install android ios macos - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: PowerShell@2 - displayName: Build Microsoft.ML.OnnxRuntime .NET 6 targets using dotnet - inputs: - targetType: 'inline' - # we don't specify 'Any CPU' as the platform here because if we do it gets added to the output path - # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\Any CPU\RelWithDebInfo\net6.0-ios\ - # which is inconsistent with the msbuild output path for the pre-.net6 targets - # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\monoandroid11.0 - # and makes it harder to do the packing - # - # 'Any CPU' is the default (first 'mixed' platform specified in the csproj) so this should be fine. - script: | - dotnet build .\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj -p:SelectedTargets=Net6 -p:Configuration=RelWithDebInfo -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + dotnet workload install android ios workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json for pre-.net6 targets' + displayName: 'Restore NuGet Packages and create project.assets.json' inputs: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:SelectedTargets=PreNet6 -p:OrtPackageId=$(OrtPackageId)' + msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: MSBuild@1 - displayName: 'Build C# for pre-.net6 targets' + displayName: 'Build C# bindings' inputs: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-p:SelectedTargets=PreNet6 -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - ${{ if eq(parameters.DoEsrp, true) }}: @@ -445,15 +430,6 @@ stages: DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} - - task: MSBuild@1 - displayName: Update projects.assets.json with combined list of all target frameworks - inputs: - solution: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:SelectedTargets=All -p:OrtPackageId=$(OrtPackageId)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: MSBuild@1 displayName: 'Build Nuget Packages' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 1373381e4c83e..dc41a2d398893 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.97 + version: 1.0.107 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.97 + version: 1.0.107 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 3b1fde6cb6e4f..404699f705344 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -169,7 +169,7 @@ jobs: - task: MSBuild@1 displayName: 'Restore NuGet Packages' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: '${{ parameters.BuildConfig }}' msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' @@ -178,7 +178,7 @@ jobs: - task: MSBuild@1 displayName: 'Build C#' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' configuration: '${{ parameters.BuildConfig }}' platform: 'Any CPU' msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' @@ -197,7 +197,7 @@ jobs: command: test projects: '$(Build.SourcesDirectory)\csharp\test\Microsoft.ML.OnnxRuntime.Tests.NetCoreApp\Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj' configuration: '${{ parameters.BuildConfig }}' - arguments: '--configuration ${{ parameters.BuildConfig }} -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) --blame' + arguments: '--configuration ${{ parameters.BuildConfig }} -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IncludeMobileTargets=false --blame' workingDirectory: '$(Build.SourcesDirectory)\csharp' - ${{ if eq(parameters.EnablePython, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index 792e828c9a880..24e46066a1f10 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -222,44 +222,29 @@ stages: versionSpec: 6.2.1 - task: PowerShell@2 - displayName: Install .NET 6 workloads + displayName: Install mobile workloads inputs: targetType: 'inline' script: | dotnet workload install android workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: PowerShell@2 - displayName: Build Microsoft.ML.OnnxRuntime .NET 6 targets using dotnet - inputs: - targetType: 'inline' - # we don't specify 'Any CPU' as the platform here because if we do it gets added to the output path - # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\Any CPU\RelWithDebInfo\net6.0-ios\ - # which is inconsistent with the msbuild output path for the pre-.net6 targets - # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\monoandroid11.0 - # and makes it harder to do the packing - # - # 'Any CPU' is the default (first 'mixed' platform specified in the csproj) so this should be fine. - script: | - dotnet build .\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj -p:SelectedTargets=Net6 -p:Configuration=RelWithDebInfo -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json for pre-.net6 targets' + displayName: 'Restore NuGet Packages and create project.assets.json' inputs: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:SelectedTargets=PreNet6 -p:OrtPackageId=$(OrtPackageId)' + msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: MSBuild@1 - displayName: 'Build C# for pre-.net6 targets' + displayName: 'Build C# bindings' inputs: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-p:SelectedTargets=PreNet6 -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' workingDirectory: '$(Build.SourcesDirectory)\csharp' - ${{ if eq(parameters.DoEsrp, true) }}: @@ -269,15 +254,6 @@ stages: DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} - - task: MSBuild@1 - displayName: Update projects.assets.json with combined list of all target frameworks - inputs: - solution: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:SelectedTargets=All -p:OrtPackageId=$(OrtPackageId)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: MSBuild@1 displayName: 'Build Nuget Packages' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 187c7656602f5..8c926619c797d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -161,11 +161,13 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' condition: eq('${{ parameters.RunWebGpuTests }}', 'false') + retryCountOnTaskFailure: 3 - script: | npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + retryCountOnTaskFailure: 3 - script: | npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' @@ -180,10 +182,12 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + retryCountOnTaskFailure: 3 - script: | npm test -- --webgl-texture-pack-mode -b=webgl -e=edge workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebGL: packed mode' + retryCountOnTaskFailure: 3 - script: | npm test -- --wasm-enable-proxy -b=wasm -e=edge workingDirectory: '$(Build.SourcesDirectory)\js\web' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 2ba4b7bea3716..fdb9238071c9e 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -84,7 +84,7 @@ stages: job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} ORT_EP_NAME: DML - WITH_CACHE: true + WITH_CACHE: false MachinePool: onnxruntime-Win2022-GPU-dml-A10 - stage: kernelDocumentation diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index 5341ae062d332..680b12602910e 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index b2893286803b0..8ef1fd4522973 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 5d48a93b09c90..5673bddfe058a 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -5,7 +5,7 @@ mypy pytest setuptools>=68.2.2 wheel>=0.35.1 -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx argparse sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/windows/setup_env_cuda.bat b/tools/ci_build/github/windows/setup_env_cuda.bat index 96569cbe0f648..2233f7611ab6a 100644 --- a/tools/ci_build/github/windows/setup_env_cuda.bat +++ b/tools/ci_build/github/windows/setup_env_cuda.bat @@ -1,15 +1,17 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { - set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% -} else { +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( +set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% +) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% -} +) + @REM The default version is still cuda v11.8, because set cuda v12.2 after it -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ { +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 -} else { +) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 -} +) + set GRADLE_OPTS=-Dorg.gradle.daemon=false diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 4328c6eba1fe1..49b536e6ab81e 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -1,11 +1,21 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% -} else { +) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% -} -set PATH=C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Current\Bin;%PATH% +) +set PATH=C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib;%PATH% + +@REM The default version is still cuda v11.8, because set cuda v12.2 after it +set PATH=%PATH%;C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +) else ( + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\extras\CUPTI\lib64 +) + + set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index cc27cdc293646..df74e7e5599a8 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -552,11 +552,12 @@ def generate_files(line_list, args): files_list.append( "" ) + else: files_list.append( "' @@ -706,25 +707,9 @@ def generate_files(line_list, args): ) if is_windows(): - if "2022" in openvino_path: - dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") - tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") - else: - dll_list_path = os.path.join( - openvino_path, "deployment_tools\\inference_engine\\bin\\intel64\\Release\\" - ) - tbb_list_path = os.path.join(openvino_path, "deployment_tools\\inference_engine\\external\\tbb\\bin\\") - ngraph_list_path = os.path.join(openvino_path, "deployment_tools\\ngraph\\lib\\") - for ngraph_element in os.listdir(ngraph_list_path): - if ngraph_element.endswith("dll"): - files_list.append( - "' - ) + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") + for dll_element in os.listdir(dll_list_path): if dll_element.endswith("dll"): files_list.append( @@ -735,26 +720,7 @@ def generate_files(line_list, args): + args.target_architecture + '\\native" />' ) - # plugins.xml - files_list.append( - "' - ) - # usb-ma2x8x.mvcmd - # OpenVINO 2022.3 doesn't have usb-ma2x8x.mvcmd - if "2022.3" not in openvino_path: - files_list.append( - "' - ) + for tbb_element in os.listdir(tbb_list_path): if tbb_element.endswith("dll"): files_list.append( @@ -827,8 +793,10 @@ def generate_files(line_list, args): "" ) - # Some tools to be packaged in nightly build only, should not be released + # Some tools to be packaged in nightly debug build only, should not be released # These are copied to the runtimes folder for convenience of loading with the dlls + # NOTE: nuget gives a spurious error on linux if these aren't in a separate directory to the library so + # we add them to a tools folder for that reason. if ( args.is_release_build.lower() != "true" and args.target_architecture == "x64" @@ -838,7 +806,10 @@ def generate_files(line_list, args): "" ) @@ -851,7 +822,10 @@ def generate_files(line_list, args): "" ) @@ -905,7 +879,6 @@ def generate_files(line_list, args): os.system(copy_command + " " + source_props + " " + target_props) files_list.append("') if not is_snpe_package and not is_qnn_package: - files_list.append("') files_list.append("') # Process targets file @@ -924,7 +897,6 @@ def generate_files(line_list, args): os.system(copy_command + " " + source_targets + " " + target_targets) files_list.append("') if not is_snpe_package and not is_qnn_package: - files_list.append("') files_list.append("') # Process xamarin targets files diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index dcc6a92d84ef2..7a77839c4a4e7 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,10 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # not currently required, but running ensures we're hitting all mobile platforms + "Android CI Pipeline", + "iOS CI Pipeline", + "ONNX Runtime React Native CI Pipeline", ] # remove pipelines that have already run successfully