diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 34911cfc7972e..3965fe063b148 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -97,6 +97,14 @@ jobs: --exclude=java/src/main/native/*.c --exclude=onnxruntime/core/mlas/inc/* --exclude=onnxruntime/core/mlas/lib/* + --exclude=onnxruntime/contrib_ops/cuda/bert/flash_attention/* + --exclude=build/Debug/* + --exclude=cmake/* + --exclude=csharp/test/* + --exclude=onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/* + --exclude=orttraining/orttraining/test/* + --exclude=onnxruntime/test/* + --exclude=winml/* filter: "-runtime/references" lint-js: diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 8aaec8adef979..3d94d30947c76 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -54,11 +54,10 @@ jobs: --test \ --build_shared_lib \ --build_objc \ + --use_coreml \ --use_xnnpack \ --use_binskim_compliant_compile_flags - # TODO add --use_coreml once unit test failures are addressed - Objective-C-StaticAnalysis: runs-on: macos-14 diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml index eb35f6a814987..0867d4c343e91 100644 --- a/.github/workflows/sca.yml +++ b/.github/workflows/sca.yml @@ -16,6 +16,8 @@ env: jobs: Onnxruntime-SCA-training-CUDA: + permissions: + security-events: write runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - uses: actions/checkout@v4 @@ -57,6 +59,8 @@ jobs: # No python Onnxruntime-SCA-win32-WINML-x64: + permissions: + security-events: write runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - uses: actions/checkout@v4 @@ -95,6 +99,8 @@ jobs: # No java, No python Onnxruntime-SCA-win32-WINML-x86: + permissions: + security-events: write runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - uses: actions/checkout@v4 diff --git a/.lintrunner.toml b/.lintrunner.toml index 08d8af1206a1c..e6d06b34726fe 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -136,6 +136,7 @@ exclude_patterns = [ 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS based libs recommends NO automatic code formatting 'onnxruntime/core/mickey/gemm/**', # CUTLASS based libs recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks + 'onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h', # Bool Switches hang Clang ] command = [ 'python', diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 4eb865968cb53..29eb7045fc299 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", + "commitHash": "f46495ea96f68fc3f6c394f099b2992743f6ff7f", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 00097c25d2ba5..5555fa692eae8 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -101,8 +101,9 @@ option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to provide eigen_SOURCE_PATH if turn this on." OFF) option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) +option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) -cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -123,7 +124,6 @@ option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code cov option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) -option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF) option(onnxruntime_USE_DML "Build with DirectML support" OFF) @@ -241,7 +241,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) +cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) @@ -374,6 +374,10 @@ if (onnxruntime_USE_ROCM) message(WARNING "ck_tile can only be enabled on ROCm >= 6.0 due to compatibility and compilation speed, disable automatically") set(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE OFF) endif() + if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE AND CMAKE_BUILD_TYPE STREQUAL "Debug") + message(WARNING "ck_tile hits compiler error in Debug build, disable automatically") + set(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE OFF) + endif() endif() @@ -439,13 +443,6 @@ if (onnxruntime_ENABLE_MEMORY_PROFILE) endif() set(ONNX_ML 1) -if (NOT onnxruntime_ENABLE_PYTHON) - set(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS OFF) -endif() - -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - add_compile_definitions(ENABLE_LANGUAGE_INTEROP_OPS) -endif() if (NOT (UNIX AND onnxruntime_ENABLE_PYTHON AND onnxruntime_ENABLE_TRAINING AND (NOT onnxruntime_BUILD_SHARED_LIB))) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) @@ -554,7 +551,7 @@ if(NOT WIN32 AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android") endif() find_package(Patch) -if (WIN32 AND NOT Patch_FOUND) +if (CMAKE_HOST_WIN32 AND NOT Patch_FOUND) # work around CI machines missing patch from the git install by falling back to the binary in this repo. # replicate what happens in https://github.com/Kitware/CMake/blob/master/Modules/FindPatch.cmake but without # the hardcoded suffixes in the path to the patch binary. @@ -578,11 +575,15 @@ endif() #Need python to generate def file if (onnxruntime_BUILD_SHARED_LIB OR onnxruntime_ENABLE_PYTHON) if (onnxruntime_ENABLE_PYTHON) - if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS OR onnxruntime_REQUIRE_PYTHON_EMBED_LIB) + if (onnxruntime_REQUIRE_PYTHON_EMBED_LIB) find_package(Python 3.8 COMPONENTS Interpreter Development NumPy) else() find_package(Python 3.8 COMPONENTS Interpreter Development.Module NumPy) endif() + message("Numpy version: ${Python_NumPy_VERSION}") + if(Python_NumPy_VERSION VERSION_LESS "2.0.0") + message(WARNING "The build binary will not be compatible with NumPy 2.0 because the NumPy installed on this machine is too low.") + endif() else() find_package(Python 3.8 COMPONENTS Interpreter) endif() @@ -664,10 +665,10 @@ else() check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) if(onnxruntime_ENABLE_TRAINING_APIS) - check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) - if(HAS_DANGLING_REFERENCE) - list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) - endif() + check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) + if(HAS_DANGLING_REFERENCE) + list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) + endif() endif() check_function_exists(reallocarray HAS_REALLOCARRAY) if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") @@ -742,6 +743,9 @@ if (onnxruntime_USE_CUDA) message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) + message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0") + set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") @@ -794,6 +798,11 @@ if (onnxruntime_USE_RKNPU) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_RKNPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES rknpu) endif() +if (onnxruntime_USE_VSINPU) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_VSINPU=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VSINPU=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES vsinpu) +endif() if (onnxruntime_USE_NNAPI_BUILTIN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_NNAPI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_NNAPI_BUILTIN=1) @@ -836,8 +845,8 @@ if (onnxruntime_USE_QNN) file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/libQnn*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/Qnn*.dll") if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc" OR ${QNN_ARCH_ABI} STREQUAL "arm64x-windows-msvc") file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" - "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" - "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) endif() message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) @@ -1031,7 +1040,7 @@ function(onnxruntime_set_compile_flags target_name) # Enable warning target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options -Wall>" "$<$>:-Wall>") target_compile_options(${target_name} PRIVATE "$<$>:-Wextra>") - if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "IBMClang") #external/protobuf/src/google/protobuf/arena.h:445:18: error: unused parameter 'p' target_compile_options(${target_name} PRIVATE "-Wno-unused-parameter") endif() @@ -1048,6 +1057,9 @@ function(onnxruntime_set_compile_flags target_name) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:${FLAG}>") endforeach() + if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 13 AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 12) + target_compile_options(${target_name} PRIVATE "$<$:-Wno-maybe-uninitialized>") + endif() if (onnxruntime_USE_CUDA) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") @@ -1128,6 +1140,13 @@ endfunction() function(onnxruntime_add_shared_library target_name) add_library(${target_name} SHARED ${ARGN}) onnxruntime_configure_target(${target_name}) + if(WIN32) + target_compile_definitions(${target_name} PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) + target_compile_definitions(${target_name} PRIVATE VER_MINOR=${VERSION_MINOR_PART}) + target_compile_definitions(${target_name} PRIVATE VER_BUILD=${VERSION_BUILD_PART}) + target_compile_definitions(${target_name} PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) + target_compile_definitions(${target_name} PRIVATE VER_STRING=\"${VERSION_STRING}\") + endif() endfunction() function(onnxruntime_add_static_library target_name) @@ -1142,6 +1161,13 @@ function(onnxruntime_add_shared_library_module target_name) else() #On Windows, this target shouldn't generate an import lib, but I don't know how to disable it. add_library(${target_name} MODULE ${ARGN}) + if(WIN32) + target_compile_definitions(${target_name} PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) + target_compile_definitions(${target_name} PRIVATE VER_MINOR=${VERSION_MINOR_PART}) + target_compile_definitions(${target_name} PRIVATE VER_BUILD=${VERSION_BUILD_PART}) + target_compile_definitions(${target_name} PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) + target_compile_definitions(${target_name} PRIVATE VER_STRING=\"${VERSION_STRING}\") + endif() endif() onnxruntime_configure_target(${target_name}) @@ -1189,11 +1215,11 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 if (onnxruntime_USE_ACL_2002) add_definitions(-DACL_2002=1) else() - if (onnxruntime_USE_ACL_2308) - add_definitions(-DACL_2308=1) - else() + if (onnxruntime_USE_ACL_2308) + add_definitions(-DACL_2308=1) + else() add_definitions(-DACL_1905=1) - endif() + endif() endif() endif() endif() @@ -1338,6 +1364,10 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DUSE_OPENVINO=1) + if(onnxruntime_NPU_NO_FALLBACK) + add_definitions(-DOPENVINO_DISABLE_NPU_FALLBACK=1) + endif() + if (onnxruntime_USE_OPENVINO_GPU) add_definitions(-DOPENVINO_CONFIG_GPU=1) endif() @@ -1407,14 +1437,6 @@ string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}") string(APPEND ORT_BUILD_INFO ", cmake cxx flags: ${CMAKE_CXX_FLAGS}") configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h) get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) -if (onnxruntime_GENERATOR_IS_MULTI_CONFIG) - configure_file(../requirements.txt.in ${CMAKE_CURRENT_BINARY_DIR}/Debug/requirements.txt) - configure_file(../requirements.txt.in ${CMAKE_CURRENT_BINARY_DIR}/Release/requirements.txt) - configure_file(../requirements.txt.in ${CMAKE_CURRENT_BINARY_DIR}/RelWithDebInfo/requirements.txt) - configure_file(../requirements.txt.in ${CMAKE_CURRENT_BINARY_DIR}/MinSizeRel/requirements.txt) -else() - configure_file(../requirements.txt.in ${CMAKE_CURRENT_BINARY_DIR}/requirements.txt) -endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) @@ -1480,9 +1502,6 @@ if (onnxruntime_USE_CUDA) endif() if (onnxruntime_USE_MIGRAPHX) - if (WIN32) - message(FATAL_ERROR "MIGraphX does not support build in Windows!") - endif() set(AMD_MIGRAPHX_HOME ${onnxruntime_MIGRAPHX_HOME}) endif() @@ -1524,7 +1543,7 @@ if (onnxruntime_ENABLE_TRAINING) list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard) endif() -if (UNIX AND onnxruntime_USE_NCCL) +if (UNIX OR onnxruntime_USE_NCCL) # MPI is INDEPENDENT of NCCL for now. You can build NCLL without MPI and launch multi-GPU with your own launcher. if (onnxruntime_USE_MPI) if (EXISTS "${onnxruntime_MPI_HOME}") @@ -1552,7 +1571,7 @@ if (UNIX AND onnxruntime_USE_NCCL) if (onnxruntime_USE_NCCL) if (onnxruntime_USE_CUDA) set(NCCL_LIBNAME "nccl") - elseif (onnxruntime_USE_ROCM) + elseif (onnxruntime_USE_ROCM OR onnxruntime_USE_MIGRAPHX) set(NCCL_LIBNAME "rccl") endif() find_path(NCCL_INCLUDE_DIR @@ -1631,6 +1650,14 @@ set(VERSION_MINOR_PART 0 CACHE STRING "Second part of numeric file/product ver set(VERSION_BUILD_PART 0 CACHE STRING "Third part of numeric file/product version.") set(VERSION_PRIVATE_PART 0 CACHE STRING "Fourth part of numeric file/product version.") set(VERSION_STRING "Internal Build" CACHE STRING "String representation of file/product version.") +if(VERSION_MAJOR_PART STREQUAL "0" AND VERSION_MINOR_PART STREQUAL "0" AND VERSION_BUILD_PART STREQUAL "0" AND VERSION_PRIVATE_PART STREQUAL "0") + string(REPLACE "." ";" ORT_VERSION_STRING_LIST ${ORT_VERSION}) + list(GET ORT_VERSION_STRING_LIST 0 VERSION_MAJOR_PART) + list(GET ORT_VERSION_STRING_LIST 1 VERSION_MINOR_PART) + list(GET ORT_VERSION_STRING_LIST 2 VERSION_BUILD_PART) + set(VERSION_STRING ORT_VERSION) +endif() + if (WIN32) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SYS_PATH_LIB}) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 690b6d4e66154..6eb784a4063ed 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -321,7 +321,7 @@ else() string(APPEND CMAKE_CXX_FLAGS " -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl") string(APPEND CMAKE_C_FLAGS " -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl") endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android" AND Onnxruntime_GCOV_COVERAGE) + if (CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_GCOV_COVERAGE) string(APPEND CMAKE_CXX_FLAGS " -g -O0 --coverage ") string(APPEND CMAKE_C_FLAGS " -g -O0 --coverage ") endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index 96c183909bcb2..72469603a0889 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/f46495ea96f68fc3f6c394f099b2992743f6ff7f.zip;0e2b6d1dc7f0a808d1e23f7dd985f7bc18d52cbc coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 775576a771529..14e6ed515fd6e 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -46,6 +46,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(gtest_disable_pthreads ON) endif() + if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") + set(gtest_disable_pthreads ON CACHE BOOL "gtest_disable_pthreads" FORCE) + endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) if (IOS OR ANDROID) # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index e15c8a046dc20..21ae0947f3788 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -57,6 +57,7 @@ foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) list(APPEND SYMBOL_FILES "${ONNXRUNTIME_ROOT}/core/providers/${f}/symbols.txt") endforeach() +if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c COMMAND ${Python_EXECUTABLE} "${REPO_ROOT}/tools/ci_build/gen_def.py" --version_file "${ONNXRUNTIME_ROOT}/../VERSION_NUMBER" --src_root "${ONNXRUNTIME_ROOT}" @@ -66,6 +67,7 @@ add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_s WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) add_custom_target(onnxruntime_generate_def ALL DEPENDS ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) +endif() if(WIN32) onnxruntime_add_shared_library(onnxruntime ${SYMBOL_FILE} @@ -95,30 +97,33 @@ elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) FRAMEWORK TRUE FRAMEWORK_VERSION A MACOSX_FRAMEWORK_INFO_PLIST ${INFO_PLIST_PATH} - SOVERSION ${ORT_VERSION} # Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file. ) else() - onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) + if(${CMAKE_SYSTEM_NAME} MATCHES "AIX") + onnxruntime_add_shared_library(onnxruntime ${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc) + else() + onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c ) + endif() if (onnxruntime_USE_CUDA) set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN") endif() endif() -add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES}) +if(${CMAKE_SYSTEM_NAME} MATCHES "AIX") + add_dependencies(onnxruntime ${onnxruntime_EXTERNAL_DEPENDENCIES}) +else() + add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES}) +endif() target_include_directories(onnxruntime PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC "$") -target_compile_definitions(onnxruntime PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) -target_compile_definitions(onnxruntime PRIVATE VER_MINOR=${VERSION_MINOR_PART}) -target_compile_definitions(onnxruntime PRIVATE VER_BUILD=${VERSION_BUILD_PART}) -target_compile_definitions(onnxruntime PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) -target_compile_definitions(onnxruntime PRIVATE VER_STRING=\"${VERSION_STRING}\") + target_compile_definitions(onnxruntime PRIVATE FILE_NAME=\"onnxruntime.dll\") if(UNIX) if (APPLE) set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker -dead_strip") - else() + elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker --version-script=${SYMBOL_FILE} -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() else() @@ -130,7 +135,6 @@ if (NOT WIN32) set(ONNXRUNTIME_SO_LINK_FLAG " -Wl,-exported_symbols_list,${SYMBOL_FILE}") if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") set_target_properties(onnxruntime PROPERTIES - SOVERSION ${ORT_VERSION} MACOSX_RPATH TRUE INSTALL_RPATH_USE_LINK_PATH FALSE BUILD_WITH_INSTALL_NAME_DIR TRUE @@ -138,7 +142,7 @@ if (NOT WIN32) else() set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "@loader_path") endif() - elseif (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + elseif (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'") endif() endif() @@ -189,6 +193,7 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} + ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} @@ -205,11 +210,8 @@ set(onnxruntime_INTERNAL_LIBRARIES onnxruntime_flatbuffers ) -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - list(APPEND onnxruntime_INTERNAL_LIBRARIES - onnxruntime_language_interop - onnxruntime_pyop - ) +if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") + list(APPEND onnxruntime_INTERNAL_LIBRARIES iconv) endif() if (onnxruntime_USE_EXTENSIONS) @@ -228,13 +230,30 @@ target_link_libraries(onnxruntime PRIVATE ) set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_SO_LINK_FLAG} ${onnxruntime_DELAYLOAD_FLAGS}) -set_target_properties(onnxruntime PROPERTIES - PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}" - LINK_DEPENDS ${SYMBOL_FILE} - VERSION ${ORT_VERSION} - FOLDER "ONNXRuntime" -) - +#See: https://cmake.org/cmake/help/latest/prop_tgt/SOVERSION.html +if(NOT APPLE AND NOT WIN32) + if(${CMAKE_SYSTEM_NAME} MATCHES "AIX") + set_target_properties(onnxruntime PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}" + VERSION ${ORT_VERSION} + SOVERSION 1 + FOLDER "ONNXRuntime") + else() + set_target_properties(onnxruntime PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}" + LINK_DEPENDS ${SYMBOL_FILE} + VERSION ${ORT_VERSION} + SOVERSION 1 + FOLDER "ONNXRuntime") + endif() +else() + # Omit the SOVERSION setting in Windows/macOS/iOS/.. build + set_target_properties(onnxruntime PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}" + LINK_DEPENDS ${SYMBOL_FILE} + VERSION ${ORT_VERSION} + FOLDER "ONNXRuntime") +endif() install(TARGETS onnxruntime EXPORT ${PROJECT_NAME}Targets PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index c9bf2ac5c3dc6..43d16abd8fbae 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -108,7 +108,7 @@ add_dependencies(onnxruntime_framework ${onnxruntime_EXTERNAL_DEPENDENCIES}) # For the shared onnxruntime library, this is set in onnxruntime.cmake through CMAKE_SHARED_LINKER_FLAGS # But our test files don't use the shared library so this must be set for them. # For Win32 it generates an absolute path for shared providers based on the location of the executable/onnxruntime.dll -if (UNIX AND NOT APPLE AND NOT onnxruntime_MINIMAL_BUILD AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") +if (UNIX AND NOT APPLE AND NOT onnxruntime_MINIMAL_BUILD AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'") endif() diff --git a/cmake/onnxruntime_language_interop_ops.cmake b/cmake/onnxruntime_language_interop_ops.cmake deleted file mode 100644 index 5d88332eb2e9a..0000000000000 --- a/cmake/onnxruntime_language_interop_ops.cmake +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -include(onnxruntime_pyop.cmake) -file (GLOB onnxruntime_language_interop_ops_src "${ONNXRUNTIME_ROOT}/core/language_interop_ops/language_interop_ops.cc") -onnxruntime_add_static_library(onnxruntime_language_interop ${onnxruntime_language_interop_ops_src}) -add_dependencies(onnxruntime_language_interop onnxruntime_pyop) -onnxruntime_add_include_to_target(onnxruntime_language_interop onnxruntime_common onnxruntime_graph onnxruntime_framework onnxruntime_pyop onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11) -target_include_directories(onnxruntime_language_interop PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}) \ No newline at end of file diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f5473c..66f4aea606ef5 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -39,6 +39,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h + ${MLAS_SRC_DIR}/flashattn.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -82,7 +83,10 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) set(mlas_platform_preprocess_srcs @@ -350,9 +354,12 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs @@ -420,12 +427,24 @@ else() ) if(COMPILES_P10) check_cxx_source_compiles(" + #ifdef _AIX + #define POWER_10 0x40000 + #define POWER_10_ANDUP (POWER_10) + #include + #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) + int main() { + bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); + return 0; + } + #else #include int main() { unsigned long hwcap2 = getauxval(AT_HWCAP2); bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); return 0; - }" + } + } + #endif" HAS_P10_RUNTIME ) if (HAS_P10_RUNTIME) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 7e7819ac31a19..05a50a55db409 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -80,6 +80,9 @@ endif() if(onnxruntime_USE_RKNPU) set(PROVIDERS_RKNPU onnxruntime_providers_rknpu) endif() +if(onnxruntime_USE_VSINPU) + set(PROVIDERS_VSINPU onnxruntime_providers_vsinpu) +endif() if(onnxruntime_USE_DML) set(PROVIDERS_DML onnxruntime_providers_dml) endif() @@ -188,6 +191,10 @@ if (onnxruntime_USE_TVM) include(onnxruntime_providers_tvm.cmake) endif() +if (onnxruntime_USE_VSINPU) + include(onnxruntime_providers_vsinpu.cmake) +endif() + if (onnxruntime_USE_XNNPACK) include(onnxruntime_providers_xnnpack.cmake) endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index b211c02f712bd..d2afe19f36691 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -236,11 +236,6 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD set_target_properties(onnxruntime_providers_shared PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_shared PROPERTIES LINKER_LANGUAGE CXX) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_MINOR=${VERSION_MINOR_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_BUILD=${VERSION_BUILD_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) - target_compile_definitions(onnxruntime_providers_shared PRIVATE VER_STRING=\"${VERSION_STRING}\") target_compile_definitions(onnxruntime_providers_shared PRIVATE FILE_NAME=\"onnxruntime_providers_shared.dll\") @@ -252,7 +247,9 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD if(APPLE) set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst") elseif(UNIX) - set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections") + if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") + set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections") + endif() elseif(WIN32) set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def") set(ONNXRUNTIME_PROVIDERS_SHARED onnxruntime_providers_shared) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 3b48a40bf1166..82c31ce6b6b4d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -33,19 +33,21 @@ ) - if (onnxruntime_CUDA_MINIMAL) - set(onnxruntime_providers_cuda_shared_srcs "") - else() + if (NOT onnxruntime_CUDA_MINIMAL) file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" ) + else() + set(onnxruntime_providers_cuda_cu_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu" + ) endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) # disable contrib ops conditionally - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL) if (NOT onnxruntime_ENABLE_ATEN) list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 01c4f8b2c8719..d7d83b0ce8d64 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -19,23 +19,25 @@ endif() # Add search paths for default rocm installation - list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) - find_package(hip) - find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME}) + # Suppress the warning about the small capitals of the package name - Enable when support to CMake 3.27.0 is used + # cmake_policy(SET CMP0144 NEW) - find_package(miopen) - find_package(rocblas) + if(WIN32 AND NOT HIP_PLATFORM) + set(HIP_PLATFORM "amd") + endif() + + find_package(hip REQUIRED) + find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host MIOpen roc::rocblas) + set(migraphx_libs migraphx::c hip::host) file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.h" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) @@ -46,18 +48,16 @@ set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) - target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) - set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") - set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp) - - include(CheckLibraryExists) - check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) - if(HAS_STREAM_SYNC) - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC) - message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED") + if(MSVC) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) else() - message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED") + target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + endif() + if(UNIX) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") + target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs) endif() if (onnxruntime_ENABLE_TRAINING_OPS) @@ -68,8 +68,16 @@ endif() endif() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + install(TARGETS onnxruntime_providers_migraphx + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + else() + install(TARGETS onnxruntime_providers_migraphx + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + endif() diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 5876b2b5c448b..d738e29101cfe 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -45,11 +45,6 @@ target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${OpenVINO_INCLUDE_DIR} ${OPENVINO_INCLUDE_DIR_LIST} ${PYTHON_INCLUDE_DIRS} $ENV{OPENCL_INCS} $ENV{OPENCL_INCS}/../../cl_headers/) target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_MINOR=${VERSION_MINOR_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_BUILD=${VERSION_BUILD_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) - target_compile_definitions(onnxruntime_providers_openvino PRIVATE VER_STRING=\"${VERSION_STRING}\") target_compile_definitions(onnxruntime_providers_openvino PRIVATE FILE_NAME=\"onnxruntime_providers_openvino.dll\") if(MSVC) diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 38d4b52c97a84..71692ddb9391f 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -49,7 +49,10 @@ find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) + find_package(rocm_smi REQUIRED) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB}) + include_directories(${ROCM_SMI_INCLUDE_DIR}) + link_directories(${ROCM_SMI_LIB_DIR}) file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index e56de0c7124dc..3d46c139feea9 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - + if(onnxruntime_DISABLE_CONTRIB_OPS) + message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." ) + endif() add_definitions(-DUSE_TENSORRT=1) if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER) add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER) @@ -13,7 +15,6 @@ set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) if (WIN32) - add_definitions(-D_SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING=1) set(OLD_CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4099 /wd4551 /wd4505 /wd4515 /wd4706 /wd4456 /wd4324 /wd4701 /wd4804 /wd4702 /wd4458 /wd4703") if (CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -155,8 +156,11 @@ # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. - set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) - + if(onnxruntime_CUDA_MINIMAL) + set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + else() + set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + endif() file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc" @@ -191,6 +195,9 @@ if (WIN32) target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456) endif() + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE USE_CUDA_MINIMAL=1) + endif() # Needed for the provider interface, as it includes training headers when training is enabled if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_vsinpu.cmake b/cmake/onnxruntime_providers_vsinpu.cmake new file mode 100644 index 0000000000000..4b987fd1e424b --- /dev/null +++ b/cmake/onnxruntime_providers_vsinpu.cmake @@ -0,0 +1,37 @@ + add_definitions(-DUSE_VSINPU=1) + file(GLOB_RECURSE onnxruntime_providers_vsinpu_srcs + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vsinpu_srcs}) + add_library(onnxruntime_providers_vsinpu ${onnxruntime_providers_vsinpu_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vsinpu + onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11 + safeint_interface nsync::nsync_cpp) + add_dependencies(onnxruntime_providers_vsinpu ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_vsinpu PROPERTIES FOLDER "ONNXRuntime" LINKER_LANGUAGE CXX) + target_include_directories(onnxruntime_providers_vsinpu PRIVATE ${ONNXRUNTIME_ROOT} $ENV{TIM_VX_INSTALL}/include) + + find_library(TIMVX_LIBRARY NAMES tim-vx PATHS $ENV{TIM_VX_INSTALL}/lib NO_DEFAULT_PATH) + if(NOT TIMVX_LIBRARY) + message(FATAL_ERROR "TIM-VX library is not found!") + endif() + + if(CMAKE_CROSSCOMPILING) + message(STATUS "VSINPU ep will be cross compiled.") + if(EXISTS "$ENV{VIVANTE_SDK_DIR}/drivers") + set(DRIVER_DIR "$ENV{VIVANTE_SDK_DIR}/drivers") + elseif(EXISTS "$ENV{VIVANTE_SDK_DIR}/lib") + set(DRIVER_DIR "$ENV{VIVANTE_SDK_DIR}/lib") + else() + message(FATAL_ERROR "Neither drivers nor lib directory exists in this VIVANTE_SDK_DIR.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wl,-rpath-link ${DRIVER_DIR} ${TIMVX_LIBRARY}") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") + target_link_libraries(onnxruntime_providers_vsinpu PRIVATE ${TIMVX_LIBRARY}) + endif() diff --git a/cmake/onnxruntime_pyop.cmake b/cmake/onnxruntime_pyop.cmake deleted file mode 100644 index f7583690945a1..0000000000000 --- a/cmake/onnxruntime_pyop.cmake +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -onnxruntime_add_static_library(onnxruntime_pyop "${ONNXRUNTIME_ROOT}/core/language_interop_ops/pyop/pyop.cc") -add_dependencies(onnxruntime_pyop ${onnxruntime_EXTERNAL_DEPENDENCIES}) -onnxruntime_add_include_to_target(onnxruntime_pyop onnxruntime_common onnxruntime_graph onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers ${GSL_TARGET} Boost::mp11) -target_include_directories(onnxruntime_pyop PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}) -onnxruntime_add_include_to_target(onnxruntime_pyop Python::Module Python::NumPy) -if (TARGET Python::Python) - target_link_libraries(onnxruntime_pyop PRIVATE Python::Python) -else() - target_link_libraries(onnxruntime_pyop PRIVATE Python::Module) -endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index da5732e484d6b..07c65e7986b05 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -193,10 +193,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE ${pybind11_lib} ) -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_language_interop onnxruntime_pyop) -endif() - set(onnxruntime_pybind11_state_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES} ${pybind11_dep} @@ -566,6 +562,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ + COMMAND ${CMAKE_COMMAND} -E copy + ${REPO_ROOT}/requirements.txt + $ COMMAND ${CMAKE_COMMAND} -E copy ${REPO_ROOT}/ThirdPartyNotices.txt $/onnxruntime/ @@ -668,6 +667,15 @@ add_custom_command( $ ) +if (onnxruntime_BUILD_SHARED_LIB) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_OPENVINO) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD @@ -1027,6 +1035,3 @@ if (onnxruntime_USE_QNN) endif() endif() -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - include(onnxruntime_language_interop_ops.cmake) -endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2be68146b5e94..2966a4624a966 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -88,7 +88,6 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" - "bert/group_query_attention_helper.h" "bert/group_query_attention.h" "bert/group_query_attention.cc" "bert/group_query_attention_impl.h" diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index f9ba2b341f741..01590a431205c 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -141,10 +141,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) Boost::mp11 safeint_interface ) - if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - list(APPEND ONNXRUNTIME_LIBS onnxruntime_language_interop onnxruntime_pyop) - endif() - if(UNIX AND NOT APPLE) if (HAS_NO_MAYBE_UNINITIALIZED) target_compile_options(onnxruntime_training_mnist PUBLIC "-Wno-maybe-uninitialized") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fb96fd1cad39d..0159c35d1941b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -546,6 +546,10 @@ if(onnxruntime_USE_NNAPI_BUILTIN) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_VSINPU) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_vsinpu) +endif() + if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) endif() @@ -583,16 +587,13 @@ if(onnxruntime_USE_ARMNN) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_armnn) endif() -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - set(ONNXRUNTIME_INTEROP_TEST_LIBS PRIVATE onnxruntime_language_interop onnxruntime_pyop) -endif() - set(ONNXRUNTIME_TEST_LIBS onnxruntime_session ${ONNXRUNTIME_INTEROP_TEST_LIBS} ${onnxruntime_libs} # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} + ${PROVIDERS_VSINPU} ${PROVIDERS_JS} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} @@ -916,10 +917,6 @@ endif() if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS) target_compile_definitions(onnxruntime_test_all PRIVATE DEBUG_NODE_INPUTS_OUTPUTS) endif() - -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - target_link_libraries(onnxruntime_test_all PRIVATE onnxruntime_language_interop onnxruntime_pyop) -endif() if (onnxruntime_USE_ROCM) if (onnxruntime_USE_COMPOSABLE_KERNEL) target_compile_definitions(onnxruntime_test_all PRIVATE USE_COMPOSABLE_KERNEL) @@ -1057,10 +1054,6 @@ set(onnx_test_libs onnx_test_data_proto ${onnxruntime_EXTERNAL_LIBRARIES}) -if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - list(APPEND onnx_test_libs onnxruntime_language_interop onnxruntime_pyop) -endif() - if (NOT IOS) onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc) if(MSVC) @@ -1232,6 +1225,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if (CMAKE_SYSTEM_NAME STREQUAL "Android") list(APPEND onnxruntime_perf_test_libs ${android_shared_libs}) endif() + if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") + list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal) + endif() target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) if(WIN32) target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) @@ -1241,10 +1237,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB) - target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop) - endif() - if (onnxruntime_USE_TVM) if (WIN32) target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") @@ -1286,6 +1278,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS ${android_shared_libs}) endif() + if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") + list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2) + endif() + AddTest(DYN TARGET onnxruntime_shared_lib_test SOURCES ${onnxruntime_shared_lib_test_SRC} ${onnxruntime_unittest_main_src} @@ -1474,10 +1470,6 @@ endif() onnxruntime_flatbuffers ) - if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) - list(APPEND ONNXRUNTIME_TEST_LIBS onnxruntime_language_interop onnxruntime_pyop) - endif() - target_link_libraries(onnxruntime_test_trainer PRIVATE ${ONNXRUNTIME_TEST_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES} @@ -1525,7 +1517,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if(UNIX) if (APPLE) set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker -dead_strip") - else() + elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.lds -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() else() @@ -1589,6 +1581,9 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() + if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") + list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) + endif() AddTest(DYN TARGET onnxruntime_customopregistration_test SOURCES ${onnxruntime_customopregistration_test_SRC} ${onnxruntime_unittest_main_src} @@ -1623,7 +1618,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI if(UNIX) if (APPLE) set(ONNXRUNTIME_CUSTOM_OP_INVALID_LIB_LINK_FLAG "-Xlinker -dead_strip") - else() + elseif (NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") string(CONCAT ONNXRUNTIME_CUSTOM_OP_INVALID_LIB_LINK_FLAG "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_invalid_library/custom_op_library.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") @@ -1654,7 +1649,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI if(UNIX) if (APPLE) set(ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip") - else() + elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") string(CONCAT ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op_lib.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") @@ -1686,7 +1681,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI if(UNIX) if (APPLE) set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip") - else() + elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") string(CONCAT ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") @@ -1705,6 +1700,9 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" ${ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR}/test_logging_apis.cc) set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils) + if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") + list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) + endif() if(NOT WIN32) list(APPEND onnxruntime_logging_apis_test_LIBS nsync::nsync_cpp ${CMAKE_DL_LIBS}) @@ -1768,7 +1766,9 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD if(APPLE) set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst") elseif(UNIX) - set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\\$ORIGIN") + if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") + set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\\$ORIGIN") + endif() elseif(WIN32) set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def") else() diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 584c49d612293..82983646527dc 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -1,8 +1,43 @@ +diff --git a/absl/base/attributes.h b/absl/base/attributes.h +index 5ea5ee3e..f4949898 100644 +--- a/absl/base/attributes.h ++++ b/absl/base/attributes.h +@@ -559,7 +559,7 @@ + #undef ABSL_ATTRIBUTE_UNUSED + #define ABSL_ATTRIBUTE_UNUSED __attribute__((__unused__)) + #else +-#define ABSL_ATTRIBUTE_UNUSED ++#define ABSL_ATTRIBUTE_UNUSED [[maybe_unused]] + #endif + + // ABSL_ATTRIBUTE_INITIAL_EXEC +diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h +index d4fe8f5c..27418d13 100644 +--- a/absl/container/internal/raw_hash_set.h ++++ b/absl/container/internal/raw_hash_set.h +@@ -1924,7 +1924,7 @@ HashtablezInfoHandle SampleHashtablezInfo(size_t sizeof_slot, size_t sizeof_key, + // In SOO, we sample on the first insertion so if this is an empty SOO case + // (e.g. when reserve is called), then we still need to sample. + if (kSooEnabled && was_soo && c.size() == 0) { +- return Sample(sizeof_slot, sizeof_key, sizeof_value, SooCapacity()); ++ return Sample(sizeof_slot, sizeof_key, sizeof_value, (int16_t)SooCapacity()); + } + // For non-SOO cases, we sample whenever the capacity is increasing from zero + // to non-zero. +@@ -3525,7 +3525,7 @@ class raw_hash_set { + assert(is_soo()); + if (!ShouldSampleHashtablezInfo()) return HashtablezInfoHandle{}; + return Sample(sizeof(slot_type), sizeof(key_type), sizeof(value_type), +- SooCapacity()); ++ (int16_t)SooCapacity()); + } + + inline void destroy_slots() { diff --git a/absl/copts/GENERATED_AbseilCopts.cmake b/absl/copts/GENERATED_AbseilCopts.cmake -index a4ab1aa2..dfd13fd7 100644 +index da2282fe..4c7fc26f 100644 --- a/absl/copts/GENERATED_AbseilCopts.cmake +++ b/absl/copts/GENERATED_AbseilCopts.cmake -@@ -129,8 +129,6 @@ list(APPEND ABSL_MSVC_FLAGS +@@ -181,8 +181,6 @@ list(APPEND ABSL_MSVC_FLAGS "/wd4005" "/wd4068" "/wd4180" @@ -10,12 +45,12 @@ index a4ab1aa2..dfd13fd7 100644 - "/wd4267" "/wd4503" "/wd4800" - ) + "/DNOMINMAX" diff --git a/absl/copts/GENERATED_copts.bzl b/absl/copts/GENERATED_copts.bzl -index a6efc98e..8c4de8e7 100644 +index b9e0071e..dd8410ec 100644 --- a/absl/copts/GENERATED_copts.bzl +++ b/absl/copts/GENERATED_copts.bzl -@@ -130,8 +130,6 @@ ABSL_MSVC_FLAGS = [ +@@ -182,8 +182,6 @@ ABSL_MSVC_FLAGS = [ "/wd4005", "/wd4068", "/wd4180", @@ -23,12 +58,12 @@ index a6efc98e..8c4de8e7 100644 - "/wd4267", "/wd4503", "/wd4800", - ] + "/DNOMINMAX", diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index e6e11949..0aa7d868 100644 +index 2d85ac74..4875d668 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py -@@ -115,10 +115,6 @@ MSVC_WARNING_FLAGS = [ +@@ -118,10 +118,6 @@ MSVC_WARNING_FLAGS = [ "/wd4068", # unknown pragma # qualifier applied to function type has no meaning; ignored "/wd4180", diff --git a/cmake/patches/flatbuffers/flatbuffers.patch b/cmake/patches/flatbuffers/flatbuffers.patch index fbe8db37ecb0e..9fb58e301bba8 100644 --- a/cmake/patches/flatbuffers/flatbuffers.patch +++ b/cmake/patches/flatbuffers/flatbuffers.patch @@ -10,3 +10,21 @@ index 3987eac9..5e5462f1 100644 + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow") endif() message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +diff --git a/include/flatbuffers/flatbuffers.h b/include/flatbuffers/flatbuffers.h +index bc828a31..3d3effe8 100644 +--- a/include/flatbuffers/flatbuffers.h ++++ b/include/flatbuffers/flatbuffers.h +@@ -213,7 +213,12 @@ inline const char * const *ElementaryTypeNames() { + // We're explicitly defining the signedness since the signedness of integer + // bitfields is otherwise implementation-defined and causes warnings on older + // GCC compilers. +-struct TypeCode { ++ ++struct ++#if defined(_AIX) && defined(__clang__) ++__attribute__((packed)) ++#endif ++TypeCode { + // ElementaryType + unsigned short base_type : 4; + // Either vector (in table) or array (in struct) diff --git a/cmake/patches/protobuf/protobuf_cmake.patch b/cmake/patches/protobuf/protobuf_cmake.patch index 00e1f6f81d6f0..e315420e4defe 100644 --- a/cmake/patches/protobuf/protobuf_cmake.patch +++ b/cmake/patches/protobuf/protobuf_cmake.patch @@ -12,11 +12,12 @@ index 04cb3303a..c023001de 100644 /wd4305 # 'identifier' : truncation from 'type1' to 'type2' /wd4307 # 'operator' : integral constant overflow /wd4309 # 'conversion' : truncation of constant value -@@ -259,7 +257,6 @@ if (MSVC) +@@ -259,7 +257,7 @@ if (MSVC) /wd4355 # 'this' : used in base member initializer list /wd4506 # no definition for inline function 'function' /wd4800 # 'type' : forcing value to bool 'true' or 'false' (performance warning) - /wd4996 # The compiler encountered a deprecated declaration. ++ ${onnxruntime_PROTOBUF_EXTRA_WARNING_DISABLEMENT} ) # Allow big object add_definitions(/bigobj) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index d74250b962628..ff6b71217ad87 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -718,11 +718,6 @@ target_compile_definitions(winml_dll PRIVATE ONNX_ML) target_compile_definitions(winml_dll PRIVATE LOTUS_LOG_THRESHOLD=2) target_compile_definitions(winml_dll PRIVATE LOTUS_ENABLE_STDERR_LOGGING) target_compile_definitions(winml_dll PRIVATE PLATFORM_WINDOWS) -target_compile_definitions(winml_dll PRIVATE VER_MAJOR=${VERSION_MAJOR_PART}) -target_compile_definitions(winml_dll PRIVATE VER_MINOR=${VERSION_MINOR_PART}) -target_compile_definitions(winml_dll PRIVATE VER_BUILD=${VERSION_BUILD_PART}) -target_compile_definitions(winml_dll PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) -target_compile_definitions(winml_dll PRIVATE VER_STRING=\"${VERSION_STRING}\") target_compile_definitions(winml_dll PRIVATE BINARY_NAME=\"${BINARY_NAME}\") if (onnxruntime_WINML_NAMESPACE_OVERRIDE STREQUAL "Windows") diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 13d925e0fc2ee..44d2222dbce16 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1357,7 +1357,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca OrtAllocatorType allocatorType, int identifier, OrtMemType memType, - out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transfered to caller + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transferred to caller ); public static DOrtCreateMemoryInfo OrtCreateMemoryInfo; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 163a2b394c4ae..5946e9fb1b165 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -22,7 +22,7 @@ public enum OnnxValueType ONNX_TYPE_MAP = 3, // It's a map ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor - ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKOWN) + ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN) } /// @@ -31,7 +31,7 @@ public enum OnnxValueType /// The class implements IDisposable and must /// be disposed of, otherwise native resources will leak /// and will eventually cause the application to slow down or crash. - /// + /// /// If the OrtValue instance is constructed over a managed memory, and it is not /// disposed properly, the pinned memory will continue to be pinned and interfere /// with GC operation. @@ -72,7 +72,7 @@ internal OrtValue(IntPtr handle, OnnxValueType onnxValueType) /// Constructor. The newly constructed OrtValue takes ownership of the native OrtValue instance /// and disposes of it when the OrtValue instance is disposed. The instance will take ownership and will /// dispose of compositeMembers instances. - /// + /// /// This constructor can only throw if OnnxType is not specified. /// /// native ortValue handle @@ -189,10 +189,10 @@ public OrtValue GetValue(int index, OrtAllocator allocator) /// /// Returns a ReadOnlySpan over tensor native buffer that /// provides a read-only view. - /// + /// /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. /// To get memory descriptor use GetTensorMemoryInfo(). - /// + /// /// OrtValue must contain a non-string tensor. /// The span is valid as long as the OrtValue instance is alive (not disposed). /// @@ -210,10 +210,10 @@ public ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged /// This enables you to safely and efficiently modify the underlying /// native buffer in a type-safe manner. This is useful for example in IOBinding scenarios /// where you want to modify results of the inference and feed it back as input. - /// + /// /// Note, that the memory may be device allocated. /// To get memory descriptor use GetTensorMemoryInfo(). - /// + /// /// OrtValue must contain a non-string tensor. /// The span is valid as long as the OrtValue instance is alive (not disposed). /// @@ -237,7 +237,7 @@ public Span GetTensorMutableRawData() /// /// Fetch string tensor element buffer pointer at the specified index, /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory instance. - /// + /// /// Obtain TensorTypeAndShape to get shape and element count. /// /// flat string tensor element index @@ -256,7 +256,7 @@ public ReadOnlyMemory GetStringElementAsMemory(int index) /// /// Fetch string tensor element buffer pointer at the specified index, /// copy/convert UTF-8 into a UTF-16 string and return it. - /// + /// /// Obtain TensorTypeAndShape to get shape and element count. /// /// flat string tensor element index @@ -279,7 +279,7 @@ public string GetStringElement(int index) /// /// Get a span over the native memory of the string tensor element. /// The span is valid as long as the OrtValue is valid. - /// + /// /// This is useful if you want to perform your own UTF-8 decoding or /// you do not care about decoding. /// Obtain TensorTypeAndShape to get shape and element count. @@ -483,7 +483,7 @@ private Span GetTensorBufferRawData(Type requestedType) /// This can be a piece of arbitrary memory that may be allocated by OrtAllocator (possibly on a device), /// a chunk of managed memory (must be pinned for the duration of OrtValue lifetime) or a memory that is allocated /// natively allocated using Marshal.AllocHGlobal(), stackalloc or other means (may be on a device). - /// + /// /// The resulting OrtValue does not own the underlying memory buffer and will not attempt to /// deallocate it. The caller must make sure that the memory remains valid for the duration of OrtValue lifetime. /// @@ -769,12 +769,12 @@ out IntPtr valueHandle /// Converts the string argument represented by ReadOnlySpan to UTF-8, /// allocates space in the native tensor and copies it into the native tensor memory. /// Typically, this is used to populate a new empty string tensor element. - /// + /// /// The number of elements is according to the shape supplied to CreateTensorWithEmptyStrings(). /// However, this API can also be used to overwrite any existing element within the string tensor. - /// + /// /// In general, to obtain the number of elements for any tensor, use GetTensorTypeAndShape() which - /// would return a disposable instance of TensorTypeAndShapeInfo. + /// would return a disposable instance of TensorTypeAndShapeInfo. /// Then call GetElementCount() or GetShape(). /// /// ReadOnlySpan over chars @@ -795,12 +795,12 @@ public void StringTensorSetElementAt(ReadOnlySpan str, int index) /// Converts the string argument represented by ReadOnlyMemory to UTF-8, /// allocates space in the native tensor and copies it into the native tensor memory. /// Typically, this is used to populate a new empty string tensor element. - /// + /// /// The number of elements is according to the shape supplied to CreateTensorWithEmptyStrings(). /// However, this API can also be used to overwrite any existing element within the string tensor. - /// + /// /// In general, to obtain the number of elements for any tensor, use GetTensorTypeAndShape() which - /// would return a disposable instance of TensorTypeAndShapeInfo. + /// would return a disposable instance of TensorTypeAndShapeInfo. /// Then call GetElementCount() or GetShape(). /// /// @@ -815,7 +815,7 @@ public void StringTensorSetElementAt(ReadOnlyMemory rom, int index) /// /// This API resizes String Tensor element to the requested amount of bytes (UTF-8) /// and copies the bytes from the supplied ReadOnlySpan into the native tensor memory (resized buffer). - /// + /// /// The API is useful for quick loading of utf8 data into the native tensor memory. /// /// read only span of bytes @@ -841,7 +841,7 @@ public void StringTensorSetElementAt(ReadOnlySpan utf8Bytes, int index) /// Creates an OrtValue that contains a string tensor. /// String tensors are always allocated on CPU. /// String data will be converted to UTF-8 and copied to native memory. - /// + /// /// Note, this is different from creating an OrtValue from other primitive data types /// where memory is pinned (if necessary) and the OrtValue points to that chunk of memory. /// @@ -885,10 +885,10 @@ public static OrtValue CreateFromStringTensor(Tensor tensor) /// Creates a sequence of OrtValues from a collection of OrtValues. /// All OrtValues in the collection must be of the same Onnx type. /// I.e. (Tensor, SparseTensor, Map, Sequence, etc.) - /// + /// /// The ortValues that are passed as argument are taken possession of by the newly /// created OrtValue. The caller should not dispose them, unless this call fails. - /// + /// /// The ortValues would be empty on successful return. /// /// a collection of OrtValues. On success the ortValues contained in the list @@ -978,24 +978,24 @@ public void ProcessSequence(SequenceElementVisitor visitor, OrtAllocator allocat /// Creates a map OrtValue with keys and values. /// On a high level the Onnxruntime representation of the map always consists of two /// OrtValues, keys and values. - /// + /// /// According to ONNX standard map keys can be unmanaged types only (or strings). /// Those keys are contained in a single tensor within OrtValue keys. - /// + /// /// Map values, on the other hand, can be composite types. The values parameter /// can either contain a single tensor with unmanaged map values with the same number of /// elements as the keys, or it can be a sequence of OrtValues, /// each of those can be a composite type (tensor, sequence, map). If it is a sequence, /// then the number of elements must match the number of elements in keys. - /// + /// /// Keys and values must be in the same order. - /// + /// /// ORT supports only a subset of types for keys and values, however, this API does not /// restrict it. - /// + /// /// The ortValues that are passed as argument are taken possession of by the newly /// created OrtValue. The caller should not dispose them, unless this call fails. - /// + /// /// Keys and values arguments will be set to null on success. /// /// Contains keys @@ -1031,10 +1031,10 @@ public static OrtValue CreateMap(ref OrtValue keys, ref OrtValue values) /// This API helps to quickly creates a map OrtValue with unmanaged (primitive) keys and values specified as arrays. /// This helps the user not to create OrtValues for keys and values separately and deal only with the final result. /// The map would consist of two tensors, one for keys and one for values. - /// + /// /// The OrtValues would be created on top of the managed memory arrays and use it directly. /// The number of elements in keys and values must be the same and they must be in order. - /// + /// /// The types must be unmanaged. /// /// keys type @@ -1078,10 +1078,10 @@ public static OrtValue CreateMap(K[] keys, V[] values) where K : unmanaged /// This helps the user not to create OrtValues for keys and values separately. /// The number of elements in keys and values must be the same and they must be in order. /// The map would consist of two tensors, one for keys and one for values. - /// + /// /// string keys would be converted to UTF-8 encoding and copied to an allocated native memory. /// The OrtValue for values would be created on top of the managed memory using it directly. - /// + /// /// The values type must be unmanaged. /// /// @@ -1128,13 +1128,13 @@ public static OrtValue CreateMapWithStringKeys(IReadOnlyCollection ke /// /// Creates a map OrtValue with non-string keys and string values. - /// + /// /// This helps the user not to create OrtValues for keys and values separately. /// The number of elements in keys and values must be the same and they must be in order. - /// + /// /// The OrtValue for keys would be created on top of the managed memory using it directly. /// string values would be converted to UTF-8 encoding and copied to an allocated native memory. - /// + /// /// /// unmanaged type of keys /// @@ -1182,17 +1182,17 @@ public static OrtValue CreateMapWithStringValues(K[] keys, IReadOnlyCollectio /// Typically, when one uses GetValue() API, it creates a copy of OrtValue /// that points to the same buffer as keys or values. This API helps to deal with those /// temporary instances and avoid leaks. - /// + /// /// According to ONNX standard map keys can be unmanaged types only (or strings). /// Those keys are contained in a single tensor within OrtValue keys. So you can query those /// directly from keys argument. - /// + /// /// Map values, on the other hand, can be composite types. The values parameter /// can either contain a single tensor with unmanaged map values with the same number of /// elements as the keys, or it can be a sequence of OrtValues, /// each of those can be a composite type (tensor, sequence, map). If it is a sequence, /// then the number of elements must match the number of elements in keys. - /// + /// /// Depending on the structure of the values, one will either directly query a single tensor /// from values, or will have to iterate over the sequence of OrtValues and visit each of those /// resulting in a recursive visitation. @@ -1204,7 +1204,7 @@ public static OrtValue CreateMapWithStringValues(K[] keys, IReadOnlyCollectio /// /// This API helps the user to process a map OrtValue without /// having to deal with the lifespan of intermediate OrtValues. - /// + /// /// each API value is fed to the vistor functor. /// /// visitor function diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index d6a6b9627f418..0892e17fc97bc 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -116,7 +116,7 @@ public void TestSessionOptions() var directml_dll_path = AppDomain.CurrentDomain.BaseDirectory; SetDllDirectory(directml_dll_path); - + try { opt.AppendExecutionProvider_DML(0); @@ -124,7 +124,7 @@ public void TestSessionOptions() catch (OnnxRuntimeException ortException) { // if we run on a CI machine with the incorrect hardware we might get an error due to that. - // allow that as the call made it through to the DML EP so the C# layer is working correctly. + // allow that as the call made it through to the DML EP so the C# layer is working correctly. // any other exception type or error message is considered a failure. Assert.Contains("The specified device interface or feature level is not supported on this system.", ortException.Message); @@ -1895,7 +1895,7 @@ private void TestSharedAllocatorUsingCreateAndRegisterAllocator() sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); // Create two sessions to share the allocator - // Create a thrid session that DOES NOT use the allocator in the environment + // Create a third session that DOES NOT use the allocator in the environment using (var session1 = new InferenceSession(model, sessionOptions)) using (var session2 = new InferenceSession(model, sessionOptions)) using (var session3 = new InferenceSession(model)) // Use the default SessionOptions instance @@ -2127,7 +2127,7 @@ private void TestLoadAzureEP() } catch (Exception) { Assert.True(false); - } + } } } } diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs index 9370a03f7fbeb..a005efa749a1c 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs @@ -32,7 +32,7 @@ class CommandOptions [Option('i', "input_file", Required = false, HelpText = "Input file.")] public string InputFile { get; set; } - [Option('p', Required = false, HelpText = "Run with parallel exection. Default is false")] + [Option('p', Required = false, HelpText = "Run with parallel execution. Default is false")] public bool ParallelExecution { get; set; } = false; [Option('o', "optimization_level", Required = false, HelpText = "Optimization Level. Default is 99, all optimization.")] diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45306c852a906..ed9e2a0567d2f 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5646,7 +5646,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16), tensor(bfloat16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output to float tensors.
M : tensor(int32)
Constrain integer type.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5f19c16cba616..df5897529baae 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -512,6 +512,7 @@ Do not modify directly.* |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| diff --git a/include/onnxruntime/core/common/gsl.h b/include/onnxruntime/core/common/gsl.h deleted file mode 100644 index 371c5b7543b54..0000000000000 --- a/include/onnxruntime/core/common/gsl.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "gsl/gsl" diff --git a/include/onnxruntime/core/common/logging/capture.h b/include/onnxruntime/core/common/logging/capture.h index 2af050918706a..13d3a3ad17aff 100644 --- a/include/onnxruntime/core/common/logging/capture.h +++ b/include/onnxruntime/core/common/logging/capture.h @@ -4,7 +4,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/code_location.h" #include "core/common/logging/severity.h" diff --git a/include/onnxruntime/core/common/span_utils.h b/include/onnxruntime/core/common/span_utils.h index b2d1aefee9c04..9f7454625fcd1 100644 --- a/include/onnxruntime/core/common/span_utils.h +++ b/include/onnxruntime/core/common/span_utils.h @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/include/onnxruntime/core/eager/ort_kernel_invoker.h b/include/onnxruntime/core/eager/ort_kernel_invoker.h index 1d1046742db4b..fcf92de2ee39a 100644 --- a/include/onnxruntime/core/eager/ort_kernel_invoker.h +++ b/include/onnxruntime/core/eager/ort_kernel_invoker.h @@ -24,7 +24,10 @@ class ORTInvoker { public: ORTInvoker(std::shared_ptr execution_provider, const logging::Logger& logger, - const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) : execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) { + const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) + : execution_provider_(std::move(execution_provider)), + logger_(logger), + custom_op_registries_(custom_op_registries) { if (!execution_provider_) { ORT_THROW("Execution provider is nullptr"); } diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index b197d88090432..87feefa10ca4a 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -9,7 +9,7 @@ #include #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/exceptions.h" #include "core/framework/endian.h" diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 16ad943a5f47e..49c3d1bdd088a 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -231,7 +231,7 @@ class IExecutionProvider { const std::reference_wrapper filtered_graph; }; - // Fusion approach that is suppported + // Fusion approach that is supported // !!! The "Function" FusionStyle is deprecated. // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style. enum class FusionStyle { diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index aff365dc9738f..0282b84bd0f82 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -6,7 +6,7 @@ #include #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 94c6d81ee9326..07625c38d8474 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -25,7 +25,7 @@ #include "core/graph/constants.h" #include "core/graph/graph_viewer.h" #include "core/graph/onnx_protobuf.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { class OpKernelContext; } @@ -94,7 +94,7 @@ class OpKernel { } // Override this function to return a list of attributes the session can safely remove - // after it is intialized and saved. This option is useful to reduce memory usage + // after it is initialized and saved. This option is useful to reduce memory usage // when the kernel does not reuse the operator attributes but copies them. // All attributes returned by this method will be removed by method // PruneRemovableAttributes of they exists. diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index a0bbfe50a700b..1510cdc9d1459 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -8,7 +8,7 @@ #include "core/framework/ort_value.h" #include "core/framework/op_node_proto_helper.h" #include "core/graph/graph_viewer.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/include/onnxruntime/core/framework/op_node_proto_helper.h b/include/onnxruntime/core/framework/op_node_proto_helper.h index e7ac01947af41..5cbaaa0212c5f 100644 --- a/include/onnxruntime/core/framework/op_node_proto_helper.h +++ b/include/onnxruntime/core/framework/op_node_proto_helper.h @@ -7,7 +7,7 @@ #include "core/common/status.h" #include "core/framework/tensor_shape.h" #include "core/graph/graph_viewer.h" -#include "core/common/gsl.h" +#include #endif class IMLOpKernel; diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 26d78133b52fc..9c987f10ccadb 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -54,14 +54,14 @@ class Stream { // update its lookup table with the table snapshot in notification. // The memory reusing strategy is: // A kernel in current stream is safe to reuse another stream's memory chunk - // as long as the reused chunk's timestamp is less than the last synchonized + // as long as the reused chunk's timestamp is less than the last synchronized // timestamp recorded in the lookup table. // Get the current timestamp uint64_t GetCurrentTimestamp() const { return timestamp_; } // return the timestamp when the last synchronization happened between target stream and current stream. - // return 0 if no synchonization happened. + // return 0 if no synchronization happened. // if target_stream is nullptr, it means it is a sequence running on device doesn't support Stream (i.e. CPU) // we can safely return 0 in that case to save a lookup. uint64_t GetLastSyncTimestampWithTargetStream(Stream* target_stream) const { diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 96725aa103064..dd2603d214f63 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -9,7 +9,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/allocator.h" #include "core/framework/tensor_shape.h" diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index 82a1c1de83523..d4ee4a0e5e649 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -9,7 +9,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/common/span_utils.h" #include "onnxruntime_config.h" diff --git a/include/onnxruntime/core/graph/basic_types.h b/include/onnxruntime/core/graph/basic_types.h index 36984d0405bbd..cdd5e4c1e571b 100644 --- a/include/onnxruntime/core/graph/basic_types.h +++ b/include/onnxruntime/core/graph/basic_types.h @@ -19,6 +19,8 @@ class TensorProto; class SparseTensorProto; class TypeProto; class AttributeProto; +class FunctionProto; +class OperatorSetIdProto; // define types that would come from the ONNX library if we were building against it. #if defined(ORT_MINIMAL_BUILD) using OperatorSetVersion = int; diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index c4a46cd422219..39acb6b4f2aa4 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -52,6 +52,7 @@ constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; +constexpr const char* kVSINPUExecutionProvider = "VSINPUExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 4f3377f0aa0c0..9289e14c17dd1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -10,28 +10,19 @@ #include #include #include - -#ifdef _WIN32 -#pragma warning(push) -// disable some warnings from protobuf to pass Windows build -#pragma warning(disable : 4244) -#endif - -#ifdef _WIN32 -#pragma warning(pop) -#endif +#include #include "core/common/flatbuffers.h" -#include "core/common/gsl.h" +#include #include "core/common/common.h" +#include "core/common/path_string.h" #include "core/common/const_pointer_container.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/common/inlined_containers.h" #endif #include "core/common/inlined_containers_fwd.h" -#include "core/common/path.h" #include "core/common/span_utils.h" #include "core/common/status.h" #include "core/common/logging/logging.h" @@ -147,7 +138,7 @@ class Node { const std::string& Domain() const noexcept { return domain_; } /** Gets the path of the owning model if any. */ - const Path& ModelPath() const noexcept; + const std::filesystem::path& ModelPath() const noexcept; /** Gets the Node's execution priority. @remarks Lower value means higher priority */ @@ -693,7 +684,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const std::string& Description() const noexcept; /** Gets the path of the owning model, if any. */ - const Path& ModelPath() const; + const std::filesystem::path& ModelPath() const; /** Returns true if this is a subgraph or false if it is a high-level graph. */ bool IsSubgraph() const { return parent_graph_ != nullptr; } @@ -1149,13 +1140,14 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi ONNX_NAMESPACE::GraphProto ToGraphProto() const; /** Gets the GraphProto representation of this Graph - @params external_file_name name of the binary file to use for initializers + @param external_file_path File path of the binary file to use for initializers. + @param model_file_path path of the model file. @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved in the external file. Initializer smaller than this threshold are included in the onnx file. @returns GraphProto serialization of the graph. */ - ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, - const PathString& file_path, + ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, + const std::filesystem::path& model_file_path, size_t initializer_size_threshold) const; /** Gets the ISchemaRegistry instances being used with this Graph. */ diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 1816099d3210f..9385e2f092e58 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -2,10 +2,11 @@ // Licensed under the MIT License. #pragma once +#include +#include #include "core/graph/graph.h" #include "core/framework/session_options.h" -#include namespace onnxruntime { class Function; @@ -43,7 +44,7 @@ class GraphViewer { const std::string& Description() const noexcept; /** Gets the path of the owning model if any **/ - const Path& ModelPath() const noexcept { return graph_->ModelPath(); } + const std::filesystem::path& ModelPath() const noexcept { return graph_->ModelPath(); } /** Gets a tensor created from an initializer. diff --git a/include/onnxruntime/core/optimizer/graph_transformer_config.h b/include/onnxruntime/core/optimizer/graph_transformer_config.h index c112d9b0480ab..6af48331270cd 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_config.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_config.h @@ -13,7 +13,7 @@ struct GraphTransformerConfiguration { /* * Cast propagation strategy. * One strategy is to insert casts around all the nodes with the allowed opcodes - * and reduce, by removing redundent-casts and back-to-back-casts etc., and + * and reduce, by removing redundant-casts and back-to-back-casts etc., and * the other is to propagate casts using flood-fill approach, expanding float16 regions in the graph * traversing the graph up/down. */ @@ -70,4 +70,4 @@ constexpr GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy constexpr bool operator==(GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy, GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy); constexpr bool operator!=(GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy, GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy); -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 7104e70c3a8a9..9ada01673d4d9 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -12,14 +12,16 @@ #define ORT_CUDA_CTX -#include "cuda_resource.h" -#include "core/providers/custom_op_context.h" #include #include #ifndef USE_CUDA_MINIMAL #include #include #endif + +#include "core/providers/cuda/cuda_resource.h" +#include "core/providers/custom_op_context.h" + namespace Ort { namespace Custom { diff --git a/include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h b/include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h new file mode 100644 index 0000000000000..a84067a19aa8a --- /dev/null +++ b/include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h @@ -0,0 +1,34 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include "onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_VSINPU, _In_ OrtSessionOptions* options); + +#ifdef __cplusplus +} +#endif diff --git a/java/build-android.gradle b/java/build-android.gradle index afbad9f03d08d..fd22fa27e8db9 100644 --- a/java/build-android.gradle +++ b/java/build-android.gradle @@ -105,7 +105,7 @@ task sourcesJar(type: Jar) { task javadoc(type: Javadoc) { source = android.sourceSets.main.java.srcDirs - classpath += project.files(android.getBootClasspath().join(File.pathSeparator)) + classpath += project.files(android.getBootClasspath()) } task javadocJar(type: Jar, dependsOn: javadoc) { diff --git a/java/build.gradle b/java/build.gradle index cebf67e085446..3219b082994ff 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -166,11 +166,14 @@ if (cmakeBuildDir != null) { } tasks.register('cmakeCheck', Copy) { + group = 'verification' from layout.buildDirectory.get() include 'reports/**' into cmakeBuildOutputDir dependsOn(check) } +} else { + println "cmakeBuildDir is not set. Skipping cmake tasks." } dependencies { diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 3915292648aee..1be8c22b40da8 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -349,7 +349,7 @@ private SessionOptions parseSessionOptions(ReadableMap options) throws OrtExcept if (options.hasKey("interOpNumThreads")) { int interOpNumThreads = options.getInt("interOpNumThreads"); if (interOpNumThreads > 0 && interOpNumThreads < Integer.MAX_VALUE) { - sessionOptions.setIntraOpNumThreads(interOpNumThreads); + sessionOptions.setInterOpNumThreads(interOpNumThreads); } } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 919b005ec4c21..3ee9441eeb981 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -58,7 +58,7 @@ Do not modify directly.* | HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | -| LayerNormalization | ai.onnx(17+) | | +| LayerNormalization | ai.onnx(1-16,17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | | Less | ai.onnx(7-8,9-12,13+) | | | LessOrEqual | ai.onnx(12-15,16+) | | diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 4c6dab84fa973..8d077846fa6a4 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -16,12 +16,12 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | | ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | -| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✗ | ✓ | Only supports 'training_mode' value is 0, one output | +| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✓ | ✓ | Only supports 'training_mode' value is 0, one output | | Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✓ | ✓ | WebNN CPU backend doesn't support casting to uint64 data type | | Ceil | ai.onnx(7-12, 13+) | ceil | ✓ | ✓ | | | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | -| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU requires the 'W' (weight) input to be a constant | +| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✗ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | @@ -29,11 +29,11 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | | Erf | ai.onnx(7-9, 10-12, 13+) | erf | ✗ | ✓ | | | Exp | ai.onnx(7-12, 13+) | exp | ✓ | ✓ | | -| Expand | ai.onnx(8-12, 13+) | expand | ✗ | ✓ | 'shape' input should be a constant | +| Expand | ai.onnx(8-12, 13+) | expand | ✓ | ✓ | 'shape' input should be a constant | | Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | | | Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | | -| Gelu | ai.onnx(20+) | gelu | ✗ | ✓ | | +| Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | | | Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input | | GlobalAveragePool | ai.onnx(7+) | averagePool2d | ✓ | ✓ | Only supports 4-D input | | GlobalMaxPool | ai.onnx(7+) | maxPool2d | ✓ | ✓ | Only supports 4-D input | @@ -43,8 +43,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | HardSigmoid | ai.onnx(7+) | hardSigmoid | ✓ | ✓ | | | HardSwish | ai.onnx(14+) | hardSwish | ✓ | ✓ | | | Identity | ai.onnx(7-13, 14-15, 16-18, 19-20, 21+) | identity | ✓ | ✓ | | -| InstanceNormalization | ai.onnx(7+) | instanceNormalization | ✗ | ✓ | | -| LayerNormalization | ai.onnx(7-16, 17+) | layerNormalization | ✗ | ✓ | | +| InstanceNormalization | ai.onnx(7+) | instanceNormalization | ✓ | ✓ | | +| LayerNormalization | ai.onnx(7-16, 17+) | layerNormalization | ✓ | ✓ | | | LeakyRelu | ai.onnx(7-15, 16+) | leakyRelu | ✓ | ✓ | | | Less | ai.onnx(7-8, 9-12, 13+) | lesser | ✓ | ✓ | | | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | @@ -59,25 +59,25 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Not | ai.onnx(7+) | logicalnot | ✓ | ✓ | | | Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported | | Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | | -| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU restricts slope to be a static value | -| Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✗ | ✓ | | -| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✗ | ✓ | Input 'axes' if present should be a constant | -| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✗ | ✓ | Input 'axes' if present should be a constant | -| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum| ✗ | ✓ | Input 'axes' if present should be a constant | -| ReduceLogSumExp | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSumExp | ✗ | ✓ | Input 'axes' if present should be a constant | +| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) | +| Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✓ | ✓ | | +| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✓ | ✓ | Input 'axes' if present should be a constant | +| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✓ | ✓ | Input 'axes' if present should be a constant | +| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum| ✓ | ✓ | Input 'axes' if present should be a constant | +| ReduceLogSumExp | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSumExp | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceMax | ai.onnx(7-10, 11, 12, 13-17, 18-19, 20+) | reduceMax | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceMean | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceMean | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceMin | ai.onnx(7-10, 11, 12, 13-17, 18-19, 20+) | reduceMin | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceProd | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceProduct | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceSum | ai.onnx(7-10, 11-12, 13+) | reduceSum | ✓ | ✓ | Input 'axes' if present should be a constant | -| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✗ | ✓ | Input 'axes' if present should be a constant | +| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✓ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | | Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | -| Softsign | ai.onnx(7+) | softsign | ✗ | ✓ | | +| Softsign | ai.onnx(7+) | softsign | ✓ | ✓ | | | Sin | ai.onnx(7+) | sin | ✓ | ✓ | | | Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 | | Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | | diff --git a/js/web/lib/onnxjs/execution-plan.ts b/js/web/lib/onnxjs/execution-plan.ts index 5599087ab46f5..e155ff123f79d 100644 --- a/js/web/lib/onnxjs/execution-plan.ts +++ b/js/web/lib/onnxjs/execution-plan.ts @@ -92,7 +92,7 @@ export class ExecutionPlan { const inputTensors = inputList as Tensor[]; Logger.verbose( 'ExecPlan', - `Runing op:${thisOp.node.name} (${ + `Running op:${thisOp.node.name} (${ inputTensors.map((t, i) => `'${thisOp.node.inputs[i]}': ${t.type}[${t.dims.join(',')}]`).join(', ')})`); const outputList = await this.profiler.event( diff --git a/js/web/lib/onnxjs/graph.ts b/js/web/lib/onnxjs/graph.ts index f16da42815957..d444be2bf7ce0 100644 --- a/js/web/lib/onnxjs/graph.ts +++ b/js/web/lib/onnxjs/graph.ts @@ -674,7 +674,7 @@ class GraphImpl implements Graph, Graph.Transformer { } /** - * Delete the specifed node. Assume the node has one incoming input and the first output connected to other nodes. + * Delete the specified node. Assume the node has one incoming input and the first output connected to other nodes. * An input validation must be done before calling this function. * @param nodeIndex The index of node to be deleted */ diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index d697a8b3138cf..22c4e4c755f55 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -474,7 +474,7 @@ export class ProtoUtil { export class LongUtil { // This function is called to get a number from long type of data for attribute, dim, and ir version, // which values are signed integers. - // To make it more generic, add an optional paramter to convert to a unsigned number. + // To make it more generic, add an optional parameter to convert to a unsigned number. static longToNumber(n: Long|flatbuffers.Long|number, unsigned?: boolean) { if (Long.isLong(n)) { return n.toNumber(); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 29c7941e6bd30..9b37247167bab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -328,13 +328,6 @@ fn main(@builtin(local_invocation_id) localId : vec3, var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; - - // Without this initialization strange values show up in acc. - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { - acc[innerRow][innerCol] = 0.0; - } - } ${matmulSnippet} } `; diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts new file mode 100644 index 0000000000000..f8a1e1966fd4c --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +interface NavigatorML { + readonly ml: ML; +} +interface Navigator extends NavigatorML {} +interface WorkerNavigator extends NavigatorML {} +type MLDeviceType = 'cpu'|'gpu'|'npu'; +type MLPowerPreference = 'default'|'high-performance'|'low-power'; +interface MLContextOptions { + deviceType?: MLDeviceType; + powerPreference?: MLPowerPreference; + numThreads?: number; +} +interface ML { + createContext(options?: MLContextOptions): Promise; + createContext(gpuDevice: GPUDevice): Promise; +} +type MLNamedArrayBufferViews = Record; +interface MLComputeResult { + inputs?: MLNamedArrayBufferViews; + outputs?: MLNamedArrayBufferViews; +} +interface MLContext { + compute(graph: MLGraph, inputs: MLNamedArrayBufferViews, outputs: MLNamedArrayBufferViews): Promise; +} +interface MLGraph {} +type MLInputOperandLayout = 'nchw'|'nhwc'; +type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; +interface MLOperandDescriptor { + dataType: MLOperandDataType; + dimensions?: number[]; +} +interface MLOperand { + dataType(): MLOperandDataType; + shape(): number[]; +} +interface MLActivation {} +type MLNamedOperands = Record; +interface MLGraphBuilder { + // eslint-disable-next-line @typescript-eslint/no-misused-new + new(context: MLContext): MLGraphBuilder; + input(name: string, descriptor: MLOperandDescriptor): MLOperand; + constant(descriptor: MLOperandDescriptor, bufferView: ArrayBufferView): MLOperand; + constant(type: MLOperandDataType, value: number): MLOperand; + build(outputs: MLNamedOperands): Promise; +} +interface MLArgMinMaxOptions { + axes?: number[]; + keepDimensions?: boolean; + selectLastIndex?: boolean; +} +interface MLGraphBuilder { + argMin(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; + argMax(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; +} +interface MLBatchNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axis?: number; + epsilon?: number; +} +interface MLGraphBuilder { + batchNormalization(input: MLOperand, mean: MLOperand, variance: MLOperand, options?: MLBatchNormalizationOptions): + MLOperand; +} +interface MLGraphBuilder { + cast(input: MLOperand, type: MLOperandDataType): MLOperand; +} +interface MLClampOptions { + minValue?: number; + maxValue?: number; +} +interface MLGraphBuilder { + clamp(input: MLOperand, options?: MLClampOptions): MLOperand; + clamp(options?: MLClampOptions): MLActivation; +} +interface MLGraphBuilder { + concat(inputs: MLOperand[], axis: number): MLOperand; +} +type MLConv2dFilterOperandLayout = 'oihw'|'hwio'|'ohwi'|'ihwo'; +interface MLConv2dOptions { + padding?: number[]; + strides?: number[]; + dilations?: number[]; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConv2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + conv2d(input: MLOperand, filter: MLOperand, options?: MLConv2dOptions): MLOperand; +} +type MLConvTranspose2dFilterOperandLayout = 'iohw'|'hwoi'|'ohwi'; +interface MLConvTranspose2dOptions { + padding?: number[]; + strides?: number[]; + dilations?: number[]; + outputPadding?: number[]; + outputSizes?: number[]; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConvTranspose2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + convTranspose2d(input: MLOperand, filter: MLOperand, options?: MLConvTranspose2dOptions): MLOperand; +} +interface MLGraphBuilder { + add(a: MLOperand, b: MLOperand): MLOperand; + sub(a: MLOperand, b: MLOperand): MLOperand; + mul(a: MLOperand, b: MLOperand): MLOperand; + div(a: MLOperand, b: MLOperand): MLOperand; + max(a: MLOperand, b: MLOperand): MLOperand; + min(a: MLOperand, b: MLOperand): MLOperand; + pow(a: MLOperand, b: MLOperand): MLOperand; +} +interface MLGraphBuilder { + equal(a: MLOperand, b: MLOperand): MLOperand; + greater(a: MLOperand, b: MLOperand): MLOperand; + greaterOrEqual(a: MLOperand, b: MLOperand): MLOperand; + lesser(a: MLOperand, b: MLOperand): MLOperand; + lesserOrEqual(a: MLOperand, b: MLOperand): MLOperand; + logicalNot(a: MLOperand): MLOperand; +} +interface MLGraphBuilder { + abs(input: MLOperand): MLOperand; + ceil(input: MLOperand): MLOperand; + cos(input: MLOperand): MLOperand; + erf(input: MLOperand): MLOperand; + exp(input: MLOperand): MLOperand; + floor(input: MLOperand): MLOperand; + identity(input: MLOperand): MLOperand; + log(input: MLOperand): MLOperand; + neg(input: MLOperand): MLOperand; + reciprocal(input: MLOperand): MLOperand; + sin(input: MLOperand): MLOperand; + sqrt(input: MLOperand): MLOperand; + tan(input: MLOperand): MLOperand; +} +interface MLEluOptions { + alpha?: number; +} +interface MLGraphBuilder { + elu(input: MLOperand, options?: MLEluOptions): MLOperand; + elu(options?: MLEluOptions): MLActivation; +} +interface MLGraphBuilder { + expand(input: MLOperand, newShape: number[]): MLOperand; +} +interface MLGatherOptions { + axis?: number; +} +interface MLGraphBuilder { + gather(input: MLOperand, indices: MLOperand, options?: MLGatherOptions): MLOperand; +} +interface MLGraphBuilder { + gelu(input: MLOperand): MLOperand; + gelu(): MLActivation; +} +interface MLGemmOptions { + c?: MLOperand; + alpha?: number; + beta?: number; + aTranspose?: boolean; + bTranspose?: boolean; +} +interface MLGraphBuilder { + gemm(a: MLOperand, b: MLOperand, options?: MLGemmOptions): MLOperand; +} +type MLGruWeightLayout = 'zrn'|'rzn'; +type MLRecurrentNetworkDirection = 'forward'|'backward'|'both'; +interface MLGruOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + initialHiddenState?: MLOperand; + resetAfter?: boolean; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLGruWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + gru(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, + options?: MLGruOptions): MLOperand[]; +} +interface MLGruCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + resetAfter?: boolean; + layout?: MLGruWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + gruCell( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, hiddenSize: number, + options?: MLGruCellOptions): MLOperand; +} +interface MLHardSigmoidOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + hardSigmoid(input: MLOperand, options?: MLHardSigmoidOptions): MLOperand; + hardSigmoid(options?: MLHardSigmoidOptions): MLActivation; +} +interface MLGraphBuilder { + hardSwish(input: MLOperand): MLOperand; + hardSwish(): MLActivation; +} +interface MLInstanceNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + epsilon?: number; + layout?: MLInputOperandLayout; +} +interface MLGraphBuilder { + instanceNormalization(input: MLOperand, options?: MLInstanceNormalizationOptions): MLOperand; +} +interface MLLayerNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axes?: number[]; + epsilon?: number; +} +interface MLGraphBuilder { + layerNormalization(input: MLOperand, options?: MLLayerNormalizationOptions): MLOperand; +} +interface MLLeakyReluOptions { + alpha?: number; +} +interface MLGraphBuilder { + leakyRelu(input: MLOperand, options?: MLLeakyReluOptions): MLOperand; + leakyRelu(options?: MLLeakyReluOptions): MLActivation; +} +interface MLLinearOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + linear(input: MLOperand, options?: MLLinearOptions): MLOperand; + linear(options?: MLLinearOptions): MLActivation; +} +type MLLstmWeightLayout = 'iofg'|'ifgo'; +interface MLLstmOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + initialHiddenState?: MLOperand; + initialCellState?: MLOperand; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLLstmWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + lstm( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, + options?: MLLstmOptions): MLOperand[]; +} +interface MLLstmCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + layout?: MLLstmWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + lstmCell( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, cellState: MLOperand, + hiddenSize: number, options?: MLLstmCellOptions): MLOperand[]; +} +interface MLGraphBuilder { + matmul(a: MLOperand, b: MLOperand): MLOperand; +} +type MLPaddingMode = 'constant'|'edge'|'reflection'|'symmetric'; +interface MLPadOptions { + mode?: MLPaddingMode; + value?: number; +} +interface MLGraphBuilder { + pad(input: MLOperand, beginningPadding: number[], endingPadding: number[], options?: MLPadOptions): MLOperand; +} +type MLRoundingType = 'floor'|'ceil'; +interface MLPool2dOptions { + windowDimensions?: number[]; + padding?: number[]; + strides?: number[]; + dilations?: number[]; + layout?: MLInputOperandLayout; + roundingType?: MLRoundingType; + outputSizes?: number[]; +} +interface MLGraphBuilder { + averagePool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + l2Pool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + maxPool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; +} +interface MLGraphBuilder { + prelu(input: MLOperand, slope: MLOperand): MLOperand; +} +interface MLReduceOptions { + axes?: number[]; + keepDimensions?: boolean; +} +interface MLGraphBuilder { + reduceL1(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceL2(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSumExp(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMax(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMean(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMin(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceProduct(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSumSquare(input: MLOperand, options?: MLReduceOptions): MLOperand; +} +interface MLGraphBuilder { + relu(input: MLOperand): MLOperand; + relu(): MLActivation; +} +type MLInterpolationMode = 'nearest-neighbor'|'linear'; +interface MLResample2dOptions { + mode?: MLInterpolationMode; + scales?: number[]; + sizes?: number[]; + axes?: number[]; +} +interface MLGraphBuilder { + resample2d(input: MLOperand, options?: MLResample2dOptions): MLOperand; +} +interface MLGraphBuilder { + reshape(input: MLOperand, newShape: number[]): MLOperand; +} +interface MLGraphBuilder { + sigmoid(input: MLOperand): MLOperand; + sigmoid(): MLActivation; +} +interface MLGraphBuilder { + slice(input: MLOperand, starts: number[], sizes: number[]): MLOperand; +} +interface MLGraphBuilder { + softmax(input: MLOperand, axis: number): MLOperand; + softmax(axis: number): MLActivation; +} +interface MLGraphBuilder { + softplus(input: MLOperand): MLOperand; + softplus(): MLActivation; +} +interface MLGraphBuilder { + softsign(input: MLOperand): MLOperand; + softsign(): MLActivation; +} +interface MLSplitOptions { + axis?: number; +} +interface MLGraphBuilder { + split(input: MLOperand, splits: number|number[], options?: MLSplitOptions): MLOperand[]; +} +interface MLGraphBuilder { + tanh(input: MLOperand): MLOperand; + tanh(): MLActivation; +} +interface MLTransposeOptions { + permutation?: number[]; +} +interface MLGraphBuilder { + transpose(input: MLOperand, options?: MLTransposeOptions): MLOperand; +} +interface MLTriangularOptions { + upper?: boolean; + diagonal?: number; +} +interface MLGraphBuilder { + triangular(input: MLOperand, options?: MLTriangularOptions): MLOperand; +} +interface MLGraphBuilder { + where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand; +} + +// Experimental MLBuffer interface + +type MLSize64Out = number; +interface MLBuffer { + readonly size: MLSize64Out; + destroy(): void; +} +type MLSize64 = number; +interface MLBufferDescriptor { + size: MLSize64; +} +type MLNamedBuffers = Record; +interface MLContext { + createBuffer(descriptor: MLBufferDescriptor): MLBuffer; + writeBuffer( + dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64, + srcElementSize?: MLSize64): void; + readBuffer(srcBuffer: MLBuffer): Promise; + dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; +} diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 4d2b80e31a47e..f289fc20bba40 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -66,8 +66,6 @@ const setExecutionProviders = const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; - const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; - const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; if (deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); const valueDataOffset = allocWasmString(deviceType, allocs); @@ -76,26 +74,6 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } } - if (numThreads !== undefined) { - // Just ignore invalid webnnOptions.numThreads. - const validatedNumThreads = - (typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 : - numThreads; - const keyDataOffset = allocWasmString('numThreads', allocs); - const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`); - } - } - if (powerPreference) { - const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(powerPreference, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`); - } - } } break; case 'webgpu': diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a483ff09f0003..9fc8786192c5c 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; @@ -69,7 +74,7 @@ const initOrt = (numThreads: number, loggingLevel: number): void => { }; /** - * intialize runtime environment. + * initialize runtime environment. * @param env passed in the environment config object. */ export const initRuntime = async(env: Env): Promise => { @@ -253,11 +258,43 @@ export const createSession = async( await Promise.all(loadingPromises); } + for (const provider of options?.executionProviders ?? []) { + const providerName = typeof provider === 'string' ? provider : provider.name; + if (providerName === 'webnn') { + if (wasm.currentContext) { + throw new Error('WebNN execution provider is already set.'); + } + if (typeof provider !== 'string') { + const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption; + const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; + if (context) { + wasm.currentContext = context as MLContext; + } else if (gpuDevice) { + wasm.currentContext = await navigator.ml.createContext(gpuDevice); + } else { + wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference}); + } + } else { + wasm.currentContext = await navigator.ml.createContext(); + } + break; + } + } + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } + // clear current MLContext after session creation + if (wasm.currentContext) { + wasm.currentContext = undefined; + } + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); const enableGraphCapture = !!options?.enableGraphCapture; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9ced89651e844..70728c82e7753 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import type {Tensor} from 'onnxruntime-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -19,7 +24,7 @@ export declare namespace JSEP { type CaptureEndFunction = () => void; type ReplayFunction = () => void; - export interface Module extends WebGpuModule { + export interface Module extends WebGpuModule, WebNnModule { /** * Mount the external data file to an internal map, which will be used during session initialization. * @@ -106,6 +111,13 @@ export declare namespace JSEP { */ jsepOnReleaseSession: (sessionId: number) => void; } + + export interface WebNnModule { + /** + * Active MLContext used to create WebNN EP. + */ + currentContext: MLContext; + } } export interface OrtInferenceAPIs { diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 6718dcb639a47..fbde81524ccec 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -613,7 +613,7 @@ async function main() { // == The Problem == // every time when a test is completed, it will be added to the recovery page list. // if we run the test 100 times, there will be 100 previous tabs when we launch Edge again. - // this run out of resources quickly and fails the futher test. + // this run out of resources quickly and fails the further test. // and it cannot recover by itself because every time it is terminated forcely or crashes. // and the auto recovery feature has no way to disable by configuration/commandline/registry // diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h b/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h index 37ef36ac911b9..c0ae5f8ceba94 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc index 63ec5be8c2900..72c5a813e3d76 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc @@ -57,7 +57,7 @@ void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_outpu if (has_attn_layer_) { // concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) = // p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_ - // The first part is calulated above. Here just add the later. + // The first part is calculated above. Here just add the later. math::GemmEx(CblasNoTrans, CblasNoTrans, batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0}, attn_context_.data(), attn_context_depth_, diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h index aad077489186e..ce91760516cb2 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h @@ -11,7 +11,7 @@ #include "core/common/logging/logging.h" #include "core/framework/allocator.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index af902a713eaa2..a6782daa58f1a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -68,7 +68,6 @@ class AttentionBase { const Tensor* past_seq_len = nullptr) const; int num_heads_; // number of attention heads - int kv_num_heads_; // different for k and v for group query attention bool is_unidirectional_; // whether every token can only attend to previous tokens. std::vector qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute. bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V. diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index fc4905cd31819..dd52001c2ac6b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -3,9 +3,8 @@ #pragma once -#include "attention_base.h" -#include "attention_helper.h" - +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cpu/bert/attention_helper.h" #include "core/common/common.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 6b0c5f395cab0..137612a4bf902 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -3,8 +3,8 @@ #pragma once -#include "attention_base.h" -#include "attention_helper.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cpu/bert/attention_helper.h" #include "core/common/common.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -14,14 +14,31 @@ namespace onnxruntime { namespace contrib { -class GQAAttentionBase : public AttentionBase { +class GQAAttentionBase { protected: - GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) - : AttentionBase(info, require_same_hidden_size) {} + GQAAttentionBase(const OpKernelInfo& info, bool has_local) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); - int local_window_size_; - bool do_rotary_; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; + int local_window_size_; template Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index cad9274e68149..97388a9d6bce8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "group_query_attention.h" -#include "group_query_attention_helper.h" -#include "attention_utils.h" -#include "rotary_embedding.h" -#include "rotary_embedding_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/rotary_helper.h" +#include "contrib_ops/cpu/bert/attention_utils.h" +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" @@ -33,19 +34,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( GroupQueryAttention); template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; -} +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : OpKernel(info), GQAAttentionBase(info, true) {} template Status GroupQueryAttention::Compute(OpKernelContext* context) const { @@ -174,14 +164,14 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (packed_qkv) { const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; - ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV(tp, - parameters.batch_size, - parameters.sequence_length, - parameters.num_heads, - parameters.kv_num_heads, - parameters.head_size, - v_input, - v_rotary)); + ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV(tp, + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.kv_num_heads, + parameters.head_size, + v_input, + v_rotary)); } } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index a7de02452aa58..7ffb72fe55d25 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -263,38 +263,6 @@ Status CheckInputs(const Tensor* query, return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale); } - -template -Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, - int batch_size, - int sequence_length, - int num_heads, - int kv_num_heads, - int head_size, - const T* input, - T* output) { - int seq_stride = head_size; - int head_stride = sequence_length * seq_stride; - int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride; - - const int loop_len = batch_size * sequence_length * kv_num_heads; - const double cost = static_cast(head_size); - ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / kv_num_heads) / sequence_length); - const int s = static_cast((ptr / kv_num_heads) % sequence_length); - const int n = static_cast(ptr % kv_num_heads); - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + block_offset; - T* output_data = output + block_offset; - for (int i = 0; i < head_size; i++) { - output_data[i] = input_data[i]; - } - } - }); - return Status::OK(); -} - } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index b39167f4498e0..9677c30f22d8a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -10,8 +10,12 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/common/safeint.h" +#include "core/platform/env_var_utils.h" #include "core/platform/threadpool.h" +#include "core/mlas/inc/mlas.h" +#include +#include #include #include @@ -39,6 +43,11 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + + const auto& env = Env::Default(); + l2_cache_size_ = env.GetL2CacheSize(); + + disable_flash_ = ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); } template @@ -60,7 +69,6 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { } AttentionParameters parameters = {}; - constexpr float scale = 1.0f; bool past_present_share_buffer = false; ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -74,7 +82,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { ¶meters, num_heads_, mask_filter_value_, - scale, + scale_, is_unidirectional_, past_present_share_buffer, false)); @@ -99,8 +107,14 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { const int v_bias_offset = 2 * qk_hidden_size; // If optional outputs aren't needed, present_k and present_v will be null - std::vector present_k_shape({static_cast(batch_size), static_cast(num_heads_), static_cast(total_kv_sequence_length), static_cast(qk_head_size)}); - std::vector present_v_shape({static_cast(batch_size), static_cast(num_heads_), static_cast(total_kv_sequence_length), static_cast(v_head_size)}); + std::vector present_k_shape({static_cast(batch_size), + static_cast(num_heads_), + static_cast(total_kv_sequence_length), + static_cast(qk_head_size)}); + std::vector present_v_shape({static_cast(batch_size), + static_cast(num_heads_), + static_cast(total_kv_sequence_length), + static_cast(v_head_size)}); Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); @@ -138,6 +152,70 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V)); + if (std::is_same_v && + !disable_flash_ && + !is_unidirectional_ && + key_padding_mask == nullptr && + extra_add_qk == nullptr && + past_key == nullptr && + past_value == nullptr && + present_k == nullptr && + present_v == nullptr && + l2_cache_size_ > 0) { + MlasFlashAttentionThreadedArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.q_sequence_length = q_sequence_length; + args.kv_sequence_length = kv_sequence_length; + args.qk_head_size = qk_head_size; + args.v_head_size = v_head_size; + args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast(qk_head_size)) : scale_; + /* + q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper. + Let M = l2_cache_size / sizeof(float) + In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache: + slice of Q -- [Br, qk_head_size] + slice of K -- [Bc, qk_head_size] + slice of V -- [Bc, v_head_size] + result of QK -- [Br, Bc] + temporary output (same shape as QKV) -- [Br, v_head_size] + The total size of these matrices is (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc + By taking Bc = M / (4 * (qk_head_size + v_head_size)), and Br = min(Bc, qk_head_size + v_head_size), we have + (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc + <= 2 * Bc * (qk_head_size + v_head_size) + Br * Bc + <= 2 * Bc * (qk_head_size + v_head_size) + M/4 + <= 2 * M/4 + M/4 = M * (3/4) + + We leave 1/4 of the L2 cache for + 1. storing small tensors l and m + 2. instruction (code) + */ + args.kv_block_size = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); + args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0 + args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size); + args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length + args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length + + auto* tp = context->GetOperatorThreadPool(); + args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + args.buffer_size_per_thread = (static_cast(args.q_block_size) * 2 + + static_cast(args.q_block_size) * static_cast(args.kv_block_size) + + static_cast(args.q_block_size) * static_cast(args.v_head_size)) * + sizeof(float); + size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count; + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_bytes); + + args.buffer = reinterpret_cast(buffer.get()); + + args.query = Q.Get().Data(); + args.key = K.Get().Data(); + args.value = V.Get().Data(); + args.output = output->MutableData(); + + MlasFlashAttention(&args, tp); + return Status::OK(); + } + // Compute the attention score and apply the score to V return ApplyAttention(Q.GetMutable()->MutableData(), K.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index fb7da78a5c0a5..8a9bef1b2bf0d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -19,6 +19,8 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { int num_heads_; // number of attention heads float mask_filter_value_; bool is_unidirectional_; + bool disable_flash_; + int l2_cache_size_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h new file mode 100644 index 0000000000000..714d962dfb34e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_helper { + +template +Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const T* input, + T* output) { + int seq_stride = head_size; + int head_stride = sequence_length * seq_stride; + int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride; + + const int loop_len = batch_size * sequence_length * kv_num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / kv_num_heads) / sequence_length); + const int s = static_cast((ptr / kv_num_heads) % sequence_length); + const int n = static_cast(ptr % kv_num_heads); + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + for (int i = 0; i < head_size; i++) { + output_data[i] = input_data[i]; + } + } + }); + return Status::OK(); +} + +} // namespace rotary_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cdist.cc b/onnxruntime/contrib_ops/cpu/cdist.cc index d0ed81a9a6dc3..736dbcfede2fc 100644 --- a/onnxruntime/contrib_ops/cpu/cdist.cc +++ b/onnxruntime/contrib_ops/cpu/cdist.cc @@ -67,7 +67,7 @@ static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, co threadpool); #else // the performance of this isn't great as the eigen matmul is single threaded by default - // if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimising this + // if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimizing this // we can look into it in the future. ORT_UNUSED_PARAMETER(threadpool); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index e8ca4370135cc..90a51fda0b188 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -21,6 +21,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); @@ -281,6 +282,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/crop.h b/onnxruntime/contrib_ops/cpu/crop.h index 0fd0a5c49b3b4..3b72ef429c1f7 100644 --- a/onnxruntime/contrib_ops/cpu/crop.h +++ b/onnxruntime/contrib_ops/cpu/crop.h @@ -6,7 +6,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/inverse.cc b/onnxruntime/contrib_ops/cpu/inverse.cc index 355b036e36d0a..54bd99d209574 100644 --- a/onnxruntime/contrib_ops/cpu/inverse.cc +++ b/onnxruntime/contrib_ops/cpu/inverse.cc @@ -53,7 +53,7 @@ struct Inverse::ComputeImpl { void operator()(const Tensor* input, Tensor* output, int64_t batch_num, int64_t rows, int64_t cols) const { auto batch_offset = batch_num * rows * cols; - // Direct cast to half as it just as MLFloat16 containes only uint16_t + // Direct cast to half as it just as MLFloat16 contains only uint16_t const auto* input_data = reinterpret_cast(input->Data() + batch_offset); auto* output_data = reinterpret_cast(output->MutableData() + batch_offset); diff --git a/onnxruntime/contrib_ops/cpu/murmur_hash3.cc b/onnxruntime/contrib_ops/cpu/murmur_hash3.cc index ec504d215920f..000c590f32616 100644 --- a/onnxruntime/contrib_ops/cpu/murmur_hash3.cc +++ b/onnxruntime/contrib_ops/cpu/murmur_hash3.cc @@ -8,6 +8,8 @@ /* Modifications Copyright (c) Microsoft. */ #include "contrib_ops/cpu/murmur_hash3.h" +#include +#include // Platform-specific functions and macros @@ -60,11 +62,31 @@ inline uint64_t rotl64(uint64_t x, int8_t r) { // handle aligned reads, do the conversion here FORCE_INLINE uint32_t getblock(const uint32_t* p, int i) { - return p[i]; + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + return p[i]; + } else { + const uint8_t* c = (const uint8_t*)&p[i]; + return (uint32_t)c[0] | + (uint32_t)c[1] << 8 | + (uint32_t)c[2] << 16 | + (uint32_t)c[3] << 24; + } } FORCE_INLINE uint64_t getblock(const uint64_t* p, int i) { - return p[i]; + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + return p[i]; + } else { + const uint8_t* c = (const uint8_t*)&p[i]; + return (uint64_t)c[0] | + (uint64_t)c[1] << 8 | + (uint64_t)c[2] << 16 | + (uint64_t)c[3] << 24 | + (uint64_t)c[4] << 32 | + (uint64_t)c[5] << 40 | + (uint64_t)c[6] << 48 | + (uint64_t)c[7] << 56; + } } //----------------------------------------------------------------------------- @@ -204,13 +226,35 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const { int input_num_bytes = static_cast(input_element_bytes); ORT_ENFORCE(input_num_bytes % 4 == 0); const auto input_end = input + input_count * input_num_bytes; - while (input != input_end) { - MurmurHash3_x86_32(input, - input_num_bytes, - seed_, - output); - input += input_num_bytes; - ++output; + + if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { + while (input != input_end) { + MurmurHash3_x86_32(input, + input_num_bytes, + seed_, + output); + input += input_num_bytes; + ++output; + } + } else { + // Big endian platform require byte swapping. + auto raw_data = std::make_unique(input_num_bytes); + char* raw_data_ptr = raw_data.get(); + while (input != input_end) { + memcpy(raw_data_ptr, input, input_num_bytes); + char* start_byte = raw_data_ptr; + char* end_byte = start_byte + input_num_bytes - 1; + for (size_t count = 0; count < static_cast(input_num_bytes / 2); ++count) { + std::swap(*start_byte++, *end_byte--); + } + + MurmurHash3_x86_32(raw_data_ptr, + input_num_bytes, + seed_, + output); + input += input_num_bytes; + ++output; + } } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index 7e343d85f4048..b28f3758f89b5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -40,6 +40,13 @@ void Dequantize4BitsKernelReOrder( } T* output_i = output + out_y * out_cols + out_x; uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) { + const uint8_t* c = (const uint8_t*)(&quant_value); + quant_value = (uint32_t)c[0] | + (uint32_t)c[1] << 8 | + (uint32_t)c[2] << 16 | + (uint32_t)c[3] << 24; + } const int remain_x = std::min(8, out_cols - out_x); const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1)); for (int i = 0; i < remain_x; i++) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc index ed6071b40feba..de1798e54874f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc @@ -15,7 +15,7 @@ #include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc new file mode 100644 index 0000000000000..e337f41a8688d --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/sparse/sparse_attention.h" +#include "contrib_ops/cpu/sparse/sparse_attention_helper.h" +#include "contrib_ops/cpu/bert/rotary_helper.h" +#include "contrib_ops/cpu/bert/attention_utils.h" +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/graph/onnx_protobuf.h" +#include "core/common/safeint.h" +#include "core/platform/threadpool.h" + +#include +#include + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_TYPED_KERNEL_EX( + SparseAttention, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + SparseAttention); + +template +SparseAttention::SparseAttention(const OpKernelInfo& info) : OpKernel(info), SparseAttentionBase(info) { +} + +template +Status SparseAttention::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_key = context->Input(3); + const Tensor* past_value = context->Input(4); + const Tensor* block_row_indices = context->Input(5); + const Tensor* block_col_indices = context->Input(6); + const Tensor* total_seq_len = context->Input(7); + const Tensor* total_key_lengths = context->Input(8); + const Tensor* cos_cache = context->Input(9); + const Tensor* sin_cache = context->Input(10); + + SparseAttentionParameters parameters = {}; + + // Parameters from node attribute shall be set before calling CheckInputs + parameters.sparse_block_size = sparse_block_size_; + parameters.num_heads = num_heads_; + parameters.kv_num_heads = kv_num_heads_; + parameters.scale = scale_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + ORT_RETURN_IF_ERROR(sparse_attention_helper::CheckInputs(¶meters, + query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + block_row_indices, + block_col_indices, + total_key_lengths, + total_seq_len)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int q_hidden_size = parameters.hidden_size; + + std::vector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(q_hidden_size); + Tensor* output = context->Output(0, output_shape); + + constexpr bool past_present_share_buffer = true; // Only supports share buffer for past and present for now. + parameters.past_present_share_buffer = past_present_share_buffer; + + int head_size = parameters.head_size; + const int cache_length = past_present_share_buffer + ? parameters.max_cache_sequence_length + : parameters.total_sequence_length; + std::vector present_k_shape({static_cast(batch_size), + static_cast(kv_num_heads_), + static_cast(cache_length), + static_cast(head_size)}); + std::vector present_v_shape({static_cast(batch_size), + static_cast(kv_num_heads_), + static_cast(cache_length), + static_cast(head_size)}); + Tensor* present_key = context->Output(1, present_k_shape); + Tensor* present_value = context->Output(2, present_v_shape); + + // Check past and present share buffer. + if (past_present_share_buffer) { + ORT_ENFORCE(past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw()); + } + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto element_type = DataTypeImpl::GetType(); + OrtValue Q; + OrtValue K; + OrtValue V; + + const bool packed_qkv = parameters.is_packed_qkv; + if (packed_qkv) { + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); + } else { + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH( + allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); + } + + if (do_rotary_) { + rotary_embedding_helper::RotaryParameters rotary_params = {}; + rotary_params.batch_size = batch_size; + rotary_params.sequence_length = sequence_length; + rotary_params.hidden_size = q_hidden_size; + rotary_params.head_size = head_size; + rotary_params.rotary_embedding_dim = parameters.rotary_dim; + rotary_params.num_heads = num_heads_; + rotary_params.max_sequence_length = sequence_length; // unused + rotary_params.seq_stride = head_size; + rotary_params.head_stride = sequence_length * rotary_params.seq_stride; + rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * + rotary_params.head_stride; + rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.transposed = true; + auto* tp = context->GetOperatorThreadPool(); + + const bool is_prompt = parameters.total_sequence_length == parameters.sequence_length; + std::vector pos_ids(is_prompt ? 1 : batch_size * sequence_length); + if (is_prompt) { + pos_ids[0] = static_cast(0); + } else if (sequence_length == 1) { + for (int b = 0; b < batch_size; b++) { + pos_ids[b] = static_cast(total_key_lengths->Data()[b]) - 1; + } + } else { + // This supports a rare case that sequence_length > 1 when it is not prompt. + for (int b = 0; b < batch_size; b++) { + for (int s = 0; s < sequence_length; s++) { + pos_ids[b * sequence_length + s] = static_cast(total_key_lengths->Data()[b]) - + (sequence_length - s); + } + } + } + + const T* q_input; + const T* k_input; + T* q_rotary; + T* k_rotary; + if (packed_qkv) { + OrtValue RotaryQKV; + TensorShape qkv_shape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}); + Tensor::InitOrtValue(element_type, qkv_shape, allocator, RotaryQKV); + q_input = Q.Get().Data(); + k_input = q_input + num_heads_ * sequence_length * head_size; + q_rotary = RotaryQKV.GetMutable()->MutableData(); + k_rotary = q_rotary + num_heads_ * sequence_length * head_size; + Q = RotaryQKV; + } else { + OrtValue RotaryQ; + TensorShape q_shape({batch_size, num_heads_, sequence_length, head_size}); + Tensor::InitOrtValue(element_type, q_shape, allocator, RotaryQ); + OrtValue RotaryK; + TensorShape k_shape({batch_size, kv_num_heads_, sequence_length, head_size}); + Tensor::InitOrtValue(element_type, k_shape, allocator, RotaryK); + q_input = Q.Get().Data(); + k_input = K.Get().Data(); + q_rotary = RotaryQ.GetMutable()->MutableData(); + k_rotary = RotaryK.GetMutable()->MutableData(); + Q = RotaryQ; + K = RotaryK; + } + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, + pos_ids.data(), cos_cache->Data(), + sin_cache->Data(), q_rotary, rotary_interleaved_)); + + rotary_params.num_heads = kv_num_heads_; + rotary_params.hidden_size = parameters.kv_hidden_size; + if (!packed_qkv) { + rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride; + } + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, + pos_ids.data(), cos_cache->Data(), + sin_cache->Data(), k_rotary, rotary_interleaved_)); + if (packed_qkv) { + const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; + T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; + ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV(tp, + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.kv_num_heads, + parameters.head_size, + v_input, + v_rotary)); + } + } + + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + // Compute the attention score and apply the score to V + return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(), + packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, + output, present_key, present_value, + total_key_lengths, block_row_indices, block_col_indices, parameters, allocator, context); +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h new file mode 100644 index 0000000000000..4267d85c0e35d --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/sparse/sparse_attention_base.h" + +namespace onnxruntime { +namespace contrib { + +template +class SparseAttention final : public OpKernel, public SparseAttentionBase { + public: + SparseAttention(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h new file mode 100644 index 0000000000000..cf66bd8407126 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -0,0 +1,390 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_helper.h" + +#include "core/common/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/utils/dump_tensor.h" + +namespace onnxruntime { +namespace contrib { + +class SparseAttentionBase { + protected: + SparseAttentionBase(const OpKernelInfo& info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + int64_t sparse_block_size = 0; + ORT_ENFORCE(info.GetAttr("sparse_block_size", &sparse_block_size).IsOK()); + sparse_block_size_ = static_cast(sparse_block_size); + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + bool do_rotary_; // whether or not to use rotary embeddings + bool rotary_interleaved_; + int sparse_block_size_; + + template + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxN_kvxSxH + const T* V, // V data with shape BxN_kvxSxH + const Tensor* past_key, // past K input tensor + const Tensor* past_value, // past V input tensor + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor + Tensor* present_value, // present V output tensor + const Tensor* total_key_lengths, // total key lengths tensor + const Tensor* block_row_indices, // block row indices + const Tensor* block_col_indices, // block column indices + SparseAttentionParameters& parameters, // attention parameters + AllocatorPtr allocator, // allocator for temporary tensors + OpKernelContext* context) const { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int head_size = parameters.head_size; + const bool packed_qkv = parameters.is_packed_qkv; + + int past_buffer_sequence_length = static_cast(past_key->Shape().GetDims()[2]); + int present_buffer_sequence_length = static_cast(present_key->Shape().GetDims()[2]); + + // Allocate a buffer to store Softmax(QK) + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * parameters.total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + + bool past_present_share_buffer = parameters.past_present_share_buffer; + assert(past_present_share_buffer); + + auto* tp = context->GetOperatorThreadPool(); + + const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + ComputeAttentionProbs( + static_cast(attention_probs), Q, k, total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, + past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, + block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + + // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) + const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + ComputeVxAttentionScore( + output->MutableData(), static_cast(attention_probs), v, + total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + + return Status::OK(); + } + + private: + // Helper function to compute the attention probs. It does 2 things: + // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) + // attention_probs(B, N, S, T) = Softmax(attention_probs) + template + void ComputeAttentionProbs( + T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // query start pointer + const T* K, // key start pointer + const int32_t* total_key_lengths, // total key sequence lengths (past + new) + int batch_size, // batch size + int sequence_length, // sequence length of query or new key + int total_sequence_length, // maximum past_sequence_length + sequence_length + int past_buffer_sequence_length, // sequence length of past_key or past_value + int present_buffer_sequence_length, // sequence length of present_key or present_value + int head_size, // head size of query + const T* past_key, // past key + T* present_key, // present key + bool past_present_share_buffer, // whether past_key and present_key share the buffer + bool packed_qkv, // whether Q, K, V are packed + const int32_t* block_row_indices, // block row indices + const int32_t* block_col_indices, // block column indices + SparseAttentionParameters& parameters, // parameters + ThreadPool* tp) const { // thread pool + const bool is_prompt = (total_sequence_length == sequence_length); + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const int kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H + const size_t kv_input_chunk_length = q_input_chunk_length; + const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; + const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; + + const int loop_len = batch_size * num_heads_; + const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + TensorOpCost unit_cost; + const ptrdiff_t probs_matrix_bytes = + SafeInt(sequence_length) * total_sequence_length * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); + + unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); + unit_cost.bytes_stored += static_cast(probs_matrix_bytes); + + // Cost to concatenate current key to cache (assume past and present share buffer). + double bytes_to_copy_key = static_cast(sizeof(T) * sequence_length * head_size); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + + DUMP_CPU_TENSOR_INIT(); + DUMP_CPU_TENSOR("block_row_indices", block_row_indices, parameters.num_sparse_layout, parameters.stride_row_indices); + DUMP_CPU_TENSOR("block_col_indices", block_col_indices, parameters.num_sparse_layout, parameters.stride_col_indices); + + // Check whether each layout has sparse (has zero in lower triangular) + std::vector layout_has_sparse(parameters.num_sparse_layout); + for (int layout_index = 0; layout_index < parameters.num_sparse_layout; layout_index++) { + int nonzero_elements = block_row_indices[(layout_index + 1) * parameters.stride_row_indices - 1]; + int dense_nonzero = (parameters.stride_row_indices * (parameters.stride_row_indices - 1)) / 2; + layout_has_sparse[layout_index] = nonzero_elements < dense_nonzero; + DUMP_STRING("layout_has_sparse[", layout_index, "]=", layout_has_sparse[layout_index]); + } + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + DUMP_STRING("batch_size=", batch_size, ",num_heads=", num_heads_, ",loop_len=", loop_len, ",begin=", begin, ",end=", end); + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i) / num_heads_; + const int head_index = static_cast(i) % num_heads_; + const int past_seq_len = is_prompt ? 0 : (static_cast(total_key_lengths[batch_index]) - sequence_length); + const size_t past_chunk_length = static_cast(past_seq_len) * head_size; + const int total_seq_len = total_key_lengths[batch_index]; + + const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; + T* output = attention_probs + output_offset; + + const T* k; + if (packed_qkv) { + k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + k = K + kv_input_chunk_length * (i / kv_num_heads_factor); + } + + // Concatenate past_k + k -> present_k + // TODO: avoid copying mutiple times for a group. + k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + const T* q; + if (packed_qkv) { + q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index; + } else { + q = Q + q_input_chunk_length * i; + } + + DUMP_STRING("i=", i, ",batch_index=", batch_index, ",head_index=", head_index, + ",past_seq_len=", past_seq_len, ",total_seq_len=", total_seq_len, ",packed_qkv=", packed_qkv); + DUMP_CPU_TENSOR("Q", q, sequence_length, head_size); + DUMP_CPU_TENSOR("K", k, total_seq_len, head_size); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, alpha, q, + head_size, k, head_size, 0.0f /*bata*/, output, total_seq_len, + nullptr); + + DUMP_CPU_TENSOR("QK", output, sequence_length, total_seq_len); + + // Compute Softmax for causal and output result in place. + T* output_softmax = output; + + int layout_id = head_index % parameters.num_sparse_layout; + bool is_sparse_layout = layout_has_sparse[layout_id]; + + DUMP_STRING("layout_id=", layout_id, ",is_sparse_layout=", is_sparse_layout); + + if (!is_sparse_layout) { // dense + for (int q_id = 0; q_id < sequence_length; q_id++) { + int causal_length = past_seq_len + q_id + 1; + ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr); + for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) { + output_softmax[remain_seq_id] = 0.f; + } + output_softmax += total_seq_len; + } + } else { // sparse + int q_id = 0; + bool has_sparse = false; + std::vector mask(parameters.max_sequence_length); + + const int32_t* layout_row_indices = block_row_indices + layout_id * parameters.stride_row_indices; + const int32_t* layout_col_indices = block_col_indices + layout_id * parameters.stride_col_indices; + do { + int q_abs_position = past_seq_len + q_id; + int causal_length = q_abs_position + 1; + + // Update mask when query token is the first or at the boundary of sparse block. + if (q_id == 0 || q_abs_position % parameters.sparse_block_size == 0) { + int row_in_sparse_layout = q_abs_position / parameters.sparse_block_size; + int start_in_col_indices = layout_row_indices[row_in_sparse_layout]; + int end_in_col_indices = layout_row_indices[row_in_sparse_layout + 1]; + int nonzero_blocks = end_in_col_indices - start_in_col_indices; + has_sparse = (nonzero_blocks != row_in_sparse_layout + 1); + + DUMP_STRING("q_id=", q_id, + ",q_abs_position=", q_abs_position, + ",sparse_block_size=", parameters.sparse_block_size, + ",row_in_sparse_layout=", row_in_sparse_layout, + ",start_in_col_indices=", start_in_col_indices, + ",end_in_col_indices=", end_in_col_indices, + ",nonzero_blocks=", nonzero_blocks, + ",has_sparse=", has_sparse); + + // Expand attention mask for current row of q_id + if (has_sparse) { + int block_aligned_length = q_abs_position / parameters.sparse_block_size * parameters.sparse_block_size + parameters.sparse_block_size; + DUMP_STRING("block_aligned_length=", block_aligned_length); + + std::fill_n(mask.begin(), block_aligned_length, 0); + for (int j = start_in_col_indices; j < end_in_col_indices; j++) { + int col_in_sparse_layout = layout_col_indices[j]; + + int offset = col_in_sparse_layout * parameters.sparse_block_size; + for (int s = 0; s < parameters.sparse_block_size; s++, offset++) { + mask[offset] = 1; + } + } + + DUMP_CPU_TENSOR("mask", mask, block_aligned_length); + } + } + + // Update inline according to attention mask. + if (has_sparse) { + for (int s = 0; s < causal_length; s++) { + if (mask[s] == 0) + output_softmax[s] = std::numeric_limits::lowest(); + } + } + ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr); + + for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) { + output_softmax[remain_seq_id] = 0.f; + } + + output_softmax += total_seq_len; + q_id++; + + } while (q_id < sequence_length); + } + + DUMP_CPU_TENSOR("softmax", output, sequence_length, total_seq_len); + } + }); + } + + template + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + const T* attention_probs, // Softmax of Q*K' with size BxNxSxT + const T* V, // v value with size BxN_kvxSxH + const int32_t* total_key_lengths, // total sequence lengths + int batch_size, // batch size + int sequence_length, // sequence length + int total_sequence_length, // maximum past_sequence_length + sequence_length + int past_buffer_sequence_length, // sequence length in past state + int present_buffer_sequence_length, // sequence length in past state + int head_size, // head size of Q, K, V + int hidden_size, // hidden size of Output + const T* past_value, // past value only + T* present_value, // present value only + bool past_present_share_buffer, // whether past_key and present_key share the buffer + bool packed_qkv, // whether Q, K, V are packed + ThreadPool* tp) const { + const bool is_prompt = sequence_length == total_sequence_length; + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const int kv_num_heads_factor = num_heads_ / kv_num_heads_; + + const int kv_input_chunk_length = sequence_length * head_size; // S x H + const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; + const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; + + // The cost of Gemm. + TensorOpCost unit_cost; + // Here we use total_sequence_length to estimate total_key_lengths[batch_index] used in GEMM. + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(sequence_length + head_size) * + total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T)); + + if (present_value) { + double bytes_to_copy_value = static_cast(sizeof(T) * sequence_length * head_size); + unit_cost.bytes_loaded += bytes_to_copy_value; + unit_cost.bytes_stored += bytes_to_copy_value; + } + + DUMP_CPU_TENSOR_INIT(); + + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + DUMP_STRING("batch_size=", batch_size, ",num_heads=", num_heads_, ",begin=", begin, ",end=", end); + + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + const int past_seq_len = is_prompt ? 0 : (static_cast(total_key_lengths[batch_index]) - sequence_length); + const size_t past_chunk_length = static_cast(past_seq_len) * head_size; + const int total_seq_len = total_key_lengths[batch_index]; + + DUMP_STRING("i=", i, ",batch_index=", batch_index, ",head_index=", head_index, + ",past_seq_len=", past_seq_len, ",total_seq_len=", total_seq_len, ",packed_qkv=", packed_qkv); + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + + // Concatenate past_v + v -> present_v + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + + DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size); + + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_seq_len * i; + + DUMP_CPU_TENSOR("attention_probs", attention_probs + attention_probs_offset, sequence_length, total_seq_len); + + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len, + 1.f, /*alpha*/ + attention_probs + attention_probs_offset, total_seq_len, v, + head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); + + DUMP_CPU_TENSOR("out", attention_probs + attention_probs_offset, sequence_length, head_size); + } + }); + } +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h similarity index 98% rename from onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h rename to onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h index a5f1d50e618af..ca69370b4ce17 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h @@ -21,7 +21,7 @@ Status CheckInputs(void* params, const Tensor* sin_cache, const Tensor* block_row_indices, const Tensor* block_col_indices, - const Tensor* seqlens_k_total, + const Tensor* total_key_lengths, const Tensor* total_seq_len) { // No packing for q/k/v: // query (batch_size, sequence_length, num_heads * head_size) @@ -36,7 +36,7 @@ Status CheckInputs(void* params, // past_value (batch_size, kv_num_heads, max_cache_sequence_length, head_size) // block_row_indices (num_layout, max_blocks + 1), where max_blocks = max_sequence_length / sparse_block_size // block_col_indices (num_layout, max_nnz) - // seqlens_k_total (batch_size) when do_rotary is True, optional otherwise + // total_key_lengths (batch_size) // total_seq_len (1) // cos_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true. // sin_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true. @@ -197,7 +197,7 @@ Status CheckInputs(void* params, } // Check the shape of total_key_sequence_lengths. We do not check the values here. - const auto& k_len_dim = seqlens_k_total->Shape().GetDims(); + const auto& k_len_dim = total_key_lengths->Shape().GetDims(); if (k_len_dim.size() != 1 && k_len_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_total_sequence_lengths must have shape (batch_size)."); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 688b7d6341ae8..12fae5ccf0983 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -26,7 +26,7 @@ #include "core/framework/TensorSeq.h" #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/beam_search.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/sequences.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 8c1ceec62fec5..3bdb274d7d5ab 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -8,7 +8,7 @@ #include "core/providers/cpu/math/softmax_shared.h" #include "core/providers/cpu/generator/random.h" #include "core/common/safeint.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_scorer.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 8f778c57bb416..7f99a808f4428 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -10,7 +10,7 @@ #endif #include -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/generation_shared.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 2b8b26f0a06fc..30bf3aa0a1212 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/allocator.h" #include "core/framework/ort_value.h" #include "contrib_ops/cpu/utils/debug_macros.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 788eab1b672d0..a107889afd76a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -26,7 +26,7 @@ #include "core/framework/session_options.h" #include "core/framework/TensorSeq.h" #include "core/framework/ort_value.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/greedy_search.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/sequences.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 99c9474a2ca4d..440a07e14a6cc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/generation_shared.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index 83aa99ff4d50a..d675ba742e03b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -7,7 +7,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_base.h" #include "contrib_ops/cpu/utils/dump_tensor.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 487a35c55a85f..bde591626bb83 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/allocator.h" #include "core/framework/feeds_fetches_manager.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 443d69d494703..34a1da99316a2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cpu/utils/dump_tensor.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 4264ceff042fb..9037e58aaf31f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/utils/dump_tensor.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 887a2b5cc5196..51473c0c931b9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index f3da01c952f55..bf866d67ffc0d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_whisper_decoder.h" #include "contrib_ops/cpu/utils/dump_tensor.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc index 8480edc405e53..ff5f256e7bb70 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" #include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h" diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 3c255879df199..2782a59d4326d 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -37,6 +37,8 @@ class IConsoleDumper { virtual void Print(const char* name, int index, bool end_line) const = 0; virtual void Print(const char* name, const std::string& value, bool end_line) const = 0; + virtual void Print(const std::string& value) const = 0; + protected: bool is_enabled_; }; diff --git a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h index 37a9b0160ade9..d5cbaa0a3e6b7 100644 --- a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h +++ b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common/make_string.h" // #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) @@ -14,9 +15,11 @@ #if DUMP_CPU_TENSOR_LEVEL > 0 #define DUMP_CPU_TENSOR_INIT() onnxruntime::contrib::CpuTensorConsoleDumper cpu_dumper #define DUMP_CPU_TENSOR(...) cpu_dumper.Print(__VA_ARGS__) +#define DUMP_STRING(...) cpu_dumper.Print(::onnxruntime::MakeString(__VA_ARGS__)) #else #define DUMP_CPU_TENSOR_INIT() #define DUMP_CPU_TENSOR(...) +#define DUMP_STRING(...) #endif #if DUMP_CPU_TENSOR_LEVEL > 1 diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 3a5deef35d6d6..87a9cd3965763 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -1,18 +1,38 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include "contrib_ops/cpu/utils/dump_tensor.h" +#include +#include +#include +#include #include "core/framework/print_tensor_utils.h" #include "contrib_ops/cpu/utils/debug_macros.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { #if DUMP_CPU_TENSOR_LEVEL > 0 +// Environment variable to enable/disable dumping +constexpr const char* kEnableCpuTensorDumper = "ORT_ENABLE_CPU_DUMP"; + +// Environment variable to enable/disable dumping thread id +constexpr const char* kDumpThreadId = "ORT_DUMP_THREAD_ID"; + +// To avoid dumping at the same time from multiple threads +static std::mutex s_mutex; + +static bool s_output_thread_id = false; + template void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1) { + std::unique_lock lock(s_mutex); + + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + if (nullptr != name) { std::cout << std::string(name) << std::endl; } @@ -26,6 +46,11 @@ void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1) { template void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) { + std::unique_lock lock(s_mutex); + + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + if (nullptr != name) { std::cout << std::string(name) << std::endl; } @@ -93,6 +118,21 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) { DumpCpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +CpuTensorConsoleDumper::CpuTensorConsoleDumper() { + is_enabled_ = ParseEnvironmentVariableWithDefault(kEnableCpuTensorDumper, 1) != 0; + s_output_thread_id = ParseEnvironmentVariableWithDefault(kDumpThreadId, 0) != 0; +} + +void CpuTensorConsoleDumper::Print(const std::string& value) const { + if (!is_enabled_) + return; + + std::unique_lock lock(s_mutex); + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + std::cout << value << std::endl; +} + void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { if (!is_enabled_) return; @@ -185,6 +225,8 @@ void CpuTensorConsoleDumper::Print(const char* name, const OrtValue& value) cons void CpuTensorConsoleDumper::Print(const char* name, int index, bool end_line) const { if (!is_enabled_) return; + + std::unique_lock lock(s_mutex); std::cout << std::string(name) << "[" << index << "]"; if (end_line) { @@ -196,6 +238,7 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b if (!is_enabled_) return; + std::unique_lock lock(s_mutex); std::cout << std::string(name) << "=" << value; if (end_line) { @@ -204,6 +247,12 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b } #else +CpuTensorConsoleDumper::CpuTensorConsoleDumper() { +} + +void CpuTensorConsoleDumper::Print(const std::string&) const { +} + void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const { } @@ -254,7 +303,6 @@ void CpuTensorConsoleDumper::Print(const char*, int, bool) const { void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } - #endif } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index d902806fd0d18..f102eae6ec709 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include "core/framework/ort_value.h" #include "contrib_ops/cpu/utils/console_dumper.h" @@ -11,7 +12,7 @@ namespace contrib { class CpuTensorConsoleDumper : public IConsoleDumper { public: - CpuTensorConsoleDumper() = default; + CpuTensorConsoleDumper(); virtual ~CpuTensorConsoleDumper() {} void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; @@ -33,6 +34,14 @@ class CpuTensorConsoleDumper : public IConsoleDumper { void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; void Print(const char* name, const std::string& value, bool end_line) const override; + + void Print(const std::string& value) const override; + + // Output a vector with a threshold for max number of elements to output. Default threshold 0 means no limit. + template + void Print(const char* name, const std::vector& vec, size_t max_count = 0) const { + this->Print(name, vec.data(), 1, static_cast(std::min(max_count, vec.size()))); + } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3e6edb162360d..d9907f09121d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -145,7 +145,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = num_splits; + parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 5c13bb731ce28..150079cdf157a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -334,6 +334,11 @@ Status FlashAttention( contrib::AttentionParameters& parameters, AttentionData& data, float scale) { + ORT_UNUSED_PARAMETER(device_prop); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(parameters); + ORT_UNUSED_PARAMETER(data); + ORT_UNUSED_PARAMETER(scale); return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); } #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 36fd7708de04b..56836bdda197c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/allocator.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -176,6 +176,13 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const T* qkv_buffer, T* present); +template +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu index 1466f5fcfe0be..66e56e701c558 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -12,23 +12,27 @@ namespace cuda { template __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) { const int h = threadIdx.x; const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; + + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + if (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; } } template __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int* in_seqlens_offset, const int* out_seqlens_offset) { // Use when (H*)*num_heads > 1024 int h = threadIdx.x; const int n = threadIdx.y; @@ -37,9 +41,12 @@ __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, const int h_step = blockDim.x; + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + while (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; h += h_step; } @@ -77,10 +84,11 @@ template using ToBytes = typename ToByteType::T; template -Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block) { +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block) { int batch_size = in_shape.x; int num_heads = in_shape.y; int sequence_length = in_shape.z; @@ -102,11 +110,13 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else if (0 == (head_size % 2)) { // pack 2 element together using Bytes = ToBytes; @@ -120,27 +130,44 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else { using Bytes = ToBytes; if (head_size * num_heads <= max_threads_per_block) { const dim3 block(head_size, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } return CUDA_CALL(cudaGetLastError()); } +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block) { + const int* in_seqlens_offset = nullptr; + const int* out_seqlens_offset = nullptr; + return LaunchStridedCopy( + stream, in, in_shape, in_strides, in_seqlens_offset, + out, out_strides, out_seqlens_offset, + max_threads_per_block); +} + template Status LaunchStridedCopy( cudaStream_t stream, const float* in, int4 in_shape, longlong4 in_strides, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index ceee17c2a2d01..ee49f362564a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -282,7 +282,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) - // calcualte k, v + // calculate k, v n = 2 * hidden_size; k = hidden_size; if (!has_layer_state_ || !use_past_) { diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index ae53eca541fa5..8a17e945df3f3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -141,7 +141,7 @@ __global__ void EmbedLayerNormKernel( } __syncthreads(); - // 2. load pos/segment/word embeddings and add them toghether + // 2. load pos/segment/word embeddings and add them together // offset into embeddings is given by word_id * hidden_size const int position_offset = position_id * hidden_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h new file mode 100644 index 0000000000000..5d94190ecbeb9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h @@ -0,0 +1,67 @@ +#include +#include +#include +#include +#include "utils.h" + +namespace onnxruntime { +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Alibi { + const float alibi_slope; + const int max_seqlen_k, max_seqlen_q; + + __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + : alibi_slope(alibi_slope), max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q){}; + + template + __forceinline__ __device__ void apply_alibi(Tensor& tensor, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } + } +}; + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h index 811b1be7d4315..dde6143153e8e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h @@ -12,22 +12,36 @@ struct BlockInfo { template __device__ BlockInfo(const Params& params, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), - sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), - actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative + ? -1 + : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr + ? params.seqlen_q + : params.cu_seqlens_q[bidb + 1] - sum_s_q) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. , - seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), - actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr + ? params.seqlen_k + : (params.is_seqlens_k_cumulative + ? params.cu_seqlens_k[bidb + 1] - sum_s_k + : params.cu_seqlens_k[bidb])), + actual_seqlen_k(params.seqused_k + ? params.seqused_k[bidb] + : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } template - inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ + index_t + q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template - inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ + index_t + k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; } @@ -41,6 +55,5 @@ struct BlockInfo { //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index cbe536c6ce45a..0463d3795b446 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -16,7 +16,7 @@ constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - using index_t = uint32_t; + using index_t = int64_t; // The QKV matrices. void* __restrict__ q_ptr = nullptr; void* __restrict__ k_ptr = nullptr; @@ -79,6 +79,9 @@ struct Flash_fwd_params : public Qkv_params { int* __restrict__ cu_seqlens_q = nullptr; int* __restrict__ cu_seqlens_k = nullptr; + // If provided, the actual length of each k sequence. + int* __restrict__ seqused_k = nullptr; + int* __restrict__ blockmask = nullptr; // The K_new and V_new matrices. @@ -100,6 +103,11 @@ struct Flash_fwd_params : public Qkv_params { // The indices to index into the KV cache. int* __restrict__ cache_batch_idx = nullptr; + // Paged KV cache + int* __restrict__ block_table = nullptr; + index_t block_table_batch_stride = 0; + int page_block_size = 0; + // Local window size int window_size_left = -1; int window_size_right = -1; @@ -115,6 +123,9 @@ struct Flash_fwd_params : public Qkv_params { int num_splits = 0; // For split-KV version + void* __restrict__ alibi_slopes_ptr = nullptr; + index_t alibi_slopes_batch_stride = 0; + const cudaDeviceProp* dprops = nullptr; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 0f58a74c4d2fd..90f0b94cafce8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -31,6 +31,7 @@ void set_params_fprop(Flash_fwd_params& params, void* out, void* cu_seqlens_q_d, void* cu_seqlens_k_d, + void* seqused_k, void* p_d, void* softmax_lse_d, float softmax_scale, @@ -82,6 +83,7 @@ void set_params_fprop(Flash_fwd_params& params, params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); // P = softmax(QK^T) params.p_ptr = p_d; @@ -105,7 +107,7 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; - // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates // local and causal, meaning when we have local window size params.is_causal = is_causal; if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { @@ -123,24 +125,25 @@ void set_params_fprop(Flash_fwd_params& params, params.is_seqlens_k_cumulative = true; } -size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { +size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; return bytes; } -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { +size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) { size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; return bytes; } -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { +size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, + size_t seqlen_q, size_t head_size_rounded) { size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; return bytes; } void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { + HEADDIM_SWITCH(params.d, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { @@ -156,15 +159,15 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits) { +size_t num_splits_heuristic(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, + size_t head_size, size_t num_SMs, size_t max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); - const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + const size_t block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const size_t num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. - const int num_m_blocks = (seqlen_q + 64 - 1) / 64; - int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + const size_t num_m_blocks = (seqlen_q + 64 - 1) / 64; + size_t batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; // If we have enough to almost fill the SMs, then just use 1 split if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; @@ -173,15 +176,15 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + auto ceildiv = [](size_t a, size_t b) { return (a + b - 1) / b; }; // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks // (i.e. it's 11 splits anyway). // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + auto is_split_eligible = [&ceildiv, &num_n_blocks](size_t num_splits) { return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); }; - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + for (size_t num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { efficiency.push_back(0.f); } else { @@ -194,7 +197,7 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea efficiency.push_back(eff); } } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + for (size_t num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { continue; } @@ -207,25 +210,39 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea } // Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) -std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, - int head_size, int num_SMs) { - int max_splits = 128; +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, + size_t num_heads, size_t head_size, size_t num_SMs) { + size_t max_splits = 128; // split kv buffers - int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, - num_SMs, max_splits); + size_t num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); if (num_splits > 1) { // softmax_lse_accum buffer - int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + size_t softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); - int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + const size_t head_size_rounded = round_multiple(head_size, 32); + size_t out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; } else { return {0, 0, 0}; } } +// void set_params_alibi(Flash_fwd_params ¶ms, void* alibi_slopes, int batch_size, int num_heads){ +// if (alibi_slopes != nullptr) { +// // TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); +// // CHECK_DEVICE(alibi_slopes); +// // TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); +// // TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) +// || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); +// params.alibi_slopes_ptr = alibi_slopes; +// params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? num_heads : 0; // TODO: flag for bool +// } else { +// params.alibi_slopes_ptr = nullptr; +// } +// } + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -262,7 +279,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, q, k, v, out, /*cu_seqlens_q*/ nullptr, /*cu_seqlens_k*/ nullptr, - nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, softmax_lse, softmax_scale, is_causal, @@ -289,6 +307,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, params.oaccum_ptr = nullptr; } + params.alibi_slopes_ptr = nullptr; + run_mha_fwd(params, stream); return Status::OK(); } @@ -301,6 +321,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, void* out, // half (total_q, num_heads, head_size) int* cu_seqlens_q, // int (batch_size + 1) int* cu_seqlens_k, // int (batch_size + 1) + void* seqused_k, // batch_size; If given, use this many elements of each batch element's keys. + int* block_table, // batch_size x max_num_blocks_per_seq void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) int batch_size, int num_heads, @@ -310,11 +332,14 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int max_seqlen_k, float softmax_scale, bool is_causal, - bool is_bf16) { + bool is_bf16, + int max_num_blocks_per_seq, + int page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + const bool paged_KV = block_table != nullptr; Flash_fwd_params params; set_params_fprop(params, @@ -326,7 +351,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, q, k, v, out, cu_seqlens_q, cu_seqlens_k, - nullptr, + seqused_k, + /*p_ptr=*/nullptr, softmax_lse, softmax_scale, is_causal, @@ -340,11 +366,25 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, params.oaccum_ptr = nullptr; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; + params.alibi_slopes_ptr = nullptr; + if (paged_KV) { + params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table_batch_stride = max_num_blocks_per_seq; + // params.num_blocks = num_blocks; + params.page_block_size = page_block_size; + params.k_batch_stride = page_block_size * num_heads_k * head_size; + params.v_batch_stride = page_block_size * num_heads_k * head_size; + } else { + params.block_table = nullptr; + params.block_table_batch_stride = 0; + // params.num_blocks = 0; + params.page_block_size = 1; + } run_mha_fwd(params, stream); return Status::OK(); } -bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) { +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) { bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; bool is_sm90 = dprops.major == 9 && dprops.minor == 0; return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); @@ -364,6 +404,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* seqlens_k_, // batch_size void* rotary_cos, // seqlen_ro x (rotary_dim / 2) void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, int num_heads_k, @@ -381,11 +422,14 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded int local_window_size, bool is_rotary_interleaved, - bool is_packed_qkv) { + bool is_packed_qkv, + int max_num_blocks_per_seq, + int page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + const bool paged_KV = block_table != nullptr; // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; @@ -398,6 +442,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, q, kcache, vcache, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, /*p_ptr=*/nullptr, softmax_lse, softmax_scale, @@ -461,6 +506,21 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.oaccum_ptr = nullptr; } + params.alibi_slopes_ptr = nullptr; + if (paged_KV) { + params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table_batch_stride = max_num_blocks_per_seq; + // params.num_blocks = num_blocks; + params.page_block_size = page_block_size; + params.k_batch_stride = page_block_size * num_heads_k * head_size; + params.v_batch_stride = page_block_size * num_heads_k * head_size; + } else { + params.block_table = nullptr; + params.block_table_batch_stride = 0; + // params.num_blocks = 0; + params.page_block_size = 1; + } + // Only split kernel supports appending to KV cache run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 24891bcc4d499..4c59561449851 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -66,6 +66,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, void* out, // half (total_q, num_heads, v_head_size) int* cu_seqlens_q, // int (batch_size + 1) int* cu_seqlens_k, // int (batch_size + 1) + void* seqused_k, // batch_size; If given, only this many elements of each batch element's keys are used. + int* block_table, // batch_size x max_num_blocks_per_seq void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) int batch_size, int num_heads, @@ -75,7 +77,9 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int max_seqlen_k, float softmax_scale, bool is_causal, - bool is_bf16); + bool is_bf16, + int max_num_blocks_per_seq = 0, + int page_block_size = 1); Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -87,8 +91,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size - void* rotary_sin, // seqlen_ro x (rotary_dim / 2) void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, int num_heads_k, @@ -106,14 +111,16 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded int local_window_size = -1, bool is_rotary_interleaved = false, - bool is_packed_qkv = false); + bool is_packed_qkv = false, + int max_num_blocks_per_seq = 0, + int page_block_size = 1); -size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); +size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); -std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, - int head_size, int num_SMs); +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, + size_t head_size, size_t num_SMs); -bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k); } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu index 431eb2bd69def..1ef1ce251ecba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); + run_mha_fwd_hdim128(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu index 0cb48272dec3f..52ff792c6edcb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); + run_mha_fwd_hdim160(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu index 142e922f71031..3bdc5e4b0443f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); + run_mha_fwd_hdim192(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu index 2142b1c343110..e4972875d0512 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); + run_mha_fwd_hdim224(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu index 751363184e23a..59568b0bb03ce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); + run_mha_fwd_hdim256(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu index ebf0236435971..ad3d4df7dfc85 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); + run_mha_fwd_hdim32(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu index 166bb2a0072f4..006416458c91b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); + run_mha_fwd_hdim64(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu index c8760b8168db6..d5a273a3f4163 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_bf16_sm80.cu @@ -8,9 +8,9 @@ namespace onnxruntime { namespace flash { -template<> +template <> void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); + run_mha_fwd_hdim96(params, stream); } } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 1fac03882b4b1..1c8a93674a80b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -7,20 +7,26 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-variable" #pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4267) +#pragma warning(disable : 4100) +#pragma warning(disable : 4101) // equivalent to GCC's -Wunused-variable +#pragma warning(disable : 4189) // equivalent to GCC's -Wunused-but-set-variable #endif -#include #include #include #include #include -#include #include "contrib_ops/cuda/bert/flash_attention/block_info.h" #include "contrib_ops/cuda/bert/flash_attention/kernel_traits.h" #include "contrib_ops/cuda/bert/flash_attention/utils.h" #include "contrib_ops/cuda/bert/flash_attention/softmax.h" +#include "contrib_ops/cuda/bert/flash_attention/mask.h" +#include "contrib_ops/cuda/bert/flash_attention/rotary.h" namespace onnxruntime { namespace flash { @@ -28,60 +34,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, - Tensor2& acc_o, float softmax_scale_log2) { - if (Is_first) { - flash::template reduce_max(scores, scores_max); - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - flash::reduce_sum(scores, scores_sum); - } else { - cute::Tensor scores_max_prev = make_fragment_like(scores_max); - cute::copy(scores_max, scores_max_prev); - flash::template reduce_max(scores, scores_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); -#pragma unroll - for (int mi = 0; mi < cute::size(scores_max); ++mi) { - float scores_max_cur = !Check_inf - ? scores_max(mi) - : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scores_sum(mi) *= scores_scale; -#pragma unroll - for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale; - } - } - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); - flash::reduce_sum(scores, scores_sum_cur); -#pragma unroll - for (int mi = 0; mi < cute::size(scores_sum); ++mi) { - scores_sum(mi) += scores_sum_cur(mi); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void write_softmax_to_gmem( - cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { - // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) - cute::Layout l = tOrP.layout(); - cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); - CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); - CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); -#pragma unroll - for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { - cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -97,25 +50,30 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; + // constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + const int n_block_min = !Is_local + ? 0 + : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - // We exit early and write 0 to gO and gLSE. + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= n_block_min) { - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -151,53 +109,55 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - cute::Shape, cute::Int>{}, - make_stride(params.q_row_stride, _1{})); - cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - cute::Shape, cute::Int>{}, - make_stride(params.k_row_stride, _1{})); - cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - cute::Shape, cute::Int>{}, - make_stride(params.v_row_stride, _1{})); - cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), - cute::Shape, cute::Int>{}, - make_stride(params.seqlen_k_rounded, _1{})); - - cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; // We move K and V to the last block. + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; - cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); - cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; - auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); - cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling @@ -205,51 +165,46 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // TODO: this might need to change if we change the mma instruction in SM70 - cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); - cute::Tensor scores_sum = make_fragment_like(scores_max); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // // PREDICATES // // Construct identity layout for sQ and sK - cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k - cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); - cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll - for (int k = 0; k < cute::size(tQpQ); ++k) { + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < cute::size(tKVpKV); ++k) { + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } // Prologue - cute::Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); @@ -260,27 +215,34 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi if (Kernel_traits::Share_Q_K_smem) { flash::cp_async_wait<0>(); __syncthreads(); - cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { flash::cp_async_wait<1>(); __syncthreads(); - cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } clear(acc_o); + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr + ? 0.0f + : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, + params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -292,22 +254,23 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) + ? cute::ceil_div(kBlockM, kBlockN) + : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); @@ -316,31 +279,13 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi smem_thr_copy_Q, smem_thr_copy_K); // if (cute::thread0()) { print(acc_s); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { - flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); - } - } else { - // I can't get the stride from idx_row - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); flash::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -348,22 +293,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - // Convert scores from fp32 to fp16/bf16 - cute::Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - // if (Return_softmax) { - // cute::Tensor tOrP_copy = make_fragment_like(tOrP); - // copy(tOrP, tOrP_copy); - // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - // tPgP.data() = tPgP.data() + (-kBlockN); - // } - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -374,13 +312,11 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { - cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -390,64 +326,36 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - cute::Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - // if (Return_softmax) { - // cute::Tensor tOrP_copy = make_fragment_like(tOrP); - // copy(tOrP, tOrP_copy); - // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - // tPgP.data() = tPgP.data() + (-kBlockN); - // } - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - cute::Tensor lse = make_fragment_like(scores_sum); -#pragma unroll - for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; -#pragma unroll - for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; - } - } + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax); // Convert acc_o from fp32 to fp16/bf16 - cute::Tensor rO = flash::convert_type(acc_o); - cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { @@ -456,33 +364,35 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - cute::Shape, cute::Int>{}, - make_stride(params.o_row_stride, _1{})); - cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - cute::Shape>{}, cute::Stride<_1>{}); + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) - cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); - cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); + Tensor tOrO = make_tensor(shape(tOgO)); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(cute::size<0>(taccOcO))::value == 4); + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll - for (int mi = 0; mi < cute::size(lse); ++mi) { + for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); @@ -491,13 +401,13 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Construct identity layout for sO - cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll - for (int k = 0; k < cute::size(tOpO); ++k) { + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } @@ -508,8 +418,10 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, + const int m_block, const int n_split_idx, + const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; @@ -527,8 +439,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons using GmemTiledCopyO = std::conditional_t< !Split, - typename Kernel_traits::GmemTiledCopyOaccum, - typename Kernel_traits::GmemTiledCopyO>; + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum>; using ElementO = std::conditional_t; const BlockInfo binfo(params, bidb); @@ -591,15 +503,29 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); + const int* block_table = params.block_table == nullptr + ? nullptr + : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr + ? 0 + : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr + ? 0 + : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); @@ -649,10 +575,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - // // PREDICATES // @@ -732,10 +654,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { flash::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { flash::copy_w_min_idx( @@ -757,19 +680,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); } } - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } } // Need this before we can read in K again, so that we'll see the updated K values. __syncthreads(); - if (n_block_max > n_block_copy_min) { - tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; - tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; - } + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; } // Read Q from gmem to smem, optionally apply rotary embedding. - Tensor tQrQ = make_fragment_like(tQgQ); if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, @@ -818,6 +752,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons clear(acc_o); + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f + : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -838,7 +778,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Advance gV if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads @@ -852,22 +800,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons smem_thr_copy_Q, smem_thr_copy_K); // if (cute::thread0()) { print(acc_s); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { - flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); - } - } else { - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); flash::cp_async_wait<0>(); __syncthreads(); @@ -876,7 +810,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons if (n_block > n_block_min) { // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -885,18 +827,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -912,7 +853,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); @@ -924,51 +873,36 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons __syncthreads(); if (n_block > n_block_min) { // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - // if (cute::thread0()) { print(acc_o_rowcol); } - Tensor lse = make_fragment_like(scores_sum); -#pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; -#pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; - } - } - // if (cute::thread0()) { print(lse); } - // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning @@ -1041,13 +975,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); - // __syncthreads(); - // if (cute::thread0()) { print(tOgOaccum); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1063,12 +995,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1077,7 +1009,7 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1121,7 +1053,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { if (row < kMaxSplits) { sLSE[row][col] = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } __syncthreads(); @@ -1129,7 +1061,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // 16 rows, so each time we load we can load 8 rows). + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; // static_assert(kThreadsPerSplit <= 32); static_assert(kRowsPerLoadTranspose <= 32); @@ -1218,11 +1150,11 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); } } - // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } } tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; } - // if (cute::thread0()) { print(tOrO); } + // if (cute::thread0()) { print_tensor(tOrO); } Tensor rO = flash::convert_type(tOrO); // Write to gO @@ -1256,4 +1188,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { #if defined(__GNUC__) #pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 87d189a803f8a..b1941df75be2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,33 +10,42 @@ namespace onnxruntime { namespace flash { -template -__global__ void flash_fwd_kernel(Flash_fwd_params params) { - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ #else - (void)params; +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ + template \ + __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { +#if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + flash::compute_attn(params); +#else + FLASH_UNSUPPORTED_ARCH #endif } -template -__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { +#if defined(ARCH_SUPPORTS_FLASH) + flash::compute_attn_splitkv(params); #else - (void)params; + FLASH_UNSUPPORTED_ARCH #endif } -template -__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); -#else - (void)params; -#endif } template @@ -52,25 +61,27 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<(smem_size), stream>>>(params); + }); }); }); }); @@ -87,20 +98,25 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_Local_Const, [&] { + BOOL_SWITCH(params.num_splits > 1, SplitConst, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV_Const, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // If Append_KV_Const, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_Local_Const, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, Is_Local_Const && !Is_causal, Has_alibi, + IsEvenMNConst && !Append_KV_Const && IsEvenKConst && !Is_Local_Const && Kernel_traits::kHeadDim <= 128, + IsEvenKConst, SplitConst, Append_KV_Const >; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); + } + kernel<<(smem_size), stream>>>(params); + }); }); }); }); @@ -113,7 +129,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { // If headdim is divisible by 64, then we set kBlockM = 8, etc. constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { @@ -135,8 +151,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int kBlockM = 64; // Fixed for all head dimensions - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd>(params, stream); } @@ -163,8 +182,8 @@ void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 96; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 96; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -240,7 +259,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 192; + constexpr static int Headdim = 192; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -254,7 +273,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { constexpr static int Headdim = 224; - int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 52a4e56491c5e..cb08dbc853a91 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -25,7 +25,7 @@ struct Flash_kernel_traits { #endif using ElementAccum = float; - using index_t = uint32_t; + using index_t = int64_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< @@ -76,7 +76,6 @@ struct Flash_fwd_kernel_traits : public Base { typename Base::MMA_Atom_Arch, Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group Tile, _16, _16>>; - using SmemLayoutAtomQ = decltype(composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, @@ -89,19 +88,9 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomVtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomVtransposedNoSwizzle{}, - Shape, Int>{})); - // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype(composition(Swizzle{}, Layout, Int>, @@ -112,10 +101,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; - static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -127,8 +114,8 @@ struct Flash_fwd_kernel_traits : public Base { // to the same banks. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = cute::Layout, cute::Int>, - cute::Stride, _1>>; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. @@ -136,27 +123,19 @@ struct Flash_fwd_kernel_traits : public Base { Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy>; - using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - cute::Layout>{})); // Val layout, 8 vals per read + Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - cute::Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = cute::Layout, cute::Int>, - cute::Stride, _1>>; - - using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - cute::Layout>{})); // Val layout, 8 vals per store + Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, - cute::Layout, // Thread layout, 8 threads per row - cute::Stride<_8, _1>>, - cute::Layout, // Thread layout, 16 threads per row - cute::Stride<_16, _1>>>; + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store @@ -205,84 +184,61 @@ struct Flash_bwd_kernel_traits : public Base { using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, - cute::Layout, cute::Int, _1>>, + Layout, Int, _1>>, Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, - cute::Layout, cute::Int, _1>>, + Layout, Int, _1>>, Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, - cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, - cute::Layout>, - cute::Stride, _1>>{})); + Layout>, + Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, - cute::make_shape(cute::Int{}, cute::Int{}))); + make_shape(Int{}, Int{}))); using SmemLayoutAtomKV = decltype(composition(Swizzle{}, - cute::Layout, cute::Int>, - cute::Stride, _1>>{})); + Layout, Int>, + Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, - cute::make_shape(cute::Int{}, cute::Int{}))); - - using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomKtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); - using SmemLayoutKtransposed = decltype(tile_to_shape( - SmemLayoutAtomKtransposed{}, - make_shape(Int{}, Int{}))); - // Maybe the KtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomKtransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; - static_assert(kBlockN >= 64); + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = 64; + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, - cute::Layout, cute::Int>, - cute::Stride, _1>>{})); + Layout, Int>, + Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, - cute::make_shape(cute::Int{}, cute::Int{}))); - using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); - using SmemLayoutPdStransposed = decltype(tile_to_shape( - SmemLayoutAtomPdStransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomPdStransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype(composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + using SmemCopyAtomPdS = Copy_Atom; - using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); - using SmemLayoutQdOtransposed = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposed = decltype(composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, Layout>, @@ -300,25 +256,18 @@ struct Flash_bwd_kernel_traits : public Base { make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; - static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ - static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; - static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); - static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); - static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); - // static constexpr int kSmemdPsumCount = kBlockM; - static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); - static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); - static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); @@ -326,8 +275,8 @@ struct Flash_bwd_kernel_traits : public Base { // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = cute::Layout, cute::Int>, - cute::Stride, _1>>; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. @@ -337,30 +286,30 @@ struct Flash_bwd_kernel_traits : public Base { DefaultCopy>; using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - cute::Layout>{})); // Val layout, 8 vals per read + Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - cute::Layout>{})); // Val layout, 8 vals per store + Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - cute::Layout>{})); // Val layout, 8 vals per store + Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - cute::Layout>{})); // Val layout, 8 vals per store + Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, - cute::Layout, // Thread layout, 8 threads per row - cute::Stride<_8, _1>>, - cute::Layout, // Thread layout, 16 threads per row - cute::Stride<_16, _1>>>; + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, - cute::Layout>{})); // Val layout, 4 vals per store + Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, - cute::Layout, // Thread layout, 8 threads per row - cute::Stride<_32, _1>>{}, - cute::Layout>{})); // Val layout, 1 val per store + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store }; } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h new file mode 100644 index 0000000000000..b225e5e3be559 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h @@ -0,0 +1,208 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace onnxruntime { +namespace flash { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope = 0.f) + : max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q), window_size_left(window_size_left), window_size_right(window_size_right), alibi_slope(!Has_alibi ? 0.0 : alibi_slope){}; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor& tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } else { +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; +}; + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h new file mode 100644 index 0000000000000..dfc14ab4b4406 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/rotary.h @@ -0,0 +1,154 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "contrib_ops/cuda/bert/flash_attention/utils.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace onnxruntime { +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 8017f83bbb01d..3c205378f0177 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -7,8 +7,7 @@ #include -#include -#include +#include #include "contrib_ops/cuda/bert/flash_attention/utils.h" @@ -20,7 +19,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template -__device__ inline void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); @@ -35,7 +34,7 @@ __device__ inline void thread_reduce_(Tensor const& tensor, Te } template -__device__ inline void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++) { @@ -44,26 +43,26 @@ __device__ inline void quad_allreduce_(Tensor& dst, Tensor -__device__ inline void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template -__device__ inline void reduce_max(Tensor const& tensor, Tensor& max) { +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor& max) { MaxOp max_op; reduce_(tensor, max, max_op); } -template -__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor& sum) { SumOp sum_op; - reduce_(tensor, sum, sum_op); + thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template -inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -85,7 +84,7 @@ inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor // Apply the exp to all the elements. template -inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { +__forceinline__ __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -115,103 +114,71 @@ inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tens } } -template -inline __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, - const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; -#pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; -#pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= max_seqlen_k) { -// Without the "make_coord" we get wrong results -#pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; - } - } - } - } -} +//////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride, - const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - // const int row_idx_offset = row_idx_offset_ + lane_id / 4; - const int row_idx_offset = row_idx_offset_; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; -#pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; -#pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); - const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); -#pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; -#pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } +template +struct Softmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax(){}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; } } - // if (cute::thread0()) { - // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); - // print(tensor(make_coord(i, mi), _)); - // // print(tensor(_, j + nj * size<1, 0>(tensor))); - // } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); } - } -} + }; -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { - // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 - apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, - max_seqlen_q, warp_row_stride, -1, 0); -} - -template -inline __device__ void apply_mask_causal_w_idx( - Tensor& tensor, Tensor const& idx_rowcol, - const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 2, "Only support 2D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); - CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); -#pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); -#pragma unroll - for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { - if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; } } - // if (cute::thread0()) { - // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); - // print(tensor(_, make_coord(j, ni))); - // // print(tensor(_, j + ni * size<1, 0>(tensor))); - // } - } -} + return lse; + }; +}; } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h index 5b70988949bbd..02bd7effd7da6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -23,6 +23,37 @@ } \ }() +#define FLASHATTENTION_DISABLE_ALIBI // TEMP: Remove if we enable alibi +#ifdef FLASHATTENTION_DISABLE_ALIBI +#define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K +#define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else +#define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL +#define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define LOCAL_SWITCH BOOL_SWITCH +#endif + #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ @@ -34,7 +65,7 @@ } \ }() -#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ +#define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ constexpr static int kHeadDim = 32; \ diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 7aefd4799bc4d..9ef75120881e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -27,10 +27,10 @@ namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ uint32_t relu2(const uint32_t x); +__forceinline__ __device__ uint32_t relu2(const uint32_t x); template <> -inline __device__ uint32_t relu2(const uint32_t x) { +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -52,7 +52,7 @@ inline __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template <> -inline __device__ uint32_t relu2(const uint32_t x) { +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" @@ -67,10 +67,10 @@ inline __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template -inline __device__ uint32_t convert_relu2(const float2 x); +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); template <> -inline __device__ uint32_t convert_relu2(const float2 x) { +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -81,7 +81,7 @@ inline __device__ uint32_t convert_relu2(const float2 x) { } template <> -inline __device__ uint32_t convert_relu2(const float2 x) { +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -97,20 +97,20 @@ inline __device__ uint32_t convert_relu2(const float2 x) { template struct MaxOp { - __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster - __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { - __device__ inline T operator()(T const& x, T const& y) { return x + y; } + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -119,7 +119,7 @@ template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template - static __device__ inline T run(T x, Operator& op) { + static __device__ __forceinline__ T run(T x, Operator& op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); @@ -131,7 +131,7 @@ struct Allreduce { template <> struct Allreduce<2> { template - static __device__ inline T run(T x, Operator& op) { + static __device__ __forceinline__ T run(T x, Operator& op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } @@ -143,10 +143,10 @@ template -inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, - Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, - ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { +__forceinline__ __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -178,9 +178,9 @@ inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 template -inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, - ThrCopy smem_thr_copy_B) { +__forceinline__ __device__ void gemm_rs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -200,42 +200,48 @@ inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template -inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting - // "int_tuple.hpp(74): error: conversion to inaccessible base class" - // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. template -inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; - static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); - static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); - constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; - auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - // TD [2023-08-13]: Same error as above on Cutlass 3.2 - // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - // get<0, 1>(l), - // get<1, 1, 1>(l)); - return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), - get<1>(get<0>(l)), - get<1>(get<1>(get<1>(l)))); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ auto convert_type(Tensor const& tensor) { +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; @@ -247,7 +253,7 @@ inline __device__ auto convert_type(Tensor const& tensor) { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void relu_(Tensor& tensor) { +__forceinline__ __device__ void relu_(Tensor& tensor) { constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); using value_t = typename Engine::value_type; @@ -263,7 +269,7 @@ inline __device__ void relu_(Tensor& tensor) { // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction template -inline __device__ auto convert_type_relu(Tensor const& tensor) { +__forceinline__ __device__ auto convert_type_relu(Tensor const& tensor) { using From_type = typename Engine::value_type; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); @@ -304,9 +310,9 @@ CUTE_HOST_DEVICE void cp_async_wait() { template -inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, const int max_MN = 0) { +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -363,138 +369,5 @@ inline __device__ void copy_w_min_idx(Tensor const& S, //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_rotary_interleaved(Tensor const& S, - Tensor& D, - Tensor const& Cos, - Tensor const& Sin, - Tensor const& identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); -#pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -#pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); -#pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void copy_rotary_contiguous(Tensor const& S, - Tensor& D, - Tensor const& Cos, - Tensor const& Sin, - Tensor const& identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); -#pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -#pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); -#pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 0c26f04edef99..3b6ad238cc826 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -141,7 +141,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = num_splits; + parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 62974d12003fe..3099b52cce13e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -577,7 +577,7 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; int b = tid / seqlen; @@ -592,7 +592,7 @@ __global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { position_ids[tid] = seqlens_k[tid]; @@ -600,7 +600,7 @@ __global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, } // Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; @@ -675,7 +675,7 @@ Status FlashAttention( bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, - reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu index d610051c77e50..2c251246267b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu @@ -148,7 +148,7 @@ __launch_bounds__(blockSize) // following the Softmax. // // For now zero-out only [row_index - 2*attention_window, row_index + 2*attention_window], - // we can even be more agressive and reduce the zeroing out window size since + // we can even be more aggressive and reduce the zeroing out window size since // each row has entries in 3 blocks (3*attention_window size instead of 4*attention_window) int zero_start = row_index - 2 * attention_window; if (zero_start < 0) { diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 5ae7c149fa05c..ba8b00df07e06 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -166,7 +166,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = num_splits; + parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index a18744d29b1db..3e168189be3d5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -631,6 +631,8 @@ Status FlashAttention( data.output, cu_seqlens_q, cu_seqlens_k, + nullptr, // seqused_k + nullptr, // block_table softmax_lse_buffer, batch_size, num_heads, @@ -640,7 +642,7 @@ Status FlashAttention( sequence_length, scale, false, // is causal - false // is bf16 + false // is bf16 )); DUMP_TENSOR_INIT(); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index 1b28b288f3d7c..ad0a83c9cde65 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -25,8 +25,9 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int64_t* position_ids, // (1) or BxS const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int position_ids_format, - const bool interleaved, const int batch_stride, const int seq_stride, - const int head_stride) { + const bool interleaved, + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous +) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -40,10 +41,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH return; } - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - - const T* input_data = input + block_offset; - T* output_data = output + block_offset; + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; if (i >= rotary_embedding_dim) { output_data[i] = input_data[i]; @@ -77,34 +76,58 @@ template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, const T* cos_cache, const T* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool is_input_bnsh_format) { + int4 in_strides; + int4 out_strides; + if (is_input_bnsh_format) { + int in_head_stride = sequence_length * head_size; + int out_head_stride = sequence_length * head_size; + in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; + out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; + } else { + int in_head_stride = head_size; + int out_head_stride = head_size; + in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1}; + out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1}; + } + return LaunchRotaryEmbeddingKernel( + stream, output, input, position_ids, + cos_cache, sin_cache, batch_size, + sequence_length, num_heads, head_size, + rotary_embedding_dim, max_sequence_length, + position_ids_format, interleaved, + max_threads_per_block, + in_strides, out_strides); +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, + int4 in_strides, int4 out_strides // strides in bnsh coord +) { // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); int tpb = (head_size + 31) / 32 * 32; const dim3 block(tpb); const dim3 grid(sequence_length, batch_size, num_heads); - // Default input tensor shape is [batch, seq, hidden_size] - int head_stride = head_size; - int seq_stride = num_heads * head_stride; - int batch_stride = sequence_length * seq_stride; - if (is_input_bnsh_format) { - seq_stride = head_size; - head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; - } - assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, rotary_embedding_dim, position_ids_format, - interleaved, batch_stride, seq_stride, head_stride); + interleaved, in_strides, out_strides); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 6053814b835bb..dd0ac6a6e3274 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -28,6 +28,26 @@ Status LaunchRotaryEmbeddingKernel( const int max_threads_per_block, const bool is_input_bnsh_format); +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + int4 in_strides, + int4 out_strides); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h index 2c9dc3689f882..116b9fb80da4d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h @@ -41,7 +41,7 @@ struct Gmem_params { // Hidden dim per head int32_t d; - // array of length b+1 holding prefix sum of actual sequence lenghts. + // array of length b+1 holding prefix sum of actual sequence lengths. int32_t* cu_seqlens; }; @@ -69,7 +69,7 @@ struct Fused_multihead_attention_params_mhca { // See https://confluence.nvidia.com/pages/viewpage.action?pageId=302779721 for details. bool enable_i2f_trick; - // array of length b+1 holding prefix sum of actual sequence lenghts + // array of length b+1 holding prefix sum of actual sequence lengths int32_t* cu_seqlens; // use C/32 Format. diff --git a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu index 666ec3a993235..9de3d48417b34 100644 --- a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu +++ b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu @@ -29,7 +29,7 @@ namespace onnxruntime { namespace cuda { namespace collective { -#if defined(USE_MPI) || defined(USE_NCCL) +#if defined(USE_MPI) || defined(ORT_USE_NCCL) using namespace onnxruntime; using namespace onnxruntime::cuda; diff --git a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h index 3ca3c1dd166af..4721fb11ec86d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h +++ b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h @@ -25,7 +25,7 @@ namespace onnxruntime { namespace cuda { namespace collective { -#if defined(USE_MPI) || defined(USE_NCCL) +#if defined(USE_MPI) || defined(ORT_USE_NCCL) constexpr size_t WARP_SIZE = 32; constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24; diff --git a/onnxruntime/contrib_ops/cuda/collective/ipc_utils.cc b/onnxruntime/contrib_ops/cuda/collective/ipc_utils.cc index b4e602228ee63..fefad449bb5b5 100644 --- a/onnxruntime/contrib_ops/cuda/collective/ipc_utils.cc +++ b/onnxruntime/contrib_ops/cuda/collective/ipc_utils.cc @@ -23,7 +23,7 @@ namespace onnxruntime { namespace cuda { namespace collective { -#if defined(USE_MPI) || defined(USE_NCCL) +#if defined(USE_MPI) || defined(ORT_USE_NCCL) using namespace onnxruntime; diff --git a/onnxruntime/contrib_ops/cuda/collective/ipc_utils.h b/onnxruntime/contrib_ops/cuda/collective/ipc_utils.h index cda0f3437b25f..44f352168b411 100644 --- a/onnxruntime/contrib_ops/cuda/collective/ipc_utils.h +++ b/onnxruntime/contrib_ops/cuda/collective/ipc_utils.h @@ -24,7 +24,7 @@ namespace onnxruntime { namespace cuda { namespace collective { -#if defined(USE_MPI) || defined(USE_NCCL) +#if defined(USE_MPI) || defined(ORT_USE_NCCL) struct CudaDeleter { void operator()(void* ptr) const noexcept { @@ -86,4 +86,4 @@ GetCustomAllReduceWorkspace(int rank, int world_size, size_t input_size, IPCMemo } // namespace collective } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 10b6f7dd56ef8..49646637b635e 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -5,7 +5,7 @@ #include "core/providers/cuda/cuda_kernel.h" -#if defined(ORT_USE_NCCL) || defined(ORT_USE_MPI) +#if defined(ORT_USE_NCCL) || defined(USE_MPI) #ifndef USE_ROCM #include "custom_reduce_impl.h" #include "ipc_utils.h" diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index 20c936e1b6718..e1549ad295664 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -4,7 +4,7 @@ #include "sharding_spec.h" #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/framework/tensor_shape.h" #include diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 5abc50a61c9a3..5b47273bc8f2f 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -37,7 +37,7 @@ class DeviceMesh { // corresponding sharding spec is a string "S[1]S[0]". // If that 2-D tensor's value is np.array([[5, 6], [7, 8]]), // GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding - // proccess. + // process. // - Start with a 2-D device mesh [[0, 1], [2, 3]] and // a 2-D tensor [[5, 6], [7, 8]] // - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]] diff --git a/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h b/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h index 9cd96d5c62c94..3d21b12f9b55c 100644 --- a/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h +++ b/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h @@ -33,10 +33,10 @@ struct CufftPlanInfo { // see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function template struct ParamsHash { - // Params must be a POD because we read out its memory - // contenst as char* when hashing + // Params must be a trivial type because we read out its memory + // contents as char* when hashing - static_assert(std::is_pod::value, "Params is not POD"); + static_assert(std::is_trivial::value, "Params is not a trivial type"); size_t operator()(const T& params) const { auto ptr = reinterpret_cast(¶ms); uint32_t value = 0x811C9DC5; @@ -50,10 +50,10 @@ struct ParamsHash { template struct ParamsEqual { - // Params must be a POD because we read out its memory - // contenst as char* when comparing + // Params must be a trivial type because we read out its memory + // contents as char* when comparing - static_assert(std::is_pod::value, "Params is not POD"); + static_assert(std::is_trivial::value, "Params is not a trivial type"); bool operator()(const T& a, const T& b) const { auto ptr1 = reinterpret_cast(&a); diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc index 4b524dcf795a2..65bec758ae525 100644 --- a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc @@ -87,7 +87,7 @@ Status FFTBase::DoFFT(OpKernelContext* context, const Tensor* X, bool complex int64_t batch_size = (batch_ndim == 0 ? 1 : input_shape.SizeToDimension(batch_ndim)); // infer output shape - // copy the input shape up to the second last dimention + // copy the input shape up to the second last dimension std::vector output_dims, signal_dims; int i = 0; for (; i < batch_ndim + signal_ndim_ - 1; ++i) { diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 28ab27ee33d10..07c5de2fe8d8c 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm( #endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; #endif default: @@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm( compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; break; case CUDA_R_32F: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; default: ORT_THROW("Unable to determine computeType in operator GemmFloat8."); diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 7d3f6eb9295d8..865a1dc29ce47 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -3,7 +3,7 @@ #include "contrib_ops/cuda/sparse/sparse_attention_impl.h" #include "contrib_ops/cuda/sparse/sparse_attention.h" -#include "contrib_ops/cuda/sparse/sparse_attention_helper.h" +#include "contrib_ops/cpu/sparse/sparse_attention_helper.h" #include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h" #include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h" #include "core/platform/env_var_utils.h" diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 09d2dba7d203a..e047bd948434d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -53,7 +53,7 @@ namespace GenerationCudaDeviceHelper { // e.g In the case of past(fp32) -> cast to fp16 -> Attention(fp16), the reorder // function will use the fp32 chunk size and cause the model silently generates // the incorrect results. -// TODO: Fix this issue. Either retrive the Attention op type from the graph or +// TODO: Fix this issue. Either retrieve the Attention op type from the graph or // check the type of past state as graph input should be same as Attention op type. // It might be better to forcefully require the same type since cast node generates // extra overhead. diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 6b712ccfbeb7e..0fe2d7ccb1f76 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -7,7 +7,7 @@ #include "core/providers/cpu/tensor/utils.h" #include "core/providers/cuda/cuda_common.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/generation_shared.h" diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index fb7af3cfdd54f..e10c2ec63fd51 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -202,6 +202,10 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +void CudaTensorConsoleDumper::Print(const std::string& value) const { + std::cout << value << std::endl; +} + void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, true); @@ -325,6 +329,10 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else + +void CudaTensorConsoleDumper::Print(const std::string&) const { +} + void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 0f25e85bb97d7..6ad0ad9a67b75 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -46,6 +46,8 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; void Print(const char* name, const std::string& value, bool end_line) const override; + + void Print(const std::string& value) const override; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 7bc3414c89978..11899feb6e1dc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -14,6 +14,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGe class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention); +// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu); @@ -23,6 +25,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, Simp class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); template <> + KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; return info; @@ -37,6 +40,8 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/js/layer_norm.cc b/onnxruntime/contrib_ops/js/layer_norm.cc index 814543a9905e0..ec4603cc69de4 100644 --- a/onnxruntime/contrib_ops/js/layer_norm.cc +++ b/onnxruntime/contrib_ops/js/layer_norm.cc @@ -8,6 +8,19 @@ namespace onnxruntime { namespace contrib { namespace js { +// LayerNormalization used to be a contrib op +// that (incorrectly) used kOnnxDomain so we need to version it +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 1, + 16, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", onnxruntime::js::JsepSupportedFloatTypes()) + .TypeConstraint("U", onnxruntime::js::JsepSupportedFloatTypes()), + onnxruntime::js::LayerNorm); + ONNX_OPERATOR_KERNEL_EX( SimplifiedLayerNormalization, kOnnxDomain, diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 3164e8c211099..349df045becf2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -169,6 +169,13 @@ Status ClassifyAttentionMode(AttentionType type, const std::vector& past, const std::vector& present); +template +Status LaunchStridedCopy( + hipStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(hipStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu new file mode 100644 index 0000000000000..7a16eb38181aa --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -0,0 +1,526 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/group_query_attention.h" +#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" + +#ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "ck_tile/core/numeric/integer.hpp" +#include "fmha_fwd.hpp" +#endif + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +std::string GetCkFmhaDataTypeString(); + +template <> +std::string GetCkFmhaDataTypeString() { + return "fp16"; +} + +template <> +std::string GetCkFmhaDataTypeString() { + return "bf16"; +} + +__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = seqlens[idx] + inc; + } +} + +Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); + return HIP_CALL(hipGetLastError()); +} + +__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = idx * length_per_seq; + } + if (idx == 0) { + out[num_elems] = num_elems * length_per_seq; + } +} + +Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqstart_init_kernel<<>>(out, num_elems, length_per_seq); + return HIP_CALL(hipGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, + int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return HIP_CALL(hipGetLastError()); +} + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : RocmKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + is_unidirectional_ = true; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { +#if USE_COMPOSABLE_KERNEL_CK_TILE + auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); + const Tensor* query = ctx->Input(0); + const Tensor* key = ctx->Input(1); + const Tensor* value = ctx->Input(2); + const Tensor* past_key = ctx->Input(3); + const Tensor* past_value = ctx->Input(4); + const Tensor* seqlens_k = ctx->Input(5); + const Tensor* total_seqlen = ctx->Input(6); + const Tensor* cos_cache = ctx->Input(7); + const Tensor* sin_cache = ctx->Input(8); + + auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + + GroupQueryAttentionParameters parameters; + using HipT = typename ToHipType::MappedType; + + const int max_thr_per_blk = device_prop.maxThreadsPerBlock; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + max_thr_per_blk)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + // parameters.zeros_count = kZerosCount; + // parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = ctx->Output(0, output_shape); + Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + + int4 past_shape; + std::vector present_dims; + Strides present_strides; + Strides past_strides; + if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + past_shape = { + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; + past_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); + present_dims = { + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; + present_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + } else { // BNSH + past_shape = { + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; + past_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); + present_dims = { + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; + present_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); + } + TensorShape present_shape(present_dims); + Tensor* present_key = ctx->Output(1, present_shape); + Tensor* present_value = ctx->Output(2, present_shape); + + Strides query_strides; + Strides key_strides; + Strides value_strides; + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord + const HipT* query_ptr = reinterpret_cast(query->DataRaw()); + const HipT* key_ptr; + const HipT* value_ptr; + if (!parameters.is_packed_qkv) { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + value_strides = key_strides; + key_ptr = reinterpret_cast(key->DataRaw()); + value_ptr = reinterpret_cast(value->DataRaw()); + } else { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + value_strides = query_strides; + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key_ptr = query_ptr + key_offset; + value_ptr = key_ptr + value_offset; + } + + IAllocatorUniquePtr rotary_q_tmp; + IAllocatorUniquePtr rotary_k_tmp; + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); + + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); + rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); + auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, + reinterpret_cast(seqlens_k->DataRaw()), + reinterpret_cast(rotary_position_ids_tmp.get()), + hip_stream, max_thr_per_blk)); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + query_strides.ForBNSHCoord(), + rotary_q_strides.ForBNSHCoord())); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + key_strides.ForBNSHCoord(), + rotary_k_strides.ForBNSHCoord())); + query_ptr = reinterpret_cast(rotary_q_tmp.get()); + key_ptr = reinterpret_cast(rotary_k_tmp.get()); + query_strides = rotary_q_strides; + key_strides = rotary_k_strides; + } + + const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; + IAllocatorUniquePtr seqlens_k_tmp; + + // build present kv cache + auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); + auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); + if (parameters.is_prompt) { + // copy prompt kv to present kv + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); + parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: + if (!parameters.kv_share_buffer) { + // copy past to present, + // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are + // not the same, aka, can not be as simple as strided + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + // In the case of share buffer + ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); + ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); + } + // then append new kv to present + size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + + // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. + // we should call fmha with total sequence lengths + seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); + seqlens_k_ptr = seqlens_k_tmp.get(); + } + static_assert(std::is_same_v); + + const float scale = parameters.scale == 0.0f + ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + bias_enum bias_type = bias_enum::no_bias; + + mask_info mask = [&]() { + if (local_window_size_ != -1) { + mask_info ret; + ret.type = mask_enum::window_generic; + ret.left = local_window_size_; + ret.right = parameters.is_unidirectional ? 0 : -1; + // ret.x = kv_sequence_length - (sequence_length - ret.left); + // ret.y = sequence_length + (ret.right - kv_sequence_length); + return ret; + } + + if (parameters.is_prompt && is_unidirectional_) { + return mask_info::decode("t", sequence_length, kv_sequence_length); + } + + return mask_info::decode("0", sequence_length, kv_sequence_length); + }(); + + auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_q_tmp.get(), batch_size, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_k_tmp.get(), batch_size, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + + fmha_fwd_args args{ + query_ptr, + present_key->DataRaw(), + present_value->DataRaw(), + nullptr, // bias, alibi/element + nullptr, // lse, logsumexp buffer + output->MutableDataRaw(), + seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode + seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode + seqlens_k_ptr, // seqlen_k_ptr, for group mode + sequence_length, // seqlen_q, for batch mode + kv_sequence_length, // seqlen_k, for batch mode + parameters.batch_size, // batch + parameters.sequence_length, // max_seqlen_q + parameters.head_size, // hdim_q + parameters.head_size, // hdim_v + parameters.num_heads, + parameters.kv_num_heads, + scale, + 1.0f, // scale_p of squant, useless + 1.0f, // scale_o of squant, useless + static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S + batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S + static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N + 0, // nhead_stride_bias + batch_size, // nhead_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B + static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B + 0, // batch_stride_bias + num_heads * batch_size, // batch_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B + mask.left, // window_size_left + mask.right, // window_size_right + static_cast(mask.type)}; + +#if 0 + std::cout + << "\n sequence_length:" << sequence_length + << "\n kv_sequence_length:" << kv_sequence_length + << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; + + std::cout + << "\n q_ptr:" << args.q_ptr + << "\n k_ptr:" << args.k_ptr + << "\n v_ptr:" << args.v_ptr + << "\n bias_ptr:" << args.bias_ptr + << "\n lse_ptr:" << args.lse_ptr + << "\n o_ptr:" << args.o_ptr + << "\n seqstart_q_ptr:" << args.seqstart_q_ptr + << "\n seqstart_k_ptr:" << args.seqstart_k_ptr + << "\n seqlen_k_ptr:" << args.seqlen_k_ptr + << "\n seqlen_q:" << args.seqlen_q + << "\n seqlen_k:" << args.seqlen_k + << "\n batch:" << args.batch + << "\n max_seqlen_q:" << args.max_seqlen_q + << "\n hdim_q:" << args.hdim_q + << "\n hdim_v:" << args.hdim_v + << "\n nhead_q:" << args.nhead_q + << "\n nhead_k:" << args.nhead_k + << "\n scale_s:" << args.scale_s + << "\n scale_p:" << args.scale_p + << "\n scale_o:" << args.scale_o + << "\n stride_q:" << args.stride_q + << "\n stride_k:" << args.stride_k + << "\n stride_v:" << args.stride_v + << "\n stride_bias:" << args.stride_bias + << "\n stride_o:" << args.stride_o + << "\n nhead_stride_q:" << args.nhead_stride_q + << "\n nhead_stride_k:" << args.nhead_stride_k + << "\n nhead_stride_v:" << args.nhead_stride_v + << "\n nhead_stride_bias:" << args.nhead_stride_bias + << "\n nhead_stride_lse:" << args.nhead_stride_lse + << "\n nhead_stride_o:" << args.nhead_stride_o + << "\n batch_stride_q:" << args.batch_stride_q + << "\n batch_stride_k:" << args.batch_stride_k + << "\n batch_stride_v:" << args.batch_stride_v + << "\n batch_stride_bias:" << args.batch_stride_bias + << "\n batch_stride_lse:" << args.batch_stride_lse + << "\n batch_stride_o:" << args.batch_stride_o + << "\n window_size_left:" << args.window_size_left + << "\n window_size_right:" << args.window_size_right + << "\n mask_type:" << args.mask_type + << std::endl; +#endif + + fmha_fwd_traits traits{ + parameters.head_size, + parameters.head_size, // v head size + GetCkFmhaDataTypeString(), + !parameters.is_prompt, // true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest + mask.type, + bias_type, + false, // has_lse + false, // do_fp8_static_quant, aka, squant + }; + + ck_tile::stream_config stream_config{ + hip_stream, + false // time_kernel + }; + + auto duration = fmha_fwd(traits, args, stream_config); + if (duration < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); + } + HIP_RETURN_IF_ERROR(hipGetLastError()); + + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h new file mode 100644 index 0000000000000..ce0de1f761aa5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class GroupQueryAttention final : public RocmKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + + private: + static std::once_flag arch_checking_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu index 1e175b37b02d8..b65841b359647 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -78,7 +78,7 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { auto m = !transA_ ? a_shape[0] : a_shape[1]; auto k = !transA_ ? a_shape[1] : a_shape[0]; - ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatiable + ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible auto n = !transB_ ? b_shape[1] : b_shape[0]; TensorShapeVector output_shape = {m, n}; diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7e5e7d7ee076d..4284b4254f485 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -71,6 +71,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -227,6 +229,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.h b/onnxruntime/core/codegen/mti/mti_tvm_utils.h index e85489dc1cac2..c2a14106c1686 100644 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.h +++ b/onnxruntime/core/codegen/mti/mti_tvm_utils.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include #include "core/codegen/mti/common.h" diff --git a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc b/onnxruntime/core/codegen/mti/tensor/concat_ops.cc index 625a91bc6456e..3394d5b7e00a2 100644 --- a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc +++ b/onnxruntime/core/codegen/mti/tensor/concat_ops.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/tensor/concat_ops.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/codegen/mti/tensor/gather.cc b/onnxruntime/core/codegen/mti/tensor/gather.cc index 3ea6ebf466209..152b3981f1623 100644 --- a/onnxruntime/core/codegen/mti/tensor/gather.cc +++ b/onnxruntime/core/codegen/mti/tensor/gather.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/tensor/gather.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/codegen/mti/tensor/slice.cc b/onnxruntime/core/codegen/mti/tensor/slice.cc index 7c73be2b5234f..6cbab43584d4b 100644 --- a/onnxruntime/core/codegen/mti/tensor/slice.cc +++ b/onnxruntime/core/codegen/mti/tensor/slice.cc @@ -5,7 +5,7 @@ #include "core/codegen/mti/mti_tvm_utils.h" #include -#include "core/common/gsl.h" +#include #include #include diff --git a/onnxruntime/core/codegen/mti/tensor/split.cc b/onnxruntime/core/codegen/mti/tensor/split.cc index 8dbbd8fdcc288..6ee366314858f 100644 --- a/onnxruntime/core/codegen/mti/tensor/split.cc +++ b/onnxruntime/core/codegen/mti/tensor/split.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/tensor/split.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/codegen/mti/tensor/tile.cc b/onnxruntime/core/codegen/mti/tensor/tile.cc index 60ef29f7ce705..2fef86adcbaea 100644 --- a/onnxruntime/core/codegen/mti/tensor/tile.cc +++ b/onnxruntime/core/codegen/mti/tensor/tile.cc @@ -3,7 +3,7 @@ #include "core/codegen/mti/tensor/tile.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace tvm_codegen { diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc index 5c2557142dd0e..88170bb56dd2d 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace tvm_codegen { -// In the cell computation, we don't have the "direction" dimention and sequence dimension, +// In the cell computation, we don't have the "direction" dimension and sequence dimension, // which have been processed outside of the cell. // Here we implement an LTSM cell. // For those args (inputs/outputs) of hidden states we put AFTER regular args (inputs/outputs) diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc index 888dddfd3dbd2..55892974aa33f 100644 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc +++ b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc @@ -7,7 +7,7 @@ #include "core/codegen/passes/utils/codegen_context.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/common/gsl.h" +#include #include diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 0bee36e4d10b3..4c9e7e80db49b 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -61,7 +61,7 @@ class CPUIDInfo { /** * @brief Some ARMv8 power efficient core has narrower 64b load/store - * that needs specialized optimiztion in kernels + * that needs specialized optimization in kernels * @return whether the indicated core has narrower load/store device */ bool IsCoreArmv8NarrowLd(uint32_t coreId) const { @@ -73,7 +73,7 @@ class CPUIDInfo { /** * @brief Some ARMv8 power efficient core has narrower 64b load/store - * that needs specialized optimiztion in kernels + * that needs specialized optimization in kernels * @return whether the current core has narrower load/store device */ bool IsCurrentCoreArmv8NarrowLd() const { diff --git a/onnxruntime/core/common/helper.cc b/onnxruntime/core/common/helper.cc index 7b7073634989d..6a52db73df106 100644 --- a/onnxruntime/core/common/helper.cc +++ b/onnxruntime/core/common/helper.cc @@ -56,7 +56,7 @@ void PrintFinalMessage(const char* msg) { #else // TODO, consider changing the output of the error message from std::cerr to logging when the // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output - // might not be easily accesible on some systems such as mobile + // might not be easily accessible on some systems such as mobile // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS std::cerr << msg << std::endl; #endif diff --git a/onnxruntime/core/common/logging/capture.cc b/onnxruntime/core/common/logging/capture.cc index 3c23e15e5cc00..ac0d4d5fc707d 100644 --- a/onnxruntime/core/common/logging/capture.cc +++ b/onnxruntime/core/common/logging/capture.cc @@ -3,7 +3,7 @@ #include "core/common/logging/capture.h" #include "core/common/logging/logging.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace logging { diff --git a/onnxruntime/core/common/path.cc b/onnxruntime/core/common/path.cc deleted file mode 100644 index 8b74d2d8c9c1c..0000000000000 --- a/onnxruntime/core/common/path.cc +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/path.h" - -#include -#include - -namespace onnxruntime { - -namespace { - -constexpr auto k_dot = ORT_TSTR("."); -constexpr auto k_dotdot = ORT_TSTR(".."); - -constexpr std::array k_valid_path_separators{ - ORT_TSTR('/'), ORT_TSTR('\\')}; - -constexpr bool IsPreferredPathSeparator(PathChar c) { - return c == k_preferred_path_separator; -} - -PathString NormalizePathSeparators(const PathString& path) { - PathString result{}; - std::replace_copy_if( - path.begin(), path.end(), std::back_inserter(result), - [](PathChar c) { - return std::find( - k_valid_path_separators.begin(), - k_valid_path_separators.end(), - c) != k_valid_path_separators.end(); - }, - k_preferred_path_separator); - return result; -} - -// parse component and trailing path separator -PathString::const_iterator ParsePathComponent( - PathString::const_iterator begin, PathString::const_iterator end, - PathString::const_iterator& component_end, bool* has_trailing_separator) { - component_end = std::find_if(begin, end, IsPreferredPathSeparator); - const auto sep_end = std::find_if_not(component_end, end, IsPreferredPathSeparator); - if (has_trailing_separator) *has_trailing_separator = sep_end != component_end; - return sep_end; -} - -#ifdef _WIN32 - -Status ParsePathRoot( - const PathString& path, - PathString& root, bool& has_root_dir, size_t& num_parsed_chars) { - // assume NormalizePathSeparators() has been called - - // drive letter - if (path.size() > 1 && - (ORT_TSTR('a') <= path[0] && path[0] <= ORT_TSTR('z') || - ORT_TSTR('A') <= path[0] && path[0] <= ORT_TSTR('Z')) && - path[1] == ORT_TSTR(':')) { - const auto root_dir_begin = path.begin() + 2; - const auto root_dir_end = std::find_if_not(root_dir_begin, path.end(), IsPreferredPathSeparator); - - root = path.substr(0, 2); - has_root_dir = root_dir_begin != root_dir_end; - num_parsed_chars = std::distance(path.begin(), root_dir_end); - return Status::OK(); - } - - // leading path separator - auto curr_it = std::find_if_not(path.begin(), path.end(), IsPreferredPathSeparator); - const auto num_initial_seps = std::distance(path.begin(), curr_it); - - if (num_initial_seps == 2) { - // "\\server_name\share_name\" - // after "\\", parse 2 path components with trailing separators - PathString::const_iterator component_end; - bool has_trailing_separator; - curr_it = ParsePathComponent(curr_it, path.end(), component_end, &has_trailing_separator); - ORT_RETURN_IF_NOT(has_trailing_separator, "Failed to parse path root: ", ToUTF8String(path)); - curr_it = ParsePathComponent(curr_it, path.end(), component_end, &has_trailing_separator); - ORT_RETURN_IF_NOT(has_trailing_separator, "Failed to parse path root: ", ToUTF8String(path)); - - root.assign(path.begin(), component_end); - has_root_dir = true; - num_parsed_chars = std::distance(path.begin(), curr_it); - } else { - // "\", "" - root.clear(); - has_root_dir = num_initial_seps > 0; - num_parsed_chars = num_initial_seps; - } - - return Status::OK(); -} - -#else // POSIX - -Status ParsePathRoot( - const PathString& path, - PathString& root, bool& has_root_dir, size_t& num_parsed_chars) { - // assume NormalizePathSeparators() has been called - auto curr_it = std::find_if_not(path.begin(), path.end(), IsPreferredPathSeparator); - const auto num_initial_seps = std::distance(path.begin(), curr_it); - - if (num_initial_seps == 2) { - // "//root_name/" - // after "//", parse path component with trailing separator - PathString::const_iterator component_end; - bool has_trailing_separator; - curr_it = ParsePathComponent(curr_it, path.end(), component_end, &has_trailing_separator); - ORT_RETURN_IF_NOT(has_trailing_separator, "Failed to parse path root: ", ToUTF8String(path)); - - root.assign(path.begin(), component_end); - has_root_dir = true; - num_parsed_chars = std::distance(path.begin(), curr_it); - } else { - // "/", "" - root.clear(); - has_root_dir = num_initial_seps > 0; - num_parsed_chars = num_initial_seps; - } - - return Status::OK(); -} - -#endif - -} // namespace - -Status Path::Parse(const PathString& original_path_str, Path& path) { - Path result{}; - - // normalize separators - const PathString path_str = NormalizePathSeparators(original_path_str); - - // parse root - size_t root_length = 0; - ORT_RETURN_IF_ERROR(ParsePathRoot( - path_str, result.root_name_, result.has_root_dir_, root_length)); - - // parse components - PathString::const_iterator component_begin = path_str.begin() + root_length; - while (component_begin != path_str.end()) { - PathString::const_iterator component_end; - PathString::const_iterator next_component_begin = ParsePathComponent( - component_begin, path_str.end(), component_end, nullptr); - result.components_.emplace_back(component_begin, component_end); - component_begin = next_component_begin; - } - - path = std::move(result); - return Status::OK(); -} - -Path Path::Parse(const PathString& path_str) { - Path path{}; - const auto status = Parse(path_str, path); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - return path; -} - -PathString Path::ToPathString() const { - PathString result = GetRootPathString(); - const size_t components_size = components_.size(); - for (size_t i = 0; i < components_size; ++i) { - result += components_[i]; - if (i + 1 < components_size) result += k_preferred_path_separator; - } - return result; -} - -PathString Path::GetRootPathString() const { - return has_root_dir_ ? root_name_ + k_preferred_path_separator : root_name_; -} - -bool Path::IsEmpty() const { - return !has_root_dir_ && root_name_.empty() && components_.empty(); -} - -bool Path::IsAbsolute() const { -#ifdef _WIN32 - return has_root_dir_ && !root_name_.empty(); -#else // POSIX - return has_root_dir_; -#endif -} - -Path Path::ParentPath() const { - Path parent{*this}; - if (!parent.components_.empty()) parent.components_.pop_back(); - return parent; -} - -Path& Path::Normalize() { - if (IsEmpty()) return *this; - - // handle . and .. - std::vector normalized_components{}; - for (const auto& component : components_) { - // ignore . - if (component == k_dot) continue; - - // handle .. which backtracks over previous component - if (component == k_dotdot) { - if (!normalized_components.empty() && - normalized_components.back() != k_dotdot) { - normalized_components.pop_back(); - continue; - } - } - - normalized_components.emplace_back(component); - } - - // remove leading ..'s if root dir present - if (has_root_dir_) { - const auto first_non_dotdot_it = std::find_if( - normalized_components.begin(), normalized_components.end(), - [](const PathString& component) { return component != k_dotdot; }); - normalized_components.erase( - normalized_components.begin(), first_non_dotdot_it); - } - - // if empty at this point, add a dot - if (!has_root_dir_ && root_name_.empty() && normalized_components.empty()) { - normalized_components.emplace_back(k_dot); - } - - components_ = std::move(normalized_components); - - return *this; -} - -Path& Path::Append(const Path& other) { - if (other.IsAbsolute() || - (!other.root_name_.empty() && other.root_name_ != root_name_)) { - return *this = other; - } - - if (other.has_root_dir_) { - has_root_dir_ = true; - components_.clear(); - } - - components_.insert( - components_.end(), other.components_.begin(), other.components_.end()); - - return *this; -} - -Path& Path::Concat(const PathString& value) { - auto first_separator = std::find_if(value.begin(), value.end(), - [](PathChar c) { - return std::find( - k_valid_path_separators.begin(), - k_valid_path_separators.end(), - c) != k_valid_path_separators.end(); - }); - ORT_ENFORCE(first_separator == value.end(), - "Cannot concatenate with a string containing a path separator. String: ", ToUTF8String(value)); - - if (components_.empty()) { - components_.push_back(value); - } else { - components_.back() += value; - } - return *this; -} - -Status RelativePath(const Path& src, const Path& dst, Path& rel) { - ORT_RETURN_IF_NOT( - src.GetRootPathString() == dst.GetRootPathString(), - "Paths must have the same root to compute a relative path. ", - "src root: ", ToUTF8String(src.GetRootPathString()), - ", dst root: ", ToUTF8String(dst.GetRootPathString())); - - const Path norm_src = src.NormalizedPath(), norm_dst = dst.NormalizedPath(); - const auto& src_components = norm_src.GetComponents(); - const auto& dst_components = norm_dst.GetComponents(); - - const auto min_num_components = std::min( - src_components.size(), dst_components.size()); - - const auto mismatch_point = std::mismatch( - src_components.begin(), src_components.begin() + min_num_components, - dst_components.begin()); - - const auto& common_src_components_end = mismatch_point.first; - const auto& common_dst_components_end = mismatch_point.second; - - std::vector rel_components{}; - rel_components.reserve( - (src_components.end() - common_src_components_end) + - (dst_components.end() - common_dst_components_end)); - - std::fill_n( - std::back_inserter(rel_components), - (src_components.end() - common_src_components_end), - k_dotdot); - - std::copy( - common_dst_components_end, dst_components.end(), - std::back_inserter(rel_components)); - - rel = Path(PathString{}, false, rel_components); - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/common/path.h b/onnxruntime/core/common/path.h deleted file mode 100644 index 732bbabe8ae3e..0000000000000 --- a/onnxruntime/core/common/path.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/common/common.h" -#include "core/common/path_string.h" - -namespace onnxruntime { - -#ifdef _WIN32 -constexpr PathChar k_preferred_path_separator = ORT_TSTR('\\'); -#else // POSIX -constexpr PathChar k_preferred_path_separator = ORT_TSTR('/'); -#endif - -// Note: We should use the std::filesystem library after upgrading to C++17. - -/** A filesystem path. */ -class Path { - public: - Path() = default; - Path(const Path&) = default; - Path& operator=(const Path&) = default; - Path(Path&&) = default; - Path& operator=(Path&&) = default; - - /** Parses a path from `path_str`. */ - static Status Parse(const PathString& path_str, Path& path); - /** Parses a path from `path_str`. Throws on failure. */ - static Path Parse(const PathString& path_str); - - /** Gets a string representation of the path. */ - PathString ToPathString() const; - /** Gets a string representation of the path's root path, if any. */ - PathString GetRootPathString() const; - /** Gets the path components following the path root. */ - const std::vector& GetComponents() const { return components_; } - - /** Whether the path is empty. */ - bool IsEmpty() const; - - /** Whether the path is absolute (refers unambiguously to a file location). */ - bool IsAbsolute() const; - /** Whether the path is relative (not absolute). */ - bool IsRelative() const { return !IsAbsolute(); } - - /** Returns a copy of the path without the last component. */ - Path ParentPath() const; - - /** - * Normalizes the path. - * A normalized path is one with "."'s and ".."'s resolved. - * Note: This is a pure path computation with no filesystem access. - */ - Path& Normalize(); - /** Returns a normalized copy of the path. */ - Path NormalizedPath() const { - Path p{*this}; - return p.Normalize(); - } - - /** - * Appends `other` to the path. - * The algorithm should model that of std::filesystem::path::append(). - */ - Path& Append(const Path& other); - - /** - * Concatenates the current path and the argument string. - * Unlike with Append() or operator/=, additional directory separators are never introduced. - */ - Path& Concat(const PathString& string); - - /** Equivalent to this->Append(other). */ - Path& operator/=(const Path& other) { - return Append(other); - } - /** Returns `a` appended with `b`. */ - friend Path operator/(Path a, const Path& b) { - return a /= b; - } - - friend Status RelativePath(const Path& src, const Path& dst, Path& rel); - - private: - Path(PathString root_name, bool has_root_dir, std::vector components) - : root_name_{std::move(root_name)}, - has_root_dir_{has_root_dir}, - components_{std::move(components)} { - } - - PathString root_name_{}; - bool has_root_dir_{false}; - std::vector components_{}; -}; - -/** - * Computes the relative path from `src` to `dst`. - * Note: This is a pure path computation with no filesystem access. - */ -Status RelativePath(const Path& src, const Path& dst, Path& rel); - -} // namespace onnxruntime diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 10e117267e14b..7b62de799b6fc 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -636,7 +636,7 @@ bool ThreadPool::ShouldParallelize(const concurrency::ThreadPool* tp) { } int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) { - // When not using OpenMP, we parallelise over the N threads created by the pool + // When not using OpenMP, we parallelize over the N threads created by the pool // tp, plus 1 for the thread entering a loop. if (tp) { if (tp->force_hybrid_ || CPUIDInfo::GetCPUIDInfo().IsHybrid()) { diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc index 1eb3bbdb1237f..42dff12eaa2db 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc @@ -4,7 +4,7 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/flatbuffers/schema/ort.fbs.h" #include "core/graph/constants.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/core/framework/config_options.cc b/onnxruntime/core/framework/config_options.cc index 1a4acb6dabf71..9fe5beafd6e7e 100644 --- a/onnxruntime/core/framework/config_options.cc +++ b/onnxruntime/core/framework/config_options.cc @@ -30,11 +30,11 @@ std::string ConfigOptions::GetConfigOrDefault(const std::string& config_key, } Status ConfigOptions::AddConfigEntry(const char* config_key, const char* config_value) noexcept { - std::string key(config_key); + std::string key = config_key; if (key.empty() || key.length() > 128) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Config key is empty or longer than maximum length 128"); - std::string val(config_value); + std::string val = config_value; if (val.length() > onnxruntime::kMaxStrLen) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Config value is longer than maximum length: ", diff --git a/onnxruntime/core/framework/data_transfer_utils.h b/onnxruntime/core/framework/data_transfer_utils.h index d54df49eeb9dd..eeec329544bc0 100644 --- a/onnxruntime/core/framework/data_transfer_utils.h +++ b/onnxruntime/core/framework/data_transfer_utils.h @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/tensor.h" diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index ec50bb7d6a5cb..7665a90448520 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -76,9 +76,9 @@ PathString MakeTensorFileName(const std::string& tensor_name, const NodeDumpOpti return path_utils::MakePathString(make_valid_name(tensor_name), dump_options.file_suffix, ".tensorproto"); } -void DumpTensorToFile(const Tensor& tensor, const std::string& tensor_name, const Path& file_path) { +void DumpTensorToFile(const Tensor& tensor, const std::string& tensor_name, const std::filesystem::path& file_path) { auto tensor_proto = utils::TensorToTensorProto(tensor, tensor_name); - const PathString file_path_str = file_path.ToPathString(); + const PathString file_path_str = file_path.native(); int output_fd; ORT_THROW_IF_ERROR(Env::Default().FileOpenWr(file_path_str, output_fd)); try { @@ -302,7 +302,7 @@ void DumpCpuTensor( break; } case NodeDumpOptions::DataDestination::TensorProtoFiles: { - const Path tensor_file = dump_options.output_dir / Path::Parse(MakeTensorFileName(tensor_metadata.name, dump_options)); + const std::filesystem::path tensor_file = dump_options.output_dir / MakeTensorFileName(tensor_metadata.name, dump_options); DumpTensorToFile(tensor, tensor_metadata.name, tensor_file); break; } @@ -411,11 +411,11 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() { } } - opts.output_dir = Path::Parse(ToPathString(Env::Default().GetEnvironmentVar(env_vars::kOutputDir))); + opts.output_dir = ToPathString(Env::Default().GetEnvironmentVar(env_vars::kOutputDir)); std::string sqlite_db_prefix = ParseEnvironmentVariableWithDefault(env_vars::kSqliteDbPrefix, "execution-trace"); - opts.sqlite_db_prefix = Path::Parse(ToPathString(sqlite_db_prefix)); + opts.sqlite_db_prefix = ToPathString(sqlite_db_prefix); // check for confirmation for dumping data to files for all nodes const bool is_input_or_output_requested = ((opts.dump_flags & NodeDumpOptions::DumpFlags::InputData) != 0) || diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h index bde005fc204c8..6090a835aa060 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.h @@ -16,7 +16,6 @@ #pragma once -#include "core/common/path.h" #include "core/framework/op_kernel.h" #include "core/framework/session_state.h" #include "core/graph/graph.h" @@ -109,9 +108,9 @@ struct NodeDumpOptions { std::string file_suffix; // the output directory for dumped data files - Path output_dir; + std::filesystem::path output_dir; // the sqlite3 db to append dumped data - Path sqlite_db_prefix; + std::filesystem::path sqlite_db_prefix; // Total number of elements which trigger snippet rather than full array for Stdout. Value 0 disables snippet. int snippet_threshold; diff --git a/onnxruntime/core/framework/endian_utils.h b/onnxruntime/core/framework/endian_utils.h index b83977c1ac67f..6f084d058d007 100644 --- a/onnxruntime/core/framework/endian_utils.h +++ b/onnxruntime/core/framework/endian_utils.h @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/status.h" #include "core/common/common.h" diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index dc45cad692b6e..43fe92edc9dfe 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -25,30 +25,10 @@ Class for managing lookup of the execution providers in a session. */ class ExecutionProviders { public: - ExecutionProviders() = default; - - common::Status Add(const std::string& provider_id, const std::shared_ptr& p_exec_provider) { - // make sure there are no issues before we change any internal data structures - if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) { - auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered."); - LOGS_DEFAULT(ERROR) << status.ErrorMessage(); - return status; - } - - // index that provider will have after insertion - auto new_provider_idx = exec_providers_.size(); - - ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx})); - - // update execution provider options - auto providerOptions = p_exec_provider->GetProviderOptions(); - exec_provider_options_[provider_id] = providerOptions; - + ExecutionProviders() { #ifdef _WIN32 - LogProviderOptions(provider_id, providerOptions, false); - // Register callback for ETW capture state (rundown) - WindowsTelemetry::RegisterInternalCallback( + etw_callback_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( [this]( LPCGUID SourceId, ULONG IsEnabled, @@ -79,6 +59,36 @@ class ExecutionProviders { } } }); + WindowsTelemetry::RegisterInternalCallback(etw_callback_); +#endif + } + + ~ExecutionProviders() { +#ifdef _WIN32 + WindowsTelemetry ::UnregisterInternalCallback(etw_callback_); +#endif + } + + common::Status + Add(const std::string& provider_id, const std::shared_ptr& p_exec_provider) { + // make sure there are no issues before we change any internal data structures + if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) { + auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered."); + LOGS_DEFAULT(ERROR) << status.ErrorMessage(); + return status; + } + + // index that provider will have after insertion + auto new_provider_idx = exec_providers_.size(); + + ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx})); + + // update execution provider options + auto providerOptions = p_exec_provider->GetProviderOptions(); + exec_provider_options_[provider_id] = providerOptions; + +#ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); #endif exec_provider_ids_.push_back(provider_id); @@ -156,5 +166,9 @@ class ExecutionProviders { // Whether the CPU provider was implicitly added to a session for fallback (true), // or whether it was explicitly added by the caller. bool cpu_execution_provider_was_implicitly_added_ = false; + +#ifdef _WIN32 + WindowsTelemetry::EtwInternalCallback etw_callback_; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index 7c8f91c7dad34..c5bcd22888b7c 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup #include "core/graph/graph_viewer.h" diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 90ee8a46f66a9..4f745b74abce7 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -637,7 +637,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, - const std::string& ep_context_path, + const std::filesystem::path& ep_context_path, const logging::Logger& logger) { InlinedVector all_ep_context_nodes; for (const auto& ep : execution_providers) { @@ -658,22 +658,20 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - onnxruntime::PathString context_cache_path; - PathString model_pathstring = graph.ModelPath().ToPathString(); + std::filesystem::path context_cache_path; + const std::filesystem::path& model_path = graph.ModelPath(); if (!ep_context_path.empty()) { - context_cache_path = ToPathString(ep_context_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + context_cache_path = ep_context_path; + } else if (!model_path.empty()) { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); } - { -#ifdef _WIN32 - std::wifstream fs(context_cache_path); -#else - std::ifstream fs(context_cache_path); -#endif - ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + if (std::filesystem::exists(context_cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", + context_cache_path, "' exist already."); } Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h index 2b4d3ce81623a..0dd17d2f4a624 100644 --- a/onnxruntime/core/framework/kernel_lookup.h +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup #include "core/framework/kernel_registry.h" #include "core/framework/kernel_type_str_resolver.h" diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 1868583f41ba0..201fda6d978b6 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -7,7 +7,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/status.h" #include "core/framework/kernel_type_str_resolver.h" diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h index fea2a6ef3a439..587be491b360a 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver.h @@ -13,7 +13,7 @@ #include "core/graph/onnx_protobuf.h" #endif // !defined(ORT_MINIMAL_BUILD) -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/status.h" #include "core/graph/op_identifier.h" diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h index 3d06013e4fe77..5daab7c1159be 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h @@ -5,7 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) -#include "core/common/gsl.h" +#include #include "core/common/status.h" #include "core/framework/kernel_type_str_resolver.h" #include "core/graph/op_identifier.h" diff --git a/onnxruntime/core/framework/model_metadef_id_generator.cc b/onnxruntime/core/framework/model_metadef_id_generator.cc index e51c6ebc29975..8b1d1f4f304c9 100644 --- a/onnxruntime/core/framework/model_metadef_id_generator.cc +++ b/onnxruntime/core/framework/model_metadef_id_generator.cc @@ -40,7 +40,7 @@ int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_vi // prefer path the model was loaded from // this may not be available if the model was loaded from a stream or in-memory bytes - const auto& model_path_str = main_graph.ModelPath().ToPathString(); + const auto model_path_str = main_graph.ModelPath().string(); if (!model_path_str.empty()) { MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); } else { diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 4e2f22dea164d..e2c06fbdfa621 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -277,7 +277,7 @@ const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpTyp const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } -const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } +const std::filesystem::path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } void NodeUnit::InitForSingleNode() { @@ -285,7 +285,7 @@ void NodeUnit::InitForSingleNode() { const auto& output_defs = target_node_.OutputDefs(); const auto& node_attrs = target_node_.GetAttributes(); auto qlinear_type = GetQLinearOpType(target_node_); - if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + if (qlinear_type == QLinearOpType::Unknown) { // Not a Qlinear op, add all inputs / outputs auto add_all_io = [](std::vector& defs, const ConstPointerContainer>& node_defs) { @@ -351,6 +351,13 @@ void NodeUnit::InitForSingleNode() { NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 ? input_defs[2] : nullptr, axis}}); + } else if (IsVariadicQLinearOp(qlinear_type)) { + size_t input_num = (input_defs.size() - 2) / 3; + for (size_t i = 0; i < input_num; i++) { + inputs_.push_back(NodeUnitIODef{*input_defs[3 * i + 2], NodeUnitIODef::QuantParam{*input_defs[3 * i + 3], + input_defs[3 * i + 4]}}); + } + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[0], input_defs[1]}}); } else { ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index a168495f12ebf..e84e62479162f 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "core/graph/basic_types.h" #include "core/graph/graph.h" @@ -78,7 +79,7 @@ class NodeUnit { const std::string& Name() const noexcept; int SinceVersion() const noexcept; NodeIndex Index() const noexcept; - const Path& ModelPath() const noexcept; + const std::filesystem::path& ModelPath() const noexcept; ProviderType GetExecutionProviderType() const noexcept; const Node& GetNode() const noexcept { return target_node_; } diff --git a/onnxruntime/core/framework/op_node_proto_helper.cc b/onnxruntime/core/framework/op_node_proto_helper.cc index c3deb94300e78..ca9b74eafe4d4 100644 --- a/onnxruntime/core/framework/op_node_proto_helper.cc +++ b/onnxruntime/core/framework/op_node_proto_helper.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/graph/op.h" -#include "core/common/gsl.h" +#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 0453a7ecac81f..46bfc3630303c 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -7,7 +7,8 @@ #include #include #include -#include "core/common/gsl.h" +#include +#include #include "core/common/inlined_containers.h" #include "core/framework/config_options.h" #include "core/framework/ort_value.h" @@ -89,7 +90,7 @@ struct SessionOptions { // // If session config value is not set, it will be assumed to be ONNX // unless the filepath ends in '.ort' (case insensitive). - std::basic_string optimized_model_filepath; + std::filesystem::path optimized_model_filepath; // enable the memory pattern optimization. // The idea is if the input shapes are the same, we could trace the internal memory allocation diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index e318c9a8238c7..b1a7504b283c5 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -10,7 +10,7 @@ #include "core/common/flatbuffers.h" -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/inlined_containers.h" diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 6af78f18fb82f..4ecd61962d797 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -6,12 +6,13 @@ #include #include #include - +#include +#include #if defined(__wasm__) #include #endif -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/common/span_utils.h" @@ -111,8 +112,8 @@ namespace onnxruntime { namespace { // This function doesn't support string tensors -static Status UnpackTensorWithRawDataImpl(const void* raw_data, size_t raw_data_len, - size_t expected_num_elements, size_t element_size, +static Status UnpackTensorWithRawDataImpl(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + size_t element_size, /*out*/ unsigned char* p_data) { auto src = gsl::make_span(static_cast(raw_data), raw_data_len); auto dst = gsl::make_span(p_data, expected_num_elements * element_size); @@ -152,8 +153,8 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \ \ - gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), \ - num_packed_pairs); \ + gsl::span src_span = \ + gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); \ gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \ \ std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ @@ -165,7 +166,7 @@ DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const ORTCHAR_T* tensor_proto_dir, + const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size) { @@ -180,22 +181,15 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot const auto& location = external_data_info->GetRelPath(); - if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { - external_file_path = location; - } else { - if (tensor_proto_dir != nullptr) { - external_file_path = onnxruntime::ConcatPathComponent(tensor_proto_dir, - external_data_info->GetRelPath()); - } else { - external_file_path = external_data_info->GetRelPath(); - } - } + external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); const size_t external_data_length = external_data_info->GetLength(); ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, - "TensorProto: ", tensor_proto.name(), " external data size mismatch. Computed size: ", - *&tensor_byte_size, ", external_data.length: ", external_data_length); + "TensorProto: ", tensor_proto.name(), + " external data size mismatch. Computed size: ", *&tensor_byte_size, + ", external_data.length: ", external_data_length); file_offset = external_data_info->GetOffset(); @@ -207,17 +201,13 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot // then uses the current directory instead. // This function does not unpack string_data of an initializer tensor Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const ORTCHAR_T* tensor_proto_dir, + const std::filesystem::path& tensor_proto_dir, std::vector& unpacked_tensor) { std::basic_string external_file_path; onnxruntime::FileOffsetType file_offset; SafeInt tensor_byte_size; - ORT_RETURN_IF_ERROR(GetExternalDataInfo( - tensor_proto, - tensor_proto_dir, - external_file_path, - file_offset, - tensor_byte_size)); + ORT_RETURN_IF_ERROR( + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size)); unpacked_tensor.resize(tensor_byte_size); ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( @@ -229,12 +219,9 @@ Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto return Status::OK(); } -// TODO(unknown): Change the current interface to take Path object for model path -// so that validating and manipulating path for reading external data becomes easy -Status TensorProtoToOrtValueImpl(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - const MemBuffer* m, AllocatorPtr alloc, - OrtValue& value) { +Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, + AllocatorPtr alloc, OrtValue& value) { if (m && m->GetBuffer() == nullptr) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "MemBuffer has not been allocated."); } @@ -274,9 +261,91 @@ Status TensorProtoToOrtValueImpl(const Env& env, const ORTCHAR_T* model_path, namespace utils { +void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) { + tensor_proto.set_raw_data(std::move(param)); +} + +void ConvertRawDataInTensorProto(TensorProto* tensor) { + size_t element_size = 1; + char* bytes = NULL; + size_t num_elements = 0; + switch (tensor->data_type()) { + case TensorProto_DataType_FLOAT: + bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); + num_elements = tensor->float_data_size(); + element_size = sizeof(float); + break; + + case TensorProto_DataType_INT32: + bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); + num_elements = tensor->int32_data_size(); + element_size = sizeof(int32_t); + break; + + case TensorProto_DataType_UINT32: + bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); + num_elements = tensor->int32_data_size(); + element_size = sizeof(uint32_t); + break; + + case TensorProto_DataType_UINT8: + case TensorProto_DataType_INT8: + bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); + num_elements = tensor->int32_data_size(); + element_size = sizeof(uint8_t); + break; + + case TensorProto_DataType_UINT16: + case TensorProto_DataType_INT16: + case TensorProto_DataType_FLOAT16: + case TensorProto_DataType_BFLOAT16: + bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); + num_elements = tensor->int32_data_size(); + element_size = sizeof(uint16_t); + break; + + case TensorProto_DataType_UINT64: + bytes = reinterpret_cast(tensor->mutable_uint64_data()->mutable_data()); + num_elements = tensor->uint64_data_size(); + element_size = sizeof(uint64_t); + break; + + case TensorProto_DataType_DOUBLE: + bytes = reinterpret_cast(tensor->mutable_double_data()->mutable_data()); + num_elements = tensor->double_data_size(); + element_size = sizeof(double); + break; + + case TensorProto_DataType_INT64: + bytes = reinterpret_cast(tensor->mutable_int64_data()->mutable_data()); + num_elements = tensor->int64_data_size(); + element_size = sizeof(int64_t); + break; + + case TensorProto_DataType_COMPLEX64: + bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); + num_elements = tensor->float_data_size(); + element_size = sizeof(float); + break; + } + if (tensor->has_raw_data()) { + num_elements = (tensor->raw_data().size()) / element_size; + bytes = const_cast(tensor->mutable_raw_data()->c_str()); + } + for (size_t i = 0; i < num_elements; ++i) { + char* start_byte = bytes + i * element_size; + char* end_byte = start_byte + element_size - 1; + for (size_t count = 0; count < element_size / 2; ++count) { + std::swap(*start_byte++, *end_byte--); + } + } + return; +} + #if !defined(ORT_MINIMAL_BUILD) + static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto& tensor, - const ORTCHAR_T* tensor_proto_dir, + const std::filesystem::path& tensor_proto_dir, size_t expected_num_elements, size_t element_size, /*out*/ unsigned char* p_data) { ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); @@ -292,7 +361,7 @@ static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, - const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, + const std::filesystem::path& tensor_proto_dir, size_t expected_num_elements, /*out*/ T* p_data) { static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); @@ -300,34 +369,35 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } -#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \ - template <> \ - Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ - const ORTCHAR_T* tensor_proto_dir, size_t expected_num_elements, \ - /*out*/ INT4_TYPE* p_data) { \ - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ - \ - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ - std::vector unpacked_tensor; \ - ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \ - \ - size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ - ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \ - \ - gsl::span src_span = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), \ - num_packed_pairs); \ - gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \ - \ - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ - \ - return Status::OK(); \ +#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \ + template <> \ + Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ + const std::filesystem::path& tensor_proto_dir, \ + size_t expected_num_elements, /*out*/ INT4_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + std::vector unpacked_tensor; \ + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \ + \ + size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = \ + gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ } DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2) DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2) -#define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \ - template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const ORTCHAR_T*, size_t, type*); +#define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \ + template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const std::filesystem::path&, \ + size_t, type*); INSTANTIATE_UNPACK_EXTERNAL_TENSOR(float) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(double) @@ -352,7 +422,7 @@ INSTANTIATE_UNPACK_EXTERNAL_TENSOR(Float8E5M2FNUZ) template <> Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& /*tensor*/, - const ORTCHAR_T* /*tensor_proto_dir*/, size_t /*expected_num_elements*/, + const std::filesystem::path& /*tensor_proto_dir*/, size_t /*expected_num_elements*/, /*out*/ std::string* /*p_data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "External data type cannot be STRING."); } @@ -369,7 +439,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d /*out*/ T* p_data, size_t expected_num_elements) { \ if (nullptr == p_data) { \ const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \ - if (size == 0) return Status::OK(); \ + if (size == 0) \ + return Status::OK(); \ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ } \ if (nullptr == p_data || Type != tensor.data_type()) { \ @@ -379,9 +450,9 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elements, p_data); \ } \ if (static_cast(tensor.field_size()) != expected_num_elements) \ - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, \ - "corrupted protobuf data: tensor shape size(", expected_num_elements, \ - ") does not match the data size(", tensor.field_size(), ") in proto"); \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", \ + expected_num_elements, ") does not match the data size(", tensor.field_size(), \ + ") in proto"); \ auto& data = tensor.field_name(); \ for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \ *p_data++ = static_cast(*data_iter); \ @@ -409,7 +480,8 @@ template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* /*raw_data*/, size_t /*raw_data_len*/, /*out*/ std::string* p_data, size_t expected_size) { if (nullptr == p_data) { - if (tensor.string_data_size() == 0) return Status::OK(); + if (tensor.string_data_size() == 0) + return Status::OK(); return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); } if (ONNX_NAMESPACE::TensorProto_DataType_STRING != tensor.data_type()) { @@ -434,7 +506,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d /*out*/ bool* p_data, size_t expected_size) { if (nullptr == p_data) { const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); - if (size == 0) return Status::OK(); + if (size == 0) + return Status::OK(); return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); } if (ONNX_NAMESPACE::TensorProto_DataType_BOOL != tensor.data_type()) { @@ -461,7 +534,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d /*out*/ MLFloat16* p_data, size_t expected_size) { if (nullptr == p_data) { const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); - if (size == 0) return Status::OK(); + if (size == 0) + return Status::OK(); return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); } if (ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 != tensor.data_type()) { @@ -705,15 +779,12 @@ DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4) // Uses the model path to construct the full path for loading external data. In case when model_path is empty // it uses current directory. template -Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const std::filesystem::path& model_path, /*out*/ T* p_data, size_t expected_num_elements) { #if !defined(ORT_MINIMAL_BUILD) if (HasExternalData(tensor)) { - return UnpackTensorWithExternalData( - tensor, - model_path.IsEmpty() ? nullptr : model_path.ParentPath().ToPathString().c_str(), - expected_num_elements, - p_data); + return UnpackTensorWithExternalData(tensor, model_path.parent_path(), + expected_num_elements, p_data); } #else ORT_UNUSED_PARAMETER(model_path); @@ -727,7 +798,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model // instantiate the UnpackTensor variant that supports external data #define INSTANTIATE_UNPACK_TENSOR(type) \ - template Status UnpackTensor(const ONNX_NAMESPACE::TensorProto&, const Path&, type* p_data, size_t); + template Status UnpackTensor(const ONNX_NAMESPACE::TensorProto&, const std::filesystem::path&, type* p_data, size_t); INSTANTIATE_UNPACK_TENSOR(float) INSTANTIATE_UNPACK_TENSOR(double) @@ -812,8 +883,8 @@ TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShape const auto& dims = tensor_shape_proto.dim(); std::vector tensor_shape_vec(static_cast(dims.size())); for (int i = 0; i < dims.size(); ++i) { - tensor_shape_vec[i] = HasDimValue(dims[i]) ? dims[i].dim_value() - : -1; /* symbolic dimensions are represented as -1 in onnxruntime*/ + tensor_shape_vec[i] = + HasDimValue(dims[i]) ? dims[i].dim_value() : -1; /* symbolic dimensions are represented as -1 in onnxruntime*/ } return TensorShape(std::move(tensor_shape_vec)); } @@ -838,7 +909,8 @@ ORT_API_STATUS_IMPL(OrtInitializeBufferForTensor, _In_opt_ void* input, size_t i enum ONNXTensorElementDataType type) { OrtStatus* status = nullptr; ORT_TRY { - if (type != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || input == nullptr) return nullptr; + if (type != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || input == nullptr) + return nullptr; size_t tensor_size = input_len / sizeof(std::string); std::string* ptr = reinterpret_cast(input); for (size_t i = 0, n = tensor_size; i < n; ++i) { @@ -846,16 +918,15 @@ ORT_API_STATUS_IMPL(OrtInitializeBufferForTensor, _In_opt_ void* input, size_t i } } ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); - }); + ORT_HANDLE_EXCEPTION([&]() { status = OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); }); } return status; } ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enum ONNXTensorElementDataType type) { - if (type != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || input == nullptr) return; + if (type != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING || input == nullptr) + return; size_t tensor_size = input_len / sizeof(std::string); std::string* ptr = reinterpret_cast(input); using std::string; @@ -884,18 +955,18 @@ static void DeleteCharArray(void* param) noexcept { } #if !defined(__wasm__) -static Status GetFileContent( - const Env& env, const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, - void*& raw_buffer, OrtCallback& deleter) { +static Status GetFileContent(const Env& env, const std::filesystem::path& file_path, FileOffsetType offset, + size_t length, void*& raw_buffer, OrtCallback& deleter) { // query length if it is 0 if (length == 0) { - ORT_RETURN_IF_ERROR(env.GetFileLength(file_path, length)); + // The return type of std::filesystem::file_size is uintmax_t which could be bigger than size_t + length = narrow(std::filesystem::file_size(file_path)); } // first, try to map into memory { Env::MappedMemoryPtr mapped_memory{}; - auto status = env.MapFileIntoMemory(file_path, offset, length, mapped_memory); + auto status = env.MapFileIntoMemory(file_path.native().c_str(), offset, length, mapped_memory); if (status.IsOK()) { deleter = mapped_memory.get_deleter().callback; raw_buffer = mapped_memory.release(); @@ -905,8 +976,8 @@ static Status GetFileContent( // if that fails, try to copy auto buffer = std::make_unique(length); - ORT_RETURN_IF_ERROR(env.ReadFileIntoBuffer( - file_path, offset, length, gsl::make_span(buffer.get(), length))); + ORT_RETURN_IF_ERROR( + env.ReadFileIntoBuffer(file_path.native().c_str(), offset, length, gsl::make_span(buffer.get(), length))); deleter = OrtCallback{DeleteCharArray, buffer.get()}; raw_buffer = buffer.release(); @@ -914,20 +985,19 @@ static Status GetFileContent( } #endif -Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter) { +Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, + SafeInt& ext_data_len, OrtCallback& ext_data_deleter) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; - if (model_path != nullptr) { + if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); } - const ORTCHAR_T* t_prot_dir_s = tensor_proto_dir.size() == 0 ? nullptr : tensor_proto_dir.c_str(); std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; - ORT_RETURN_IF_ERROR(GetExternalDataInfo(tensor_proto, t_prot_dir_s, external_data_file_path, file_offset, - raw_data_safe_len)); + ORT_RETURN_IF_ERROR( + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data @@ -937,8 +1007,8 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, - "External initializer: ", tensor_proto.name(), - " offset: ", file_offset, " size to read: ", static_cast(raw_data_safe_len), + "External initializer: ", tensor_proto.name(), " offset: ", file_offset, + " size to read: ", static_cast(raw_data_safe_len), " are out of bounds or can not be read in full (>4GB)."); auto buffer = std::make_unique(raw_data_safe_len); @@ -969,7 +1039,8 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, } try { - // Copy the file data (fileData,offset,length) into WebAssembly memory (HEAPU8,buffer,length). + // Copy the file data (fileData,offset,length) into WebAssembly memory + // (HEAPU8,buffer,length). HEAPU8.set(fileData.subarray(offset, offset + length), buffer); return 0; } catch { @@ -996,22 +1067,19 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, default: err_msg = "Unknown error occurred in memory copy."; } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", external_data_file_path, "\", error: ", err_msg); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", external_data_file_path, + "\", error: ", err_msg); #else - size_t file_length; - // error reporting is inconsistent across platforms. Make sure the full path we attempted to open is included. - auto status = env.GetFileLength(external_data_file_path.c_str(), file_length); - if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileLength for ", ToUTF8String(external_data_file_path), - " failed:", status.ErrorMessage()); - } + // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to + // manually check file size first. + std::uintmax_t file_length = std::filesystem::file_size(external_data_file_path); SafeInt end_of_read(file_offset); end_of_read += raw_data_safe_len; - ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow(file_length), - "External initializer: ", tensor_proto.name(), - " offset: ", file_offset, " size to read: ", static_cast(raw_data_safe_len), - " given file_length: ", file_length, " are out of bounds or can not be read in full."); + ORT_RETURN_IF(file_offset < 0 || static_cast(end_of_read) > file_length, + "External initializer: ", tensor_proto.name(), " offset: ", file_offset, + " size to read: ", static_cast(raw_data_safe_len), " given file_length: ", file_length, + " are out of bounds or can not be read in full."); ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path.c_str(), file_offset, raw_data_safe_len, ext_data_buf, ext_data_deleter)); ext_data_len = raw_data_safe_len; @@ -1021,11 +1089,10 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, return Status::OK(); } -#define CASE_PROTO(X, Y) \ - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ - ORT_RETURN_IF_ERROR( \ - UnpackTensor(tensor_proto, raw_data, raw_data_len, \ - (Y*)preallocated, static_cast(tensor_size))); \ +#define CASE_PROTO(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + ORT_RETURN_IF_ERROR( \ + UnpackTensor(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast(tensor_size))); \ break; /** @@ -1036,15 +1103,15 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, * @param tensor pre-allocated tensor object, where we store the data * @return */ -Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor) { +Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor) { // Validate tensor compatibility TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); if (tensor_shape != tensor.Shape()) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "TensorProtoToTensor() tensor shape mismatch!"); } - const DataTypeImpl* const source_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + const DataTypeImpl* const source_type = + DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); if (source_type->Size() > tensor.DataType()->Size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TensorProto type ", DataTypeImpl::ToString(source_type), " can not be written into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); @@ -1125,15 +1192,13 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, return Status::OK(); } -Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - const MemBuffer& m, OrtValue& value) { +Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, OrtValue& value) { return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, &m, nullptr, value); } -Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - AllocatorPtr alloc, OrtValue& value) { +Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, AllocatorPtr alloc, OrtValue& value) { return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, nullptr, alloc, value); } @@ -1177,11 +1242,6 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto } ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name) { - // Given we are using the raw_data field in the protobuf, this will work only for little-endian format. - if constexpr (endian::native != endian::little) { - ORT_THROW("Big endian not supported"); - } - // Set name, dimensions, type, and data of the TensorProto. ONNX_NAMESPACE::TensorProto tensor_proto; @@ -1200,14 +1260,14 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: *mutable_string_data->Add() = *f; } } else { - tensor_proto.set_raw_data(tensor.DataRaw(), tensor.SizeInBytes()); + utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } return tensor_proto; } common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name) { ORT_RETURN_IF_NOT(node.attribute_size() > 0, "Constant node: ", node.name(), " has no data attributes"); @@ -1255,8 +1315,8 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n ORT_UNUSED_PARAMETER(model_path); #endif default: - ORT_THROW("Unsupported attribute value type of ", constant_attribute.type(), - " in 'Constant' node '", node.name(), "'"); + ORT_THROW("Unsupported attribute value type of ", constant_attribute.type(), " in 'Constant' node '", node.name(), + "'"); } // set name last in case attribute type was tensor (would copy over name) @@ -1266,7 +1326,7 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n } common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor) { return ConstantNodeProtoToTensorProto(node, model_path, tensor, node.output(0)); } @@ -1274,9 +1334,11 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n #if !defined(DISABLE_SPARSE_TENSORS) static Status CopySparseData(size_t n_sparse_elements, const ONNX_NAMESPACE::TensorProto& indices, - const Path& model_path, - gsl::span dims, - std::function copier) { + const std::filesystem::path& model_path, + gsl::span + dims, + std::function + copier) { Status status = Status::OK(); TensorShape indices_shape(indices.dims().data(), indices.dims().size()); const auto elements = narrow(indices_shape.Size()); @@ -1293,7 +1355,8 @@ static Status CopySparseData(size_t n_sparse_elements, ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); indices_data = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); } else { - ORT_RETURN_IF_NOT(indices.int64_data_size() == static_cast(elements), "Sparse indices int64 data size does not match expected"); + ORT_RETURN_IF_NOT(indices.int64_data_size() == static_cast(elements), + "Sparse indices int64 data size does not match expected"); indices_data = gsl::make_span(indices.int64_data().data(), elements); } break; @@ -1307,7 +1370,8 @@ static Status CopySparseData(size_t n_sparse_elements, unpack_buffer.clear(); unpack_buffer.shrink_to_fit(); } else { - ORT_RETURN_IF_NOT(indices.int32_data_size() == static_cast(elements), "Sparse indices int32 data size does not match expected"); + ORT_RETURN_IF_NOT(indices.int32_data_size() == static_cast(elements), + "Sparse indices int32 data size does not match expected"); indices_values.insert(indices_values.cend(), indices.int32_data().cbegin(), indices.int32_data().cend()); } indices_data = gsl::make_span(indices_values); @@ -1346,8 +1410,9 @@ static Status CopySparseData(size_t n_sparse_elements, break; } default: - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, - "Invalid SparseTensor indices. Should one of the following types: int8, int16, int32 or int64"); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_GRAPH, + "Invalid SparseTensor indices. Should one of the following types: int8, int16, int32 or int64"); } if (indices_shape.NumDimensions() == 1) { @@ -1385,15 +1450,15 @@ static Status CopySparseData(size_t n_sparse_elements, ORT_ENFORCE(cur_index == indices_data.end()); } else { - status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Invalid SparseTensor indices. Should be rank 0 or 1. Got:", - indices_shape); + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Invalid SparseTensor indices. Should be rank 0 or 1. Got:", indices_shape); } return status; } common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& dense) { Status status = Status::OK(); @@ -1434,59 +1499,50 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT switch (element_size) { case 1: { status = CopySparseData( - n_sparse_elements, - indices, model_path, dims, - [sparse_data, dense_data](size_t from_idx, size_t to_idx) { + n_sparse_elements, indices, model_path, dims, [sparse_data, dense_data](size_t from_idx, size_t to_idx) { static_cast(dense_data)[to_idx] = static_cast(sparse_data)[from_idx]; }); break; } case 2: { - status = CopySparseData( - n_sparse_elements, - indices, model_path, dims, - [sparse_data, dense_data](size_t from_idx, size_t to_idx) { - const auto* src = static_cast(sparse_data) + from_idx; - auto* dst = static_cast(dense_data) + to_idx; - memcpy(dst, src, sizeof(uint16_t)); - }); + status = CopySparseData(n_sparse_elements, indices, model_path, dims, + [sparse_data, dense_data](size_t from_idx, size_t to_idx) { + const auto* src = static_cast(sparse_data) + from_idx; + auto* dst = static_cast(dense_data) + to_idx; + memcpy(dst, src, sizeof(uint16_t)); + }); break; } case 4: { - status = CopySparseData( - n_sparse_elements, - indices, model_path, dims, - [sparse_data, dense_data](size_t from_idx, size_t to_idx) { - const auto* src = static_cast(sparse_data) + from_idx; - auto* dst = static_cast(dense_data) + to_idx; - memcpy(dst, src, sizeof(uint32_t)); - }); + status = CopySparseData(n_sparse_elements, indices, model_path, dims, + [sparse_data, dense_data](size_t from_idx, size_t to_idx) { + const auto* src = static_cast(sparse_data) + from_idx; + auto* dst = static_cast(dense_data) + to_idx; + memcpy(dst, src, sizeof(uint32_t)); + }); break; } case 8: { - status = CopySparseData( - n_sparse_elements, - indices, model_path, dims, - [sparse_data, dense_data](size_t from_idx, size_t to_idx) { - const auto* src = static_cast(sparse_data) + from_idx; - auto* dst = static_cast(dense_data) + to_idx; - memcpy(dst, src, sizeof(uint64_t)); - }); + status = CopySparseData(n_sparse_elements, indices, model_path, dims, + [sparse_data, dense_data](size_t from_idx, size_t to_idx) { + const auto* src = static_cast(sparse_data) + from_idx; + auto* dst = static_cast(dense_data) + to_idx; + memcpy(dst, src, sizeof(uint64_t)); + }); break; } default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Element_size of: ", element_size, " is not supported.", " type: ", type); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Element_size of: ", element_size, " is not supported.", + " type: ", type); } ORT_RETURN_IF_ERROR(status); } - dense.set_raw_data(std::move(dense_data_storage)); - + utils::SetRawDataInTensorProto(dense, std::move(dense_data_storage)); } else { // No request for std::string status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported sparse tensor data type of ", @@ -1521,9 +1577,7 @@ inline void CopyElement(void* dst, const void* src, int64_t dst_index, } template -static void SetIndices(gsl::span gathered_indices, - std::string& raw_indices, - TensorProto& indices) { +static void SetIndices(gsl::span gathered_indices, std::string& raw_indices, TensorProto& indices) { raw_indices.resize(gathered_indices.size() * sizeof(T)); auto* ind_dest = reinterpret_cast(raw_indices.data()); size_t dest_index = 0; @@ -1533,7 +1587,17 @@ static void SetIndices(gsl::span gathered_indices, } else { auto* dst = ind_dest + dest_index; T v = static_cast(src_index); - memcpy(dst, &v, sizeof(T)); + if constexpr (endian::native != endian::little) { + auto src = gsl::make_span(static_cast( + reinterpret_cast(&v)), + sizeof(T)); + auto dest = gsl::make_span(static_cast( + reinterpret_cast(dst)), + sizeof(T)); + onnxruntime::utils::SwapByteOrderCopy(sizeof(T), src, dest); + } else { + memcpy(dst, &v, sizeof(T)); + } } ++dest_index; } @@ -1541,8 +1605,7 @@ static void SetIndices(gsl::span gathered_indices, } static void SparsifyGeneric(const void* dense_raw_data, size_t n_dense_elements, size_t element_size, - IsZeroFunc is_zero, CopyElementFunc copy, - TensorProto& values, TensorProto& indices, + IsZeroFunc is_zero, CopyElementFunc copy, TensorProto& values, TensorProto& indices, size_t& nnz) { auto advance = [element_size](const void* start, size_t elements) -> const void* { return (reinterpret_cast(start) + elements * element_size); @@ -1585,13 +1648,13 @@ static void SparsifyGeneric(const void* dense_raw_data, size_t n_dense_elements, } } else { indices.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); - indices.set_raw_data(std::string()); + utils::SetRawDataInTensorProto(indices, std::string()); } nnz = gathered_indices.size(); } common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& dense_proto, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::SparseTensorProto& result) { ORT_ENFORCE(HasDataType(dense_proto), "Must have a valid data type"); @@ -1623,28 +1686,28 @@ common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& void* dense_data = dense_raw_data.data(); switch (element_size) { case 1: { - SparsifyGeneric(dense_data, n_dense_elements, element_size, - IsZero, CopyElement, values, indices, nnz); + SparsifyGeneric(dense_data, n_dense_elements, element_size, IsZero, CopyElement, values, + indices, nnz); break; } case 2: { - SparsifyGeneric(dense_data, n_dense_elements, element_size, - IsZero, CopyElement, values, indices, nnz); + SparsifyGeneric(dense_data, n_dense_elements, element_size, IsZero, CopyElement, values, + indices, nnz); break; } case 4: { - SparsifyGeneric(dense_data, n_dense_elements, element_size, - IsZero, CopyElement, values, indices, nnz); + SparsifyGeneric(dense_data, n_dense_elements, element_size, IsZero, CopyElement, values, + indices, nnz); break; } case 8: { - SparsifyGeneric(dense_data, n_dense_elements, element_size, - IsZero, CopyElement, values, indices, nnz); + SparsifyGeneric(dense_data, n_dense_elements, element_size, IsZero, CopyElement, values, + indices, nnz); break; } default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Element_size of: ", element_size, " is not supported.", " data_type: ", data_type); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Element_size of: ", element_size, " is not supported.", + " data_type: ", data_type); } // Fix up shapes @@ -1664,42 +1727,40 @@ template common::Status GetSizeInBytesFromTensorProto(const ONN size_t* out); template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); -#define CASE_UNPACK(TYPE, ELEMENT_TYPE, DATA_SIZE) \ - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ - SafeInt tensor_byte_size; \ - size_t element_count = 0; \ - if (initializer.has_raw_data()) { \ - tensor_byte_size = initializer.raw_data().size(); \ - element_count = tensor_byte_size / sizeof(ELEMENT_TYPE); \ - } else { \ - element_count = initializer.DATA_SIZE(); \ - tensor_byte_size = element_count * sizeof(ELEMENT_TYPE); \ - } \ - unpacked_tensor.resize(tensor_byte_size); \ - return onnxruntime::utils::UnpackTensor( \ - initializer, \ - initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ - initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ - reinterpret_cast(unpacked_tensor.data()), element_count); \ - break; \ - } - -#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \ - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ - TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ - size_t element_count = static_cast(tensor_shape.Size()); \ - size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \ - unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ - return onnxruntime::utils::UnpackTensor( \ - initializer, \ - initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ - initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ - reinterpret_cast(unpacked_tensor.data()), element_count); \ - break; \ +#define CASE_UNPACK(TYPE, ELEMENT_TYPE, DATA_SIZE) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ + SafeInt tensor_byte_size; \ + size_t element_count = 0; \ + if (initializer.has_raw_data()) { \ + tensor_byte_size = initializer.raw_data().size(); \ + element_count = tensor_byte_size / sizeof(ELEMENT_TYPE); \ + } else { \ + element_count = initializer.DATA_SIZE(); \ + tensor_byte_size = element_count * sizeof(ELEMENT_TYPE); \ + } \ + unpacked_tensor.resize(tensor_byte_size); \ + return onnxruntime::utils::UnpackTensor(initializer, \ + initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ + initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ + reinterpret_cast(unpacked_tensor.data()), element_count); \ + break; \ + } + +#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ + TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ + size_t element_count = static_cast(tensor_shape.Size()); \ + size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \ + unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ + return onnxruntime::utils::UnpackTensor(initializer, \ + initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ + initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ + reinterpret_cast(unpacked_tensor.data()), element_count); \ + break; \ } Status UnpackInitializerData(const onnx::TensorProto& initializer, - const Path& model_path, + const std::filesystem::path& model_path, std::vector& unpacked_tensor) { // TODO, if std::vector does not use a custom allocator, the default std::allocator will // allocation the memory aligned to std::max_align_t, need look into allocating @@ -1707,7 +1768,7 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, if (initializer.data_location() == TensorProto_DataLocation_EXTERNAL) { ORT_RETURN_IF_ERROR(ReadExternalDataForTensor( initializer, - (model_path.IsEmpty() || model_path.ParentPath().IsEmpty()) ? nullptr : model_path.ParentPath().ToPathString().c_str(), + model_path.parent_path(), unpacked_tensor)); return Status::OK(); } @@ -1737,16 +1798,14 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, default: break; } - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Unsupported type: ", initializer.data_type()); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported type: ", initializer.data_type()); } #undef CASE_UNPACK -Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, - std::vector& unpacked_tensor) { +Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, std::vector& unpacked_tensor) { ORT_RETURN_IF(initializer.data_location() == TensorProto_DataLocation_EXTERNAL, "The given initializer contains external data"); - return UnpackInitializerData(initializer, Path(), unpacked_tensor); + return UnpackInitializerData(initializer, std::filesystem::path(), unpacked_tensor); } } // namespace utils diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 000502ba47594..aabfc0487f3e0 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -5,10 +5,11 @@ #include #include +#include +#include #ifndef SHARED_PROVIDER #include "core/common/common.h" -#include "core/common/path.h" #include "core/common/status.h" #include "core/common/safeint.h" #include "core/framework/endian_utils.h" @@ -19,6 +20,46 @@ #include "core/graph/onnx_protobuf.h" #include "core/platform/env.h" +namespace onnxruntime { +namespace utils { +/** + * This function is used to convert the endianess of Tensor data. + * Mostly, will be used in big endian system to support the model file + * generated on little endian system. + * @param initializer given initializer tensor + * @returns None + */ +void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* initializer); + +/** + * Wrapper function for set_raw_data. + * First calls the set_raw_data and then calls ConvertRawDataInTensorProto + * under big endian system. + * @param tensor_proto given initializer tensor + * @param raw_data source raw_data pointer + * @param raw_data_len length of raw_data + * @returns None + */ +template +void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, T1* raw_data, T2 raw_data_len) { + using namespace ONNX_NAMESPACE; + tensor_proto.set_raw_data(raw_data, raw_data_len); + if constexpr (endian::native != endian::little) { + utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&tensor_proto); + } +} + +/** + * Overload Wrapper function for set_raw_data handling string object. + * Forward the string object to set_raw_data. + * @param tensor_proto given initializer tensor + * @param param string object reference + * @returns None + */ +void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param); +} // namespace utils +} // namespace onnxruntime + namespace ONNX_NAMESPACE { class TensorProto; class TensorShapeProto; @@ -41,23 +82,23 @@ TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& ten /** * deserialize a TensorProto into a preallocated memory buffer on CPU. * \param tensor_proto_path A local file path of where the 'input' was loaded from. - * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * Can be empty if the tensor proto doesn't have external data or it was loaded from * the current working dir. This path could be either a relative path or an absolute path. * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. */ -common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, +common::Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& tensor_proto_path, const ONNX_NAMESPACE::TensorProto& input, const MemBuffer& m, OrtValue& value); /** * deserialize a TensorProto into a buffer on CPU allocated using 'alloc'. * \param tensor_proto_path A local file path of where the 'input' was loaded from. - * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * Can be empty if the tensor proto doesn't have external data or it was loaded from * the current working dir. This path could be either a relative path or an absolute path. * \param alloc Allocator to use for allocating the buffer. Must allocate CPU based memory. * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. */ -common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, +common::Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& tensor_proto_path, const ONNX_NAMESPACE::TensorProto& input, AllocatorPtr alloc, OrtValue& value); @@ -69,7 +110,7 @@ common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_pro * @param tensorp destination empty tensor * @return */ -common::Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, +common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); @@ -100,7 +141,7 @@ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_A // Given a tensor proto with external data obtain a pointer to the data and its length. // The ext_data_deleter argument is updated with a callback that owns/releases the data. -common::Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, +common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter); @@ -113,11 +154,11 @@ common::Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_ // model_path is used for contructing full path for external_data // tensor_name specifies the name for the new TensorProto TensorProto common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name); common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor); #if !defined(DISABLE_SPARSE_TENSORS) @@ -126,7 +167,7 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n // The resulting TensorProto will contain the data as raw data. // model_path is used for contructing full path for external_data common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& dense); #if !defined(ORT_MINIMAL_BUILD) @@ -135,7 +176,7 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT // The resulting SparseTensorProto will contain the data as raw data // model_path is used for contructing full path for external_data common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& dense, - const Path& model_path, + const std::filesystem::path& model_path, ONNX_NAMESPACE::SparseTensorProto& sparse); #endif // !ORT_MINIMAL_BUILD #endif // !defined(DISABLE_SPARSE_TENSORS) @@ -446,7 +487,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d // Uses the model path to construct the full path for loading external data. In case when model_path is empty // it uses current directory. template -Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const std::filesystem::path& model_path, /*out*/ T* p_data, size_t expected_size); /** @@ -458,7 +499,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model * @returns Status::OK() if data is unpacked successfully */ common::Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, - const Path& model_path, + const std::filesystem::path& model_path, std::vector& unpacked_tensor); /** diff --git a/onnxruntime/core/framework/transpose_helper.h b/onnxruntime/core/framework/transpose_helper.h index c34d5ef3f27f6..e33044117f89a 100644 --- a/onnxruntime/core/framework/transpose_helper.h +++ b/onnxruntime/core/framework/transpose_helper.h @@ -37,7 +37,7 @@ We fall back to the default implementation in all other cases, and if the input #include "core/framework/tensor.h" #include "core/platform/threadpool.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& from, size_t& to); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9c282210d2169..9eed0249711f9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -61,6 +61,7 @@ bool ProviderIsCpuBased(const std::string& provider_type) { provider_type == onnxruntime::kVitisAIExecutionProvider || provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || + provider_type == onnxruntime::kVSINPUExecutionProvider || provider_type == onnxruntime::kAclExecutionProvider || provider_type == onnxruntime::kArmNNExecutionProvider || provider_type == onnxruntime::kRknpuExecutionProvider || diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 2a14ba1db4bb7..7272a949f7218 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1254,7 +1254,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present_value", "Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).", "T") - .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { SparseAttentionTypeAndShapeInference(ctx, 3); diff --git a/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc b/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc index be39a2d7ec2b2..b1b7cf346a27c 100644 --- a/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc +++ b/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc @@ -395,7 +395,7 @@ ONNX_CONTRIB_OPERATOR_SET_SCHEMA( const auto input_rank = input_shape.dim_size(); if (input_rank != 4) fail_shape_inference("Input's shape must be 4-D"); - // parse necessary attributes for futher processing + // parse necessary attributes for further processing std::vector border; bool border_present = getRepeatedAttribute(ctx, "border", border); if (!border_present || border.size() != 4) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 0c1d79532f120..442a0db933d65 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -11,7 +11,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" @@ -98,7 +98,8 @@ static Status MergeShapeInfo(const std::string& output_name, #endif } ORT_CATCH(const ONNX_NAMESPACE::InferenceError& ex) { - // if this model was not created with the latest onnx version, allow the shape inferencing failure (strict == false). + // if this model was not created with the latest onnx version, allow the shape inferencing failure + // (strict == false). // we do this to have strict testing of the latest inferencing to detect bugs, but lenient shape inferencing for // older models in case later changes to the ONNX shape inferencing or ORT break them. if (!strict) { @@ -114,7 +115,8 @@ static Status MergeShapeInfo(const std::string& output_name, } #if !defined(DISABLE_OPTIONAL_TYPE) else if (utils::HasOptionalTensorType(source)) { - ONNX_NAMESPACE::UnionShapeInfo(utils::GetShape(source), *utils::GetMutableOptionalTypeProto(target)->mutable_tensor_type()); + ONNX_NAMESPACE::UnionShapeInfo(utils::GetShape(source), + *utils::GetMutableOptionalTypeProto(target)->mutable_tensor_type()); } #endif @@ -401,7 +403,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto& input_tensor_elem_type = input_tensor_type.elem_type(); const auto& current_tensor_elem_type = current_type.tensor_type().elem_type(); - ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types)); + ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, + override_types)); if (utils::HasShape(input_tensor_type)) { if (utils::HasShape(current_type)) { @@ -420,7 +423,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto input_tensor_elem_type = input_tensor_type.elem_type(); const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type(); - ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types)); + ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, + override_types)); if (utils::HasShape(input_tensor_type)) { if (utils::HasShape(current_type)) { @@ -440,7 +444,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu // Check for homogeneity within optional type if (is_input_type_optional_tensor_type != is_current_type_optional_tensor_type) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Optional Type mismatch. Expected: ", ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(current_type), + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Optional Type mismatch. Expected: ", + ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(current_type), " . Got: ", ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(input_type)); } @@ -453,7 +458,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto& input_tensor_elem_type = input_tensor_type.elem_type(); const auto& current_tensor_elem_type = optional_current_type.tensor_type().elem_type(); - ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types)); + ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, + override_types)); if (utils::HasShape(optional_input_type.tensor_type())) { if (utils::HasShape(optional_current_type.tensor_type())) { @@ -560,7 +566,7 @@ void Node::SetPriority(int priority) noexcept { priority_ = priority; } -const Path& Node::ModelPath() const noexcept { +const std::filesystem::path& Node::ModelPath() const noexcept { return graph_->ModelPath(); } @@ -1193,6 +1199,15 @@ Graph::Graph(const Model& owning_model, const gsl::not_null tensor{graph_proto_->add_initializer()}; auto status = utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor); + if constexpr (endian::native != endian::little) { + const AttributeProto& attrib = node.attribute(0); + if (attrib.type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + const TensorProto& sparse_values = node.attribute(0).sparse_tensor().values(); + if ((!(sparse_values.has_raw_data())) && tensor->has_raw_data()) { + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); + } + } + } ORT_ENFORCE(status.IsOK(), status.ToString()); // Ensure initializers are also graph inputs. if (ir_version_ < 4) { @@ -1203,7 +1218,8 @@ Graph::Graph(const Model& owning_model, #if !defined(DISABLE_SPARSE_TENSORS) if (node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { auto p = sparse_tensor_names_.emplace(tensor->name()); - ORT_ENFORCE(p.second, "Duplicate constant node sparse initializer name: '", tensor->name(), "' Model is invalid."); + ORT_ENFORCE(p.second, "Duplicate constant node sparse initializer name: '", tensor->name(), + "' Model is invalid."); } #endif } @@ -1533,8 +1549,10 @@ void Graph::AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_ *dst_arg_pointer = src_arg; } - nodes_[src_node_index]->MutableRelationships().output_edges.insert(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot)); - nodes_[dst_node_index]->MutableRelationships().input_edges.insert(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot)); + nodes_[src_node_index]->MutableRelationships().output_edges.insert(Node::EdgeEnd(*nodes_[dst_node_index], + src_arg_slot, dst_arg_slot)); + nodes_[dst_node_index]->MutableRelationships().input_edges.insert(Node::EdgeEnd(*nodes_[src_node_index], + src_arg_slot, dst_arg_slot)); } void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_slot, int dst_arg_slot) { @@ -1573,13 +1591,15 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s ORT_THROW("Argument mismatch when removing edge."); } - nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot)); - nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot)); + nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], + src_arg_slot, dst_arg_slot)); + nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], + src_arg_slot, dst_arg_slot)); } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) -GSL_SUPPRESS(es.84) // ignoring return value from unordered_map::insert causes noisy complaint +GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint Status Graph::BuildConnections(std::unordered_set& outer_scope_node_args_consumed) { // recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs if (!resolve_context_.nodes_with_subgraphs.empty()) { @@ -2356,7 +2376,7 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, #endif // ENABLE_TRAINING -GSL_SUPPRESS(es.84) // noisy warning about ignoring return value from insert(...) +GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...) Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { nodes_in_topological_order_.clear(); std::unordered_set downstream_nodes; // nodes downstream of the node we're currently checking @@ -2429,7 +2449,8 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { const NodeIndex idx = iter->Index(); // the input to this node is also downstream of this node if (downstream_nodes.find(idx) != downstream_nodes.end()) { - Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, "This is an invalid model. Error: the graph is not acyclic."); + Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, + "This is an invalid model. Error: the graph is not acyclic."); return status; } @@ -2444,7 +2465,8 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { return Status::OK(); } - return Status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, "This is an invalid model. Error: the graph is not acyclic."); + return Status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, + "This is an invalid model. Error: the graph is not acyclic."); } bool FullyDefinedType(const TypeProto& type_proto) { @@ -2487,11 +2509,13 @@ bool FullyDefinedType(const TypeProto& type_proto) { // parameters are the Graph instance for the subgraph, the input types from the control flow node that contains // the subgraph, and the vector to write the output from the inferencing. using SubgraphInferencingFunc = - std::function&, std::vector&, const Graph::ResolveOptions&)>; + std::function&, std::vector&, + const Graph::ResolveOptions&)>; class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer { public: - GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, const Graph::ResolveOptions& options) + GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, + const Graph::ResolveOptions& options) : node_(node), graph_(graph), inferencing_func_(inferencing_func), options_(options) { } @@ -2720,7 +2744,7 @@ Status Graph::UpdateShapeInference(Node& node) { } // Implementation of type-inference and type-checking for a single node -GSL_SUPPRESS(f.23) // spurious warning about inferred_type never being checked for null +GSL_SUPPRESS(f .23) // spurious warning about inferred_type never being checked for null Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const ResolveOptions& options) { auto& node_name = node.Name(); @@ -2796,7 +2820,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // The type-parameter T is bound to different values for different inputs. Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, "Type Error: Type parameter (" + op_formal_parameter.GetTypeStr() + - ") of Optype (" + op.Name() + ") bound to different types (" + *(param_to_type_iter->second) + + ") of Optype (" + op.Name() + ") bound to different types (" + + *(param_to_type_iter->second) + " and " + *(input_def->Type()) + " in node (" + node_name + ")."); return status; @@ -2930,7 +2955,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso *merge_target.mutable_sparse_tensor_type()->mutable_shape() = *output_def->Shape(); } #endif - auto status = MergeShapeInfo(output_def->Name(), onnx_inferred_type, merge_target, strict_shape_type_inference_, logger_); + auto status = MergeShapeInfo(output_def->Name(), onnx_inferred_type, merge_target, + strict_shape_type_inference_, logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node_name, " ", status.ErrorMessage()); } @@ -3025,7 +3051,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { ctx.set_opset_imports(DomainToVersionMap()); ctx.set_schema_registry(schema_registry_.get()); // Set the parent directory of model path to load external tensors if exist - ctx.set_model_dir(ToUTF8String(ModelPath().ParentPath().ToPathString())); + // ONNX expects a UTF-8 string here. + ctx.set_model_dir(ToUTF8String(ModelPath().parent_path().native())); LexicalScopeContext parent; if (parent_node_) { @@ -3175,7 +3202,7 @@ Status Graph::VerifyInputAndInitializerNames() { } for (auto& initializer_pair : name_to_initial_tensor_) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) inputs_and_initializers.insert(initializer_pair.first); // Initializers are expected to be included in inputs (according to ONNX spec). // onnxruntime relaxes this constraint. No duplicate-name check here. @@ -3370,7 +3397,7 @@ const std::string& Graph::Description() const noexcept { return graph_proto_->doc_string(); } -const Path& Graph::ModelPath() const { +const std::filesystem::path& Graph::ModelPath() const { return owning_model_.ModelPath(); } @@ -3411,7 +3438,8 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) - ORT_ENFORCE(sparse_tensor_names_.count(tensor_name) == 0, "sparse_tensor_names_ not in sync with name_to_initial_tensor_"); + ORT_ENFORCE(sparse_tensor_names_.count(tensor_name) == 0, + "sparse_tensor_names_ not in sync with name_to_initial_tensor_"); #endif } @@ -3447,7 +3475,8 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi return true; }; - ORT_RETURN_IF_NOT(!is_external || utils::HasExternalData(old_initializer), "Trying to replace non-external initializer with external data"); + ORT_RETURN_IF_NOT(!is_external || utils::HasExternalData(old_initializer), + "Trying to replace non-external initializer with external data"); ORT_RETURN_IF_NOT(dims_eq(), "Replacement tensor's dimensions do not match."); ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(), @@ -3525,7 +3554,8 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( ORT_ENFORCE(existing_entry != mutable_initializers.pointer_end(), "graph_proto_ is not in sync with name_to_initial_tensor_"); (**existing_entry).clear_data_location(); - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); + const DataTypeImpl* const type = + DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); auto tensor = Tensor(type, tensor_shape, tensor_buffer, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); @@ -3695,6 +3725,12 @@ SaveInputsOutputsToOrtFormat(flatbuffers::FlatBufferBuilder& builder, const std: common::Status Graph::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& fbs_graph) const { + if constexpr (endian::native != endian::little) { + auto& tens = GetAllInitializedTensors(); + for (auto& [name, tensor_p] : tens) { + utils::ConvertRawDataInTensorProto(const_cast(tensor_p)); + } + } auto inputs = SaveInputsOutputsToOrtFormat(builder, graph_inputs_including_initializers_); auto outputs = SaveInputsOutputsToOrtFormat(builder, graph_outputs_); @@ -3971,21 +4007,17 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { return result; } -ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, - const PathString& destination_file_path, +ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, + const std::filesystem::path& model_file_path, size_t initializer_size_threshold) const { GraphProto result; ToGraphProtoInternal(result); + ORT_ENFORCE(external_file_path.is_relative()); + // If model_file_path is just a file name without a path separator, for example: "model.onnx". Its parent path could + // be empty. Else, save external data file in same directory as the model. + const std::filesystem::path modified_external_file_path = model_file_path.parent_path() / external_file_path; - Path parent_path = Path::Parse(destination_file_path).ParentPath(); - Path external_file_path = Path::Parse(ToPathString(external_file_name)); - // Check if parent_path is relative path (length = 0) - if (parent_path.ToPathString().length()) { - // Save external data file in same directory as model - external_file_path = parent_path.Append(external_file_path); - } - - std::ofstream external_stream(external_file_path.ToPathString(), std::ofstream::out | std::ofstream::binary); + std::ofstream external_stream(modified_external_file_path, std::ofstream::out | std::ofstream::binary); ORT_ENFORCE(external_stream.is_open()); int64_t external_offset = 0; @@ -4022,7 +4054,7 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data(); location->set_key("location"); - location->set_value(external_file_name); + location->set_value(ToUTF8String(external_file_path.native())); ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data(); offset->set_key("offset"); offset->set_value(std::to_string(external_offset)); @@ -4249,7 +4281,7 @@ void Graph::ComputeOverridableInitializers() { #if !defined(ORT_MINIMAL_BUILD) -GSL_SUPPRESS(es.84) // warning about ignoring return value from insert(...) +GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...) Status Graph::SetGraphInputsOutputs() { // If loaded from a model file, we start from the specified inputs and // outputs set earlier by InitializeStateFromModelFileGraphProto(). @@ -4435,7 +4467,7 @@ Status Graph::PopulateNodeArgToProducerConsumerLookupsFromNodes() { } // calling private ctor -GSL_SUPPRESS(r.11) +GSL_SUPPRESS(r .11) gsl::not_null Graph::AllocateNode() { ORT_ENFORCE(nodes_.size() < static_cast(std::numeric_limits::max())); std::unique_ptr new_node(new Node(nodes_.size(), *this)); @@ -4448,7 +4480,7 @@ gsl::not_null Graph::AllocateNode() { return gsl::not_null{node}; } -// TODO: Does this need (and maybe AllocateNode) to be threadsafe so nodes_ and num_of_nodes_ managed more carefully? +// TODO(s): Does this need (and maybe AllocateNode) to be threadsafe so nodes_ and num_of_nodes_ managed more carefully? bool Graph::ReleaseNode(NodeIndex index) { if (index >= nodes_.size()) { return false; @@ -4565,13 +4597,13 @@ void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_n auto dst_idx = input_edge.GetDstArgIndex(); // if this input is an input of the fused node add an edge for that - if (dst_idx < (int)node->InputDefs().size()) { + if (dst_idx < static_cast(node->InputDefs().size())) { auto it = input_indexes.find(node->InputDefs()[dst_idx]->Name()); if (it != input_indexes.cend()) { AddEdge(producer_idx, new_node_idx, src_idx, it->second); } } else { - int dst_implicit_input_idx = dst_idx - (int)node->InputDefs().size(); + int dst_implicit_input_idx = dst_idx - static_cast(node->InputDefs().size()); ORT_ENFORCE(dst_implicit_input_idx < (int)node->ImplicitInputDefs().size()); auto it = input_indexes.find(node->ImplicitInputDefs()[dst_implicit_input_idx]->Name()); if (it != input_indexes.cend()) { @@ -5015,7 +5047,8 @@ Status Graph::InlineFunction(Node& callnode) { // Remove output edges. Requirement for RemoveNode() below. auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator for (const auto& output_edge : output_edges) { - RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); + RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), + output_edge.GetDstArgIndex()); } // create a uniq_identifier to append to every node name and intermediate input\outputs diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 7dfdba687517f..922759b02e75f 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -28,7 +28,7 @@ SaveDims(flatbuffers::FlatBufferBuilder& builder, const DimsFieldType& dims) { Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, const TensorProto& initializer, - const Path& model_path, + const std::filesystem::path& model_path, flatbuffers::Offset& fbs_tensor, const ExternalDataWriter& external_writer) { auto name = SaveStringToOrtFormat(builder, initializer.has_name(), initializer.name()); @@ -85,7 +85,7 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, #if !defined(DISABLE_SPARSE_TENSORS) Status SaveSparseInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::SparseTensorProto& initializer, - const Path& model_path, + const std::filesystem::path& model_path, flatbuffers::Offset& fbs_sparse_tensor) { // values const auto& values = initializer.values(); @@ -126,7 +126,7 @@ Status SaveSparseInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, Status SaveAttributeOrtFormat(flatbuffers::FlatBufferBuilder& builder, const AttributeProto& attr_proto, flatbuffers::Offset& fbs_attr, - const Path& model_path, + const std::filesystem::path& model_path, const onnxruntime::Graph* subgraph) { auto name = SaveStringToOrtFormat(builder, attr_proto.has_name(), attr_proto.name()); auto doc_string = SaveStringToOrtFormat(builder, attr_proto.has_doc_string(), attr_proto.doc_string()); diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h index 33eba34fbaff0..224d966500e18 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.h +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/flatbuffers.h" @@ -71,13 +72,13 @@ constexpr uint32_t kMinimumSizeForExternalData = 64; /// if the initializer contains kMinimumSizeForExternalData bytes or more, and not string data. Status SaveInitializerOrtFormat( flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::TensorProto& initializer, - const Path& model_path, flatbuffers::Offset& fbs_tensor, + const std::filesystem::path& model_path, flatbuffers::Offset& fbs_tensor, const ExternalDataWriter& external_writer = nullptr); #if !defined(DISABLE_SPARSE_TENSORS) Status SaveSparseInitializerOrtFormat( flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::SparseTensorProto& initializer, - const Path& model_path, flatbuffers::Offset& fbs_sparse_tensor); + const std::filesystem::path& model_path, flatbuffers::Offset& fbs_sparse_tensor); #endif // !defined(DISABLE_SPARSE_TENSORS) // Convert a given AttributeProto into fbs::Attribute @@ -86,7 +87,7 @@ Status SaveSparseInitializerOrtFormat( // instead of the GraphProto in attr_proto Status SaveAttributeOrtFormat( flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::AttributeProto& attr_proto, - flatbuffers::Offset& fbs_attr, const Path& model_path, + flatbuffers::Offset& fbs_attr, const std::filesystem::path& model_path, const onnxruntime::Graph* subgraph); /// diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index aefad28eb37e8..eb0fb22346f37 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -21,7 +21,21 @@ void GraphViewerToProto(const GraphViewer& graph_view, *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); } - for (const auto* value_info : graph_view.GetValueInfo()) { + std::unordered_set value_info_ = graph_view.GetValueInfo(); + + // Reserve memory for the vector to avoid reallocations + std::vector value_info_sorted; + value_info_sorted.reserve(value_info_.size()); + + value_info_sorted.assign(value_info_.begin(), value_info_.end()); + auto sort_predicate = [](const NodeArg* v1, const NodeArg* v2) { + return v1->Name() < v2->Name(); + }; + + // This ensures consistent ordering of value_info entries in the output graph + std::sort(value_info_sorted.begin(), value_info_sorted.end(), sort_predicate); + + for (const auto* value_info : value_info_sorted) { *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); } diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 13620f4d8b3bb..221bc01f5d15d 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -2,12 +2,14 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" - -#include - #include "core/graph/graph.h" #include "core/common/logging/logging.h" +#include +#include +#include +#include + namespace onnxruntime { namespace graph_utils { @@ -56,7 +58,8 @@ static bool CanUpdateImplicitInputNameInSubgraph(const Node& node, } for (auto& subgraph_node : subgraph->Nodes()) { - // recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels of nested + // recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels + // of nested // subgraphs, and at least one level lower uses removed_output_name as an implicit input const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs(); if (!subgraph_node_implicit_inputs.empty()) { @@ -464,13 +467,13 @@ static bool IsOnlyOneOutputUsed(const Graph& graph, const Node& node, const std: // a) there's only 1, and b) it's the same as any output consumed by another node auto output_indexes = graph.GetNodeOutputsInGraphOutputs(node); auto num_graph_outputs = output_indexes.size(); - if (num_graph_outputs > 1) + if (num_graph_outputs > 1) { return false; - else if (num_graph_outputs == 1) { - if (first_output != unassigned) + } else if (num_graph_outputs == 1) { + if (first_output != unassigned) { // an output is consumed by other nodes, so make sure the same output is providing the graph output return output_indexes.front() == first_output; - else { + } else { // graph output only as no other nodes are consuming the output, so just update the output_name output_name = &node.OutputDefs()[output_indexes.front()]->Name(); } @@ -678,7 +681,8 @@ const Node* FirstParentByType(const Node& node, const std::string& parent_type) return nullptr; } -void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx) { +void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, + int replacement_output_idx) { // get the output edges from node for output_idx std::vector output_edges = GraphEdge::GetNodeOutputEdges(node, output_idx); @@ -726,7 +730,9 @@ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input) { "Can only add a new input at the end of the current ones."); target.MutableInputDefs().push_back(&new_input); - assert(target.MutableInputArgsCount().size() > static_cast(target_input_idx)); // expect existing entry for all possible inputs + + // expect existing entry for all possible inputs + assert(target.MutableInputArgsCount().size() > static_cast(target_input_idx)); target.MutableInputArgsCount()[target_input_idx] = 1; } @@ -798,7 +804,8 @@ bool FindPath(const Node& node, bool is_input_edge, gsl::spanOpType() << "->" << edge.op_type; + LOGS(logger, WARNING) << "Failed since multiple edges matched:" << current_node->OpType() << "->" + << edge.op_type; return false; } edge_found = &(*it); @@ -821,7 +828,8 @@ bool FindPath(const Node& node, bool is_input_edge, gsl::span edges_to_match, std::vector>& result, const logging::Logger& logger) { +bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span edges_to_match, + std::vector>& result, const logging::Logger& logger) { result.clear(); std::vector edge_ends; @@ -830,9 +838,10 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span Node& { - return *graph.GetNode(edge_end->GetNode().Index()); - }); + std::transform(edge_ends.begin(), edge_ends.end(), std::back_inserter(result), + [&graph](const Node::EdgeEnd* edge_end) -> Node& { + return *graph.GetNode(edge_end->GetNode().Index()); + }); return true; } diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 319e055200cca..0b713196203d6 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -8,6 +8,9 @@ #include "core/graph/onnx_protobuf.h" #include "core/graph/graph.h" +#include +#include + namespace onnxruntime { namespace graph_utils { @@ -247,7 +250,8 @@ e.g. Node A produces outputs A1 and A2. to replace B1 (output index 0 for node B) with A2 (output index 1 for node A) as input to the downstream node C. The edge that existed between B and C for B1 will be removed, and replaced with an edge between A and C for A2. */ -void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx); +void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, + int replacement_output_idx); /** Replace the input to a node with a NodeArg. @remarks The replacement only updates the node's input definition and does not create any edges, @@ -302,11 +306,13 @@ inline void FinalizeNodeFusion(Graph& graph, The output definitions and edges from the last node in 'nodes' will be moved to replacement_node. All nodes in 'nodes' will be removed. */ -inline void FinalizeNodeFusion(Graph& graph, gsl::span> nodes, Node& replacement_node) { +inline void FinalizeNodeFusion(Graph& graph, gsl::span> nodes, + Node& replacement_node) { FinalizeNodeFusion(graph, nodes, replacement_node, replacement_node); } -inline void FinalizeNodeFusion(Graph& graph, std::initializer_list> nodes, Node& replacement_node) { +inline void FinalizeNodeFusion(Graph& graph, std::initializer_list> nodes, + Node& replacement_node) { FinalizeNodeFusion(graph, AsSpan(nodes), replacement_node, replacement_node); } @@ -357,17 +363,23 @@ struct EdgeEndToMatch { It is recommended to match path from bottom to top direction to avoid such issue. It is because each node input (dst_arg_index) only accepts one input edge. */ -bool FindPath(const Node& node, bool is_input_edge, gsl::span edges_to_match, std::vector& result, const logging::Logger& logger); +bool FindPath(const Node& node, bool is_input_edge, gsl::span edges_to_match, + std::vector& result, const logging::Logger& logger); -inline bool FindPath(const Node& node, bool is_input_edge, std::initializer_list edges_to_match, std::vector& result, const logging::Logger& logger) { +inline bool FindPath(const Node& node, bool is_input_edge, std::initializer_list edges_to_match, + std::vector& result, const logging::Logger& logger) { return FindPath(node, is_input_edge, AsSpan(edges_to_match), result, logger); } /** Same as FindPath above, but return the references of matched Node */ -bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span edges_to_match, std::vector>& result, const logging::Logger& logger); +bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span edges_to_match, + std::vector>& result, const logging::Logger& logger); -inline bool FindPath(Graph& graph, const Node& node, bool is_input_edge, std::initializer_list edges_to_match, std::vector>& result, const logging::Logger& logger) { +inline bool FindPath(Graph& graph, const Node& node, bool is_input_edge, + std::initializer_list edges_to_match, + std::vector>& result, + const logging::Logger& logger) { return FindPath(graph, node, is_input_edge, AsSpan(edges_to_match), result, logger); } diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index c639eeac5ea42..1842c2b4a0d1f 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -332,7 +332,7 @@ const std::vector& GraphViewer::GetNodesInTopologicalOrder(ExecutionO } const std::vector& GraphViewer::GetRootNodes() const { - // TODO: See if we need to calculate the root_nodes_ of the filtered graph. + // TODO(somebody): See if we need to calculate the root_nodes_ of the filtered graph. // GetRootNodes is only used by parallel executor currently, and isn't relevant to the usage of a filtered graph. ORT_ENFORCE(filter_info_ == nullptr, "Not supported with filtered graph."); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index b3935e69ad7b1..ee4d9f9154971 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -20,7 +20,7 @@ #endif #include "core/util/protobuf_parsing_utils.h" -#include "core/common/gsl.h" +#include #include "core/platform/env.h" @@ -81,7 +81,7 @@ Model::Model(const std::string& graph_name, const std::vector& model_local_functions, const logging::Logger& logger, const ModelOptions& options) - : model_path_(Path::Parse(model_path)) { + : model_path_(model_path) { model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); model_proto_.mutable_graph()->set_name(graph_name); model_metadata_ = model_metadata; @@ -140,12 +140,13 @@ Model::Model(const std::string& graph_name, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), + func.name()), std::move(func_template_ptr)); } // need to call private ctor so can't use make_shared - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) graph_.reset(new Graph(*this, model_proto_.mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, logger, options.strict_shape_type_inference)); } @@ -159,7 +160,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path, Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger, const ModelOptions& options) - : model_path_(Path::Parse(model_path)) { + : model_path_(model_path) { if (!utils::HasGraph(model_proto)) { ORT_THROW("ModelProto does not have a graph."); } @@ -269,11 +270,13 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr)); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), + func.name()), + std::move(func_template_ptr)); } // create instance. need to call private ctor so can't use make_unique - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger, options.strict_shape_type_inference)); } @@ -378,8 +381,8 @@ ModelProto Model::ToProto() const { return result; } -ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, - const PathString& file_path, +ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, + const std::filesystem::path& file_path, size_t initializer_size_threshold) const { ModelProto result(model_proto_); const auto& graph = *graph_; @@ -425,7 +428,7 @@ Status Model::Load(const ModelProto& model_proto, } // need to call private ctor so can't use make_shared - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) auto status = Status::OK(); ORT_TRY { @@ -465,7 +468,7 @@ Status Model::Load(ModelProto&& model_proto, } // need to call private ctor so can't use make_shared - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) auto status = Status::OK(); ORT_TRY { model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options); @@ -512,7 +515,7 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { } if (!status.IsOK()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); return status; } @@ -554,7 +557,8 @@ static Status SaveModel(Model& model, const T& file_path) { const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); - if (typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string') { + if (typeof process == 'object' && typeof process.versions == 'object' && + typeof process.versions.node == 'string') { // Node.js require('fs').writeFileSync(file_path, bytes); } else { @@ -585,7 +589,7 @@ static Status SaveModel(Model& model, const T& file_path) { }); } if (!status.IsOK()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); return status; } @@ -593,16 +597,14 @@ static Status SaveModel(Model& model, const T& file_path) { #endif } -#ifdef _WIN32 -Status Model::Save(Model& model, const std::wstring& file_path) { +Status Model::Save(Model& model, const PathString& file_path) { return SaveModel(model, file_path); } -#endif template static Status SaveModelWithExternalInitializers(Model& model, const T& file_path, - const std::string& external_file_name, + const std::filesystem::path& external_file_name, size_t initializer_size_threshold) { int fd = 0; Status status = Env::Default().FileOpenWr(file_path, fd); @@ -618,7 +620,7 @@ static Status SaveModelWithExternalInitializers(Model& model, }); } if (!status.IsOK()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); return status; } @@ -630,20 +632,16 @@ Status Model::Load(const PathString& file_path, return LoadModel(file_path, model_proto); } -GSL_SUPPRESS(r.30) // spurious warnings. p_model is potentially reset in the internal call to Load -GSL_SUPPRESS(r.35) +GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load +GSL_SUPPRESS(r .35) Status Model::Load(const PathString& file_path, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger, const ModelOptions& options) { return LoadModel(file_path, p_model, local_registries, logger, options); } -Status Model::Save(Model& model, const std::string& file_path) { - return SaveModel(model, file_path); -} - -Status Model::SaveWithExternalInitializers(Model& model, const PathString& file_path, - const std::string& external_file_name, +Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, + const std::filesystem::path& external_file_name, size_t initializer_size_threshold) { return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold); } @@ -759,8 +757,8 @@ Status Model::Save(Model& model, int p_fd) { Status Model::SaveWithExternalInitializers(Model& model, int fd, - const PathString& file_path, - const std::string& external_file_name, + const std::filesystem::path& file_path, + const std::filesystem::path& external_file_name, size_t initializer_size_threshold) { if (fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); @@ -768,7 +766,8 @@ Status Model::SaveWithExternalInitializers(Model& model, ORT_RETURN_IF_ERROR(model.MainGraph().Resolve()); - auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold); + auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, + initializer_size_threshold); google::protobuf::io::FileOutputStream output(fd); const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); if (result) { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6f4b7f4f9f00b..728af727ac83b 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -7,10 +7,10 @@ #include #include #include +#include #include "core/common/flatbuffers.h" -#include "core/common/path.h" #include "core/graph/graph_viewer.h" #include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" @@ -174,7 +174,7 @@ class Model { const ModelMetaData& MetaData() const noexcept; // Gets the path from which the model was loaded, if any. - const Path& ModelPath() const noexcept { return model_path_; } + const std::filesystem::path& ModelPath() const noexcept { return model_path_; } // Get model's main graph. Graph& MainGraph() noexcept; @@ -187,27 +187,24 @@ class Model { // Get model's serialization proto data. // Save initializer larger than the given threshold (in bytes) into an external binary file // with the given name. This function is useful to avoid hitting the size limit of protobuf files. - ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, - const PathString& file_path, + ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, + const std::filesystem::path& file_path, size_t initializer_size_threshold) const; -#ifdef _WIN32 - static common::Status Save(Model& model, const std::wstring& file_path); -#endif - static common::Status Save(Model& model, const std::string& file_path); + static common::Status Save(Model& model, const PathString& file_path); static common::Status Save(Model& model, int fd); // Save the model to file using an external file for initializers larger than the given threshold (in bytes). static common::Status SaveWithExternalInitializers(Model& model, - const PathString& file_path, - const std::string& external_file_name, + const std::filesystem::path& file_path, + const std::filesystem::path& external_file_path, size_t initializer_size_threshold); static common::Status SaveWithExternalInitializers(Model& model, int fd, - const PathString& file_path, - const std::string& external_file_name, + const std::filesystem::path& file_path, + const std::filesystem::path& external_file_path, size_t initializer_size_threshold); static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); @@ -332,7 +329,7 @@ class Model { ModelMetaData model_metadata_; // Path to model file. May be empty. - const Path model_path_; + const std::filesystem::path model_path_; // Main graph of the model. std::unique_ptr graph_; diff --git a/onnxruntime/core/graph/node_attr_utils.h b/onnxruntime/core/graph/node_attr_utils.h index 9433cfabc974e..638cebe6a3203 100644 --- a/onnxruntime/core/graph/node_attr_utils.h +++ b/onnxruntime/core/graph/node_attr_utils.h @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/graph/onnx_protobuf.h" #include "core/graph/basic_types.h" diff --git a/onnxruntime/core/graph/op.h b/onnxruntime/core/graph/op.h index 2a720a3ccad45..295b8cb95f48c 100644 --- a/onnxruntime/core/graph/op.h +++ b/onnxruntime/core/graph/op.h @@ -3,12 +3,13 @@ #pragma once -#include -#include #include "core/graph/onnx_protobuf.h" #include "core/common/status.h" #include "core/graph/constants.h" +#include +#include + namespace onnxruntime { using AttrType = ONNX_NAMESPACE::AttributeProto_AttributeType; using NodeAttributes = std::unordered_map; @@ -29,19 +30,18 @@ AttributeProto_AttributeType_GRAPHS = 10, AttributeProto_AttributeType_SPARSE_TENSOR = 22, AttributeProto_AttributeType_SPARSE_TENSORS = 23, */ -static constexpr const char* kAttrTypeStrings[] = - { - "UNDEFINED", - "FLOAT", - "INT", - "STRING", - "TENSOR", - "GRAPH", - "FLOATS", - "INTS", - "STRINGS", - "TENSORS", - "GRAPHS"}; +static constexpr const char* kAttrTypeStrings[] = { + "UNDEFINED", + "FLOAT", + "INT", + "STRING", + "TENSOR", + "GRAPH", + "FLOATS", + "INTS", + "STRINGS", + "TENSORS", + "GRAPHS"}; class TypeUtils { public: diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.cc b/onnxruntime/core/graph/runtime_optimization_record_container.cc index acd85b909e5be..2d0e1076ee372 100644 --- a/onnxruntime/core/graph/runtime_optimization_record_container.cc +++ b/onnxruntime/core/graph/runtime_optimization_record_container.cc @@ -4,15 +4,13 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "core/graph/runtime_optimization_record_container.h" - -#include - -#include "core/common/gsl.h" - #include "core/flatbuffers/flatbuffers_utils.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/graph/op_identifier_utils.h" +#include +#include + namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) @@ -182,7 +180,8 @@ Status RuntimeOptimizationRecordContainer::LoadFromOrtFormat( } ORT_RETURN_IF_NOT(optimizer_name_to_records.emplace(optimizer_name, std::move(records)).second, - "Attempting to load runtime optimization records for a previously loaded optimizer: ", optimizer_name); + "Attempting to load runtime optimization records for a previously loaded optimizer: ", + optimizer_name); } optimizer_name_to_records_ = std::move(optimizer_name_to_records); diff --git a/onnxruntime/core/graph/schema_registry.cc b/onnxruntime/core/graph/schema_registry.cc index 4dc714bd8af79..a7d94f4571d96 100644 --- a/onnxruntime/core/graph/schema_registry.cc +++ b/onnxruntime/core/graph/schema_registry.cc @@ -99,7 +99,7 @@ common::Status OnnxRuntimeOpSchemaRegistry::RegisterOpSchemaInternal(ONNX_NAMESP << "than the operator set version " << ver_range_it->second.opset_version << std::endl; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostream.str()); } - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) map_[op_name][op_domain].emplace(std::make_pair(ver, op_schema)); return common::Status::OK(); } @@ -161,7 +161,8 @@ void SchemaRegistryManager::RegisterRegistry(std::shared_ptrGetLatestOpsetVersions(is_onnx_only); @@ -172,7 +173,7 @@ void SchemaRegistryManager::GetDomainToVersionMapForRegistries(DomainToVersionMa // If the map doesn't yet contain this domain, insert it with this registry's value. // Otherwise, merge the existing range in the map. if (iter == domain_version_map.end()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) domain_version_map.insert(local_domain); } else { iter->second = std::max(iter->second, local_domain.second); @@ -194,7 +195,7 @@ DomainToVersionMap SchemaRegistryManager::GetLastReleasedOpsetVersions(bool is_o continue; auto it = domain_version_map.find(domain.first); if (it == domain_version_map.end()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) domain_version_map.insert(std::make_pair(domain.first, domain.second)); } else { it->second = std::max(it->second, domain.second); @@ -217,7 +218,7 @@ DomainToVersionMap SchemaRegistryManager::GetLatestOpsetVersions(bool is_onnx_on continue; auto it = domain_version_map.find(domain.first); if (it == domain_version_map.end()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) domain_version_map.insert(std::make_pair(domain.first, domain.second.second)); } else { it->second = std::max(it->second, domain.second.second); @@ -271,7 +272,7 @@ void SchemaRegistryManager::GetSchemaAndHistory( } if (new_version < version) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) unchecked_registry_indices.insert(unchecked_registry_indices.end(), checked_registry_indices.begin(), checked_registry_indices.end()); diff --git a/onnxruntime/core/language_interop_ops/language_interop_ops.cc b/onnxruntime/core/language_interop_ops/language_interop_ops.cc deleted file mode 100644 index b40ee08479055..0000000000000 --- a/onnxruntime/core/language_interop_ops/language_interop_ops.cc +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "language_interop_ops.h" -#include "core/framework/tensorprotoutils.h" -#include "core/platform/env.h" -#include "core/session/inference_session.h" -#include "pyop/pyop.h" -#include - -namespace onnxruntime { - -void LoadInterOp(const std::basic_string& model_uri, InterOpDomains& domains, const InterOpLogFunc& log_func) { - int fd; - - // match the error message from model.cc to keep the nodejs tests happy. - // as this is deprecated just cut-and-paste equivalent code for now. - auto status = Env::Default().FileOpenRd(model_uri, fd); - if (!status.IsOK()) { - if (status.Category() == common::SYSTEM) { - switch (status.Code()) { - case ENOENT: - status = ORT_MAKE_STATUS(ONNXRUNTIME, NO_SUCHFILE, "Load model ", ToUTF8String(model_uri), - " failed. File doesn't exist"); - break; - case EINVAL: - status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Load model ", ToUTF8String(model_uri), " failed"); - break; - default: - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "system error number ", status.Code()); - } - } - } - - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - - google::protobuf::io::FileInputStream f(fd); - f.SetCloseOnDelete(true); - ONNX_NAMESPACE::ModelProto model_proto; - ORT_ENFORCE(model_proto.ParseFromZeroCopyStream(&f), "Failed to parse model proto"); - LoadInterOp(model_proto, domains, log_func); -} - -void LoadInterOp(const ONNX_NAMESPACE::ModelProto& model_proto, InterOpDomains& domains, const InterOpLogFunc& log_func) { - LoadInterOp(model_proto.graph(), domains, log_func); -} - -void LoadInterOp(const ONNX_NAMESPACE::GraphProto& graph_proto, InterOpDomains& domains, const InterOpLogFunc& log_func) { - for (int i = 0; i < graph_proto.node_size(); ++i) { - const auto& node_proto = graph_proto.node(i); - if (node_proto.op_type() == "PyOp") { - auto pyop_domain = Ort::CustomOpDomain(node_proto.domain().c_str()); - pyop_domain.Add(LoadPyOp(node_proto, log_func)); - domains.push_back(std::move(pyop_domain)); - } else { - for (int j = 0, limit = node_proto.attribute_size(); j < limit; ++j) { - const auto& attr = node_proto.attribute(j); - if (utils::HasGraph(attr)) { - LoadInterOp(attr.g(), domains, log_func); // load pyop in subgraph - } - } // for - } // else - } // for -} -} // namespace onnxruntime diff --git a/onnxruntime/core/language_interop_ops/language_interop_ops.h b/onnxruntime/core/language_interop_ops/language_interop_ops.h deleted file mode 100644 index 2ab5945b17bc2..0000000000000 --- a/onnxruntime/core/language_interop_ops/language_interop_ops.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once -#include -#include -#include -#include "core/graph/graph.h" -#include "core/session/onnxruntime_cxx_api.h" - -namespace onnxruntime { -using InterOpLogFunc = std::function; -using InterOpDomains = std::vector; -void LoadInterOp(const std::basic_string& model_uri, InterOpDomains& domains, const InterOpLogFunc& log_func); -void LoadInterOp(const ONNX_NAMESPACE::ModelProto& model_proto, InterOpDomains& domains, const InterOpLogFunc& log_func); -void LoadInterOp(const ONNX_NAMESPACE::GraphProto& graph_proto, InterOpDomains& domains, const InterOpLogFunc& log_func); -} // namespace onnxruntime diff --git a/onnxruntime/core/language_interop_ops/pyop/pyop.cc b/onnxruntime/core/language_interop_ops/pyop/pyop.cc deleted file mode 100644 index ccbe4c0d83006..0000000000000 --- a/onnxruntime/core/language_interop_ops/pyop/pyop.cc +++ /dev/null @@ -1,399 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "pyop.h" - -#ifdef _WIN32 -#define LIB_PYOP "onnxruntime_pywrapper.dll" -#define LOAD_PYOP_LIB(n, v, m) ORT_ENFORCE((v = LoadLibraryA(n)) != nullptr, m) -#else -#ifdef __APPLE__ -#define LIB_PYOP "./libonnxruntime_pywrapper.dylib" -#else -#define LIB_PYOP "./libonnxruntime_pywrapper.so" -#endif -#define LOAD_PYOP_LIB(n, v, m) ORT_ENFORCE((v = dlopen(n, RTLD_NOW | RTLD_GLOBAL)) != nullptr, m) -#include "dlfcn.h" -#endif - -#include "core/framework/tensorprotoutils.h" -#include "core/platform/env.h" -#ifdef _DEBUG -#undef _DEBUG -#include -#define _DEBUG -#else -#include -#endif - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include "numpy/arrayobject.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace onnxruntime { - -PyOpLibProxy& PyOpLibProxy::GetInstance() { - static PyOpLibProxy proxy; - return proxy; -} - -class Scope { - public: - Scope(const std::vector& objs = {}) : objs_(objs) { - mtx_.lock(); - } - ~Scope() { - for (auto obj : objs_) { - Py_XDECREF(obj); - } - mtx_.unlock(); - } - void Add(PyObject* obj) { - objs_.push_back(obj); - } - - private: - static std::mutex mtx_; - std::vector objs_; -}; - -PyOpLibProxy::PyOpLibProxy() { - Scope scope; - Py_Initialize(); - if (_import_array() < 0) { - return; - } - auto path_list = PySys_GetObject("path"); // do not release it - if (nullptr == path_list || !PyList_Check(path_list) || - PyList_Append(path_list, PyUnicode_FromString(".")) != 0) { - return; - } - initialized_ = true; -} - -PyOpLibProxy::~PyOpLibProxy() { - if (initialized_) { - Py_Finalize(); - } -} - -std::mutex Scope::mtx_; - -const char* PyOpLibProxy::GetLastErrorMessage(std::string& err) { - Scope scope; - if (PyErr_Occurred()) { - PyObject *type, *value, *trace; - PyErr_Fetch(&type, &value, &trace); - if (nullptr != value) { - auto pyVal = PyObject_Repr(value); - scope.Add(pyVal); - auto pyStr = PyUnicode_AsEncodedString(pyVal, "utf-8", "Error ~"); - scope.Add(pyStr); - err = PyBytes_AS_STRING(pyStr); - } - PyErr_Restore(type, value, trace); - } - return err.c_str(); -} - -int32_t PyOpLibProxy::GetGil() const { - return PyGILState_Ensure(); -} - -void PyOpLibProxy::PutGil(int32_t state) const { - PyGILState_Release((PyGILState_STATE)state); -} - -PyObject* MakePyObj(const void* data, int32_t type, const std::vector& dim) { - std::vector np_dim; - for (auto d : dim) { - np_dim.push_back(static_cast(d)); - } - auto pyObj = static_cast(PyArray_EMPTY(static_cast(np_dim.size()), np_dim.data(), type, 0)); - auto data_len = std::accumulate(begin(np_dim), end(np_dim), - static_cast(PyArray_DescrFromType(type)->elsize), - std::multiplies()); - auto np_array = reinterpret_cast(pyObj); - memcpy(PyArray_DATA(np_array), data, data_len); - return pyObj; -} - -bool ExtractOutput(PyObject* pyObj, - std::vector>& outputs, - std::vector& outputs_elem_size, - std::vector>& outputs_dim) { - if (!PyArray_Check(pyObj)) { - return false; - } - - outputs_dim.push_back({}); - auto np_array = reinterpret_cast(pyObj); - outputs_elem_size.push_back(static_cast(PyArray_ITEMSIZE(np_array))); - - for (int i = 0; i < PyArray_NDIM(np_array); ++i) { - outputs_dim.back().push_back(PyArray_SHAPE(np_array)[i]); - } - - auto data_len = std::accumulate(begin(outputs_dim.back()), - end(outputs_dim.back()), - static_cast(outputs_elem_size.back()), - std::multiplies()); - - outputs.push_back(std::unique_ptr(new char[data_len])); - memcpy(static_cast(outputs.back().get()), PyArray_DATA(np_array), data_len); - return true; -} - -void* PyOpLibProxy::NewInstance(const char* module, const char* class_name, - const std::unordered_map& args) { - Scope scope; - auto pyModule = PyImport_ImportModule(module); - if (nullptr == pyModule) { - return nullptr; - } - - scope.Add(pyModule); - auto pyClass = PyObject_GetAttrString(pyModule, class_name); - if (nullptr == pyClass) { - return nullptr; - } - - scope.Add(pyClass); - auto empty_args = PyTuple_New(0); - scope.Add(empty_args); - auto named_args = PyDict_New(); - scope.Add(named_args); - for (const auto& iter : args) { - PyDict_SetItemString(named_args, iter.first.c_str(), PyUnicode_FromString(iter.second.c_str())); - } - - return PyObject_Call(pyClass, empty_args, named_args); -} - -void PyOpLibProxy::ReleaseInstance(void* instance) { - Scope scope({static_cast(instance)}); -} - -bool PyOpLibProxy::InvokePythonFunc(void* raw_inst, - const char* function, - const std::vector& inputs, - const std::vector& inputs_type, - const std::vector>& inputs_dim, - std::vector>& outputs, - std::vector& outputs_elem_size, - std::vector>& outputs_dim, - std::function logging_func) { - Scope scope; - auto instance = static_cast(raw_inst); - if (nullptr == instance || nullptr == function) { - logging_func("InvokePythonFunc: found invalid instance or function"); - return false; - } - - auto pyFunc = PyObject_GetAttrString(instance, function); - if (nullptr == pyFunc) { - logging_func("InvokePythonFunc: failed to create function object"); - return false; - } - - scope.Add(pyFunc); - auto pyArgs = PyTuple_New(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - PyTuple_SetItem(pyArgs, i, MakePyObj(inputs[i], inputs_type[i], inputs_dim[i])); - } - - scope.Add(pyArgs); - auto pyResult = PyObject_CallObject(pyFunc, pyArgs); - if (nullptr == pyResult) { - logging_func("InvokePythonFunc: no result"); - return false; - } - - scope.Add(pyResult); - if (PyArray_Check(pyResult)) { - ExtractOutput(pyResult, outputs, outputs_elem_size, outputs_dim); - } else if (PyTuple_Check(pyResult)) { - for (int32_t i = 0; i < PyTuple_Size(pyResult); ++i) { - if (!ExtractOutput(PyTuple_GetItem(pyResult, i), outputs, outputs_elem_size, outputs_dim)) { - logging_func("InvokePythonFunc: failed to extract output"); - return false; - } - } - } else { - logging_func("InvokePythonFunc: returned value must be numpy(s)"); - return false; - } - return true; -} // bool InvokePythonFunc - -PyCustomKernel::PyCustomKernel(const OnnxAttrs& attrs, - const std::string& module, - const std::string& class_name, - const std::string& compute, - PyOpLogFunc logging_func) : attrs_(attrs), module_(module), class_name_(class_name), compute_(compute), logging_func_(logging_func) { - std::string err; - auto state = PyOpLibProxy::GetInstance().GetGil(); - ORT_ENFORCE(PyOpLibProxy::GetInstance().Initialized(), "Py library not properly initialized."); - instance_ = PyOpLibProxy::GetInstance().NewInstance(module.c_str(), class_name_.c_str(), attrs_); - PyOpLibProxy::GetInstance().PutGil(state); - ORT_ENFORCE(nullptr != instance_, PyOpLibProxy::GetInstance().GetLastErrorMessage(err)); -} - -PyCustomKernel::~PyCustomKernel() { - if (nullptr != instance_) { - auto state = PyOpLibProxy::GetInstance().GetGil(); - PyOpLibProxy::GetInstance().ReleaseInstance(instance_); - PyOpLibProxy::GetInstance().PutGil(state); - instance_ = nullptr; - } -} - -// Do nothing since Custom Op does not trigger shape inference -void PyCustomKernel::GetOutputShape(OrtKernelContext*, size_t, OrtTensorTypeAndShapeInfo*) {} - -void PyCustomKernel::Compute(OrtKernelContext* context) { - ORT_ENFORCE(nullptr != context); - - Ort::KernelContext ctx(context); - const auto inputs_count = ctx.GetInputCount(); - - std::vector inputs; - std::vector> outputs; - std::vector inputs_type, outputs_elem_size; - std::vector> inputs_dim, outputs_dim; - - inputs.reserve(inputs_count); - inputs_dim.reserve(inputs_count); - for (size_t i = 0; i < inputs_count; ++i) { - auto value = ctx.GetInput(i); - ORT_ENFORCE(value.IsTensor(), "input must be a tensor"); - - inputs.push_back(value.GetTensorRawData()); - - auto type_and_shape = value.GetTensorTypeAndShapeInfo(); - inputs_type.push_back(GetNumpyType(type_and_shape.GetElementType())); - auto shape = type_and_shape.GetShape(); - inputs_dim.push_back(std::move(shape)); - } - - std::string err; - auto state = PyOpLibProxy::GetInstance().GetGil(); - ORT_ENFORCE(PyOpLibProxy::GetInstance().InvokePythonFunc(instance_, compute_.c_str(), inputs, inputs_type, - inputs_dim, outputs, outputs_elem_size, - outputs_dim, logging_func_), - PyOpLibProxy::GetInstance().GetLastErrorMessage(err)); // ORT_ENFORCE - PyOpLibProxy::GetInstance().PutGil(state); - - for (size_t i = 0; i < outputs.size(); ++i) { - auto ort_output = ctx.GetOutput(i, outputs_dim[i].data(), outputs_dim[i].size()); - auto output_mem_addr = ort_output.GetTensorMutableData(); - auto output_len = std::accumulate(begin(outputs_dim[i]), end(outputs_dim[i]), static_cast(outputs_elem_size[i]), std::multiplies()); - memcpy(output_mem_addr, outputs[i].get(), output_len); - } -} - -int32_t PyCustomKernel::GetNumpyType(int32_t elem_type) const { - int32_t numpy_type; - namespace on = ONNX_NAMESPACE; - switch (elem_type) { - case on::TensorProto_DataType_BOOL: - numpy_type = 0; - break; - case on::TensorProto_DataType_INT8: - numpy_type = 1; - break; - case on::TensorProto_DataType_UINT8: - numpy_type = 2; - break; - case on::TensorProto_DataType_INT16: - numpy_type = 3; - break; - case on::TensorProto_DataType_UINT16: - numpy_type = 4; - break; - case on::TensorProto_DataType_INT32: - numpy_type = 5; - break; - case on::TensorProto_DataType_UINT32: - numpy_type = 6; - break; - case on::TensorProto_DataType_INT64: - numpy_type = 9; - break; - case on::TensorProto_DataType_UINT64: - numpy_type = 10; - break; - case on::TensorProto_DataType_FLOAT: - numpy_type = 11; - break; - case on::TensorProto_DataType_DOUBLE: - numpy_type = 12; - break; - default: - ORT_THROW("Input primitive type not supported: ", elem_type); - } - return numpy_type; -} - -PyCustomOp::PyCustomOp(const OnnxAttrs& attrs, - const OnnxTypes& inputs_type, - const OnnxTypes& outputs_type, - const std::string& module, - const std::string& class_name, - const std::string& compute, - PyOpLogFunc logging_func) : attrs_(attrs), inputs_type_(inputs_type), outputs_type_(outputs_type), module_(module), class_name_(class_name), compute_(compute), logging_func_(logging_func) { OrtCustomOp::version = ORT_API_VERSION; } - -void* PyCustomOp::CreateKernel(const OrtApi&, const OrtKernelInfo*) const { - return new PyCustomKernel(attrs_, module_, class_name_, compute_, logging_func_); -} - -const char* PyCustomOp::GetName() const { return "PyOp"; } - -size_t PyCustomOp::GetInputTypeCount() const { return inputs_type_.size(); } -ONNXTensorElementDataType PyCustomOp::GetInputType(size_t index) const { return inputs_type_[index]; } - -size_t PyCustomOp::GetOutputTypeCount() const { return outputs_type_.size(); } -ONNXTensorElementDataType PyCustomOp::GetOutputType(size_t index) const { return outputs_type_[index]; } - -PyCustomOp* LoadPyOp(const ONNX_NAMESPACE::NodeProto& node_proto, PyOpLogFunc log_func) { - OnnxAttrs onnx_attrs; - OnnxTypes input_types, output_types; - std::string module, class_name, compute = "compute"; - for (int j = 0; j < node_proto.attribute_size(); ++j) { - const auto& attr = node_proto.attribute(j); - if (utils::HasString(attr)) { - if (attr.name() == "module") - module = attr.s(); - else if (attr.name() == "class_name") - class_name = attr.s(); - else if (attr.name() == "compute") - compute = attr.s(); - else - onnx_attrs[attr.name()] = attr.s(); - } else if (attr.ints_size() > 0) { - if (attr.name() == "input_types") { - for (int k = 0; k < attr.ints_size(); ++k) { - input_types.push_back(static_cast(attr.ints(k))); - } - } else if (attr.name() == "output_types") { - for (int k = 0; k < attr.ints_size(); ++k) { - output_types.push_back(static_cast(attr.ints(k))); - } - } - } - } // for - ORT_ENFORCE(module != "", "PyOp module not specified"); - ORT_ENFORCE(class_name != "", "PyOp class name not specified"); - ORT_ENFORCE(!input_types.empty(), "PyOp node inputs not specified"); - ORT_ENFORCE(!output_types.empty(), "PyOp node outputs not specified"); - return new PyCustomOp(onnx_attrs, input_types, output_types, module, class_name, compute, log_func); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/language_interop_ops/pyop/pyop.h b/onnxruntime/core/language_interop_ops/pyop/pyop.h deleted file mode 100644 index 049a247aec469..0000000000000 --- a/onnxruntime/core/language_interop_ops/pyop/pyop.h +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/platform/env.h" -#define LOAD_PYOP_SYM(n, v, m) ORT_ENFORCE(Env::Default().GetSymbolFromLibrary(handle_, n, reinterpret_cast(&v)) == Status::OK(), m) - -#include "core/session/onnxruntime_cxx_api.h" -#include -#include -#include -#ifdef _WIN32 -#include -#else -#define HMODULE void* -#endif - -namespace ONNX_NAMESPACE { -class NodeProto; -} - -namespace onnxruntime { - -using OnnxTypes = std::vector; -using OnnxAttrs = std::unordered_map; -using PyOpLogFunc = std::function; - -class PyOpLibProxy { - public: - static PyOpLibProxy& GetInstance(); - void ReleaseInstance(void*); - bool InvokePythonFunc(void*, - const char*, - const std::vector&, - const std::vector&, - const std::vector>&, - std::vector>&, - std::vector&, - std::vector>&, - std::function); - const char* GetLastErrorMessage(std::string&); - void* NewInstance(const char*, const char*, const OnnxAttrs&); - bool Initialized() const { return initialized_; } - int32_t GetGil() const; - void PutGil(int32_t) const; - - private: - PyOpLibProxy(); - ~PyOpLibProxy(); - bool initialized_ = false; -}; - -struct PyCustomKernel { - PyCustomKernel(const OnnxAttrs& attrs, - const std::string& module, - const std::string& class_name, - const std::string& compute, - PyOpLogFunc logging_func); - ~PyCustomKernel(); - void GetOutputShape(OrtKernelContext*, size_t, OrtTensorTypeAndShapeInfo*); - void Compute(OrtKernelContext* context); - int32_t GetNumpyType(int32_t elem_type) const; - - private: - OnnxAttrs attrs_; - std::string module_; - std::string class_name_; - std::string compute_; - void* instance_ = nullptr; - PyOpLogFunc logging_func_; -}; - -struct PyCustomOp : Ort::CustomOpBase { - PyCustomOp( - const OnnxAttrs& attrs, - const OnnxTypes& inputs_type, - const OnnxTypes& outputs_type, - const std::string& module, - const std::string& class_name, - const std::string& compute = "compute", - PyOpLogFunc logging_func = [](const char*) {}); - void* CreateKernel(const OrtApi&, const OrtKernelInfo*) const; - const char* GetName() const; - size_t GetInputTypeCount() const; - ONNXTensorElementDataType GetInputType(size_t index) const; - size_t GetOutputTypeCount() const; - ONNXTensorElementDataType GetOutputType(size_t index) const; - - private: - OnnxAttrs attrs_; - OnnxTypes inputs_type_; - OnnxTypes outputs_type_; - std::string module_; - std::string class_name_; - std::string compute_; - PyOpLogFunc logging_func_; -}; // struct PyCustomOp - -PyCustomOp* LoadPyOp(const ONNX_NAMESPACE::NodeProto& node_proto, PyOpLogFunc log_func); -} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h index 38795291b0328..9c7e00fff88ef 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -35,7 +35,7 @@ * * @file quantb_gemm.h * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. -*/ + */ #pragma once @@ -145,7 +145,6 @@ template < typename PermuteDLayout = layout::NoPermute> class QuantBGemm { public: - using ElementA = ElementA_; using LayoutA = LayoutA_; using TensorRefA = TensorRef; @@ -189,34 +188,33 @@ class QuantBGemm { /// Define the kernel using GemmKernel = typename kernel::DefaultQuantBGemm< - ElementA, - LayoutA, - kAlignmentA, - ElementB, - LayoutB, - kAlignmentB, - ElementQScale, - ElementQOffset, - LayoutQMeta, - QuantBlocking, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - kStages, - kSplitKSerial, - Operator, - GatherA, - GatherB, - ScatterD, - PermuteDLayout - >::GemmKernel; + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout>::GemmKernel; /// Argument structure struct Arguments { @@ -237,9 +235,9 @@ class QuantBGemm { // split-K parallelism (etc.) are not yet supported, keeping this for future extension int split_k_slices{1}; // For gather+scatter operations - int const *gather_A_indices{nullptr}; - int const *gather_B_indices{nullptr}; - int const *scatter_D_indices{nullptr}; + int const* gather_A_indices{nullptr}; + int const* gather_B_indices{nullptr}; + int const* scatter_D_indices{nullptr}; // // Methods @@ -247,49 +245,47 @@ class QuantBGemm { /// Default ctor CUTLASS_HOST_DEVICE - Arguments(): problem_size(0, 0, 0) {} + Arguments() : problem_size(0, 0, 0) {} /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( - GemmCoord problem_size_, - TensorRef ref_A_, - TensorRef ref_B_, - TensorRef ref_Qscale_, - TensorRef ref_C_, - TensorRef ref_D_, - typename EpilogueOutputOp::Params epilogue_ = - typename EpilogueOutputOp::Params()): - problem_size(problem_size_), - ref_A(ref_A_), - ref_B(ref_B_), - ref_Qscale(ref_Qscale_), - ref_C(ref_C_), - ref_D(ref_D_), - epilogue(epilogue_) { - assert(!kHasQOffset); + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()) : problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); } CUTLASS_HOST_DEVICE Arguments( - GemmCoord problem_size_, - TensorRef ref_A_, - TensorRef ref_B_, - TensorRef ref_Qscale_, - TensorRef ref_Qoffset_, - TensorRef ref_C_, - TensorRef ref_D_, - typename EpilogueOutputOp::Params epilogue_ = - typename EpilogueOutputOp::Params()): - problem_size(problem_size_), - ref_A(ref_A_), - ref_B(ref_B_), - ref_Qscale(ref_Qscale_), - ref_Qoffset(ref_Qoffset_), - ref_C(ref_C_), - ref_D(ref_D_), - epilogue(epilogue_) { - assert(kHasQOffset); + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()) : problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); } }; @@ -299,24 +295,22 @@ class QuantBGemm { public: /// Constructs the GEMM. - QuantBGemm() { } + QuantBGemm() {} /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - + static Status can_implement(Arguments const& args) { if (!kSplitKSerial && args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } Status status = GemmKernel::can_implement( - args.problem_size, - args.ref_A.non_const_ref(), - args.ref_B.non_const_ref(), - args.ref_Qscale.non_const_ref(), - args.ref_Qoffset.non_const_ref(), - args.ref_C.non_const_ref(), - args.ref_D - ); + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D); if (status != Status::kSuccess) { return status; @@ -326,20 +320,18 @@ class QuantBGemm { } /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - + static size_t get_workspace_size(Arguments const& args) { size_t bytes = 0; // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); if (kSplitKSerial && args.split_k_slices > 1) { - bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } @@ -347,15 +339,14 @@ class QuantBGemm { } /// Initializes GEMM state from arguments. - Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); if (kSplitKSerial) { if (args.split_k_slices > 1) { @@ -372,7 +363,6 @@ class QuantBGemm { } } } else { - if (args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } @@ -380,27 +370,25 @@ class QuantBGemm { // Initialize the Params structure params_ = typename GemmKernel::Params{ - args.problem_size, - grid_shape, - args.ref_A.non_const_ref(), - args.ref_B.non_const_ref(), - args.ref_Qscale.non_const_ref(), - args.ref_Qoffset.non_const_ref(), - args.ref_C.non_const_ref(), - args.ref_D, - args.epilogue, - static_cast(workspace), - args.gather_A_indices, - args.gather_B_indices, - args.scatter_D_indices - }; + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices}; return Status::kSuccess; } /// Lightweight update given a subset of arguments - Status update(Arguments const &args, void *workspace = nullptr) { - + Status update(Arguments const& args, void* workspace = nullptr) { if (kSplitKSerial && args.split_k_slices > 1) { if (!workspace) { return Status::kErrorWorkspaceNull; @@ -414,14 +402,13 @@ class QuantBGemm { params_.ref_C.reset(args.ref_C.non_const_ref().data()); params_.ref_D.reset(args.ref_D.data()); params_.output_op = args.epilogue; - params_.semaphore = static_cast(workspace); + params_.semaphore = static_cast(workspace); return Status::kSuccess; } /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { - ThreadblockSwizzle threadblock_swizzle; dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); @@ -457,10 +444,9 @@ class QuantBGemm { /// Runs the kernel using initialized state. Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { @@ -471,11 +457,10 @@ class QuantBGemm { } }; - //////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h index 2f4460bb59e9f..f471411730a3d 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -69,7 +69,7 @@ #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -#endif //CUTLASS_ARCH_WMMA_ENABLED +#endif // CUTLASS_ARCH_WMMA_ENABLED //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -139,13 +139,11 @@ template < /// Permute operand B typename PermuteBLayout = layout::NoPermute, /// - typename Enable = void -> + typename Enable = void> struct DefaultQuantBGemm; //////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Ampere Architecture @@ -204,8 +202,7 @@ template < /// Permute operand A typename PermuteALayout, /// Permute operand B - typename PermuteBLayout -> + typename PermuteBLayout> struct DefaultQuantBGemm { - - static_assert((platform::is_same::value - || platform::is_same>::value), - "Epilogue in the kernel level must be row major"); + static_assert((platform::is_same::value || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h index 6e5ad8f406147..72b2cf641e132 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -59,13 +59,12 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -> + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. + > struct QuantBGemm { - using Mma = Mma_; using Epilogue = Epilogue_; using OutputOp = typename Epilogue::OutputOp; @@ -96,55 +95,53 @@ struct QuantBGemm { typename Epilogue::OutputTileIterator::Params params_D; typename Epilogue::OutputTileIterator::TensorRef ref_D; typename OutputOp::Params output_op; - int *semaphore; + int* semaphore; int gemm_k_size; // how many k vectors are processed by this threadblock // For gather+scatter operations - int const *gather_A_indices; - int const *gather_B_indices; - int const *scatter_D_indices; + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; // // Methods // CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} CUTLASS_HOST_DEVICE Params( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorQScale::TensorRef ref_QScale, - typename Mma::IteratorQOffset::TensorRef ref_QOffset, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, - typename OutputOp::Params output_op = typename OutputOp::Params(), - int *workspace = nullptr, - int const *gather_A_indices = nullptr, - int const *gather_B_indices = nullptr, - int const *scatter_D_indices = nullptr - ): - problem_size(problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(ref_A.layout()), - ref_A(ref_A), - params_B(ref_B.layout()), - ref_B(ref_B), - params_QScale(ref_QScale.layout()), - ref_QScale(ref_QScale), - params_QOffset(ref_QOffset.layout()), - ref_QOffset(ref_QOffset), - params_C(ref_C.layout()), - ref_C(ref_C), - params_D(ref_D.layout()), - ref_D(ref_D), - output_op(output_op), - gather_A_indices(gather_A_indices), - gather_B_indices(gather_B_indices), - scatter_D_indices(scatter_D_indices) { + cutlass::gemm::GemmCoord const& problem_size, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int* workspace = nullptr, + int const* gather_A_indices = nullptr, + int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) : problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); @@ -165,42 +162,41 @@ struct QuantBGemm { // CUTLASS_HOST_DEVICE - QuantBGemm() { } + QuantBGemm() {} /// Determines whether kernel satisfies alignment CUTLASS_HOST_DEVICE static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorQScale::TensorRef ref_QScale, - typename Mma::IteratorQOffset::TensorRef ref_QOffset, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D) { - + cutlass::gemm::GemmCoord const& problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { // TODO check problem_size K, N must be multiple of QuantBlocking static int const kAlignmentA = (platform::is_same>::value) - ? 32 + ? 32 : (platform::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = (platform::is_same>::value) - ? 32 + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 : (platform::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; + ? 64 + : Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = (platform::is_same>::value) - ? 32 + ? 32 : (platform::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; if (!TensorRef_aligned(ref_A, kAlignmentA)) { return Status::kErrorMisalignedOperand; @@ -237,7 +233,7 @@ struct QuantBGemm { return Status::kErrorMisalignedOperand; } - if constexpr(kHasQOffset) { + if constexpr (kHasQOffset) { if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { return Status::kErrorMisalignedOperand; } @@ -256,8 +252,7 @@ struct QuantBGemm { /// Executes one GEMM CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - + void operator()(Params const& params, SharedStorage& shared_storage) { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; @@ -266,26 +261,24 @@ struct QuantBGemm { // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { return; } // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, }; cutlass::MatrixCoord tb_offset_B{ - (threadblock_tile_offset.k() * params.gemm_k_size) / 2, - (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 - }; + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2}; // Problem size is a function of threadblock index in the K dimension int problem_size_k = min( - params.problem_size.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; @@ -295,20 +288,20 @@ struct QuantBGemm { // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( - params.params_A, - params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A, - params.gather_A_indices); + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); typename Mma::IteratorB iterator_B( - params.params_B, - params.ref_B.data(), - {problem_size_k/2, params.problem_size.n()/2}, - thread_idx, - tb_offset_B, - params.gather_B_indices); + params.params_B, + params.ref_B.data(), + {problem_size_k / 2, params.problem_size.n() / 2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; @@ -318,24 +311,23 @@ struct QuantBGemm { assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); cutlass::MatrixCoord tb_offset_QScale{ - threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), - threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) - }; + threadblock_tile_offset.k() * (params.gemm_k_size / Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN / Mma::QuantBlocking::kColumn)}; typename Mma::IteratorQScale iterator_QScale( - params.params_QScale, - params.ref_QScale.data(), - {qscale_k, qscale_n}, - thread_idx, - tb_offset_QScale, - nullptr); + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); typename Mma::IteratorQOffset iterator_QOffset( - params.params_QOffset, - params.ref_QOffset.data(), - {qscale_k, qscale_n}, - thread_idx, - tb_offset_QScale); + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -371,11 +363,10 @@ struct QuantBGemm { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - //assume identity swizzle + // assume identity swizzle MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); @@ -384,7 +375,6 @@ struct QuantBGemm { // If performing a reduction via split-K, fetch the initial synchronization if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - // Fetch the synchronization lock initially but do not block. semaphore.fetch(); @@ -394,40 +384,36 @@ struct QuantBGemm { // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - params.ref_C.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices - ); + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - params.ref_D.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices - ); + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); // Wait on the semaphore - this latency may have been covered by iterator construction if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. if (threadblock_tile_offset.k()) { iterator_C = iterator_D; } semaphore.wait(threadblock_tile_offset.k()); - } // Execute the epilogue operator to update the destination tensor. @@ -438,14 +424,11 @@ struct QuantBGemm { // if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - // The final threadblock resets the semaphore for subsequent grids. lock = 0; - } - else { + } else { // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } @@ -457,6 +440,6 @@ struct QuantBGemm { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h index 0af604f090e1f..02b439fbf59eb 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -113,8 +113,7 @@ template < /// Permute operand A typename PermuteALayout = layout::NoPermute, /// Permute operand B - typename PermuteBLayout = layout::NoPermute - > + typename PermuteBLayout = layout::NoPermute> struct DefaultQuantBMma; //////////////////////////////////////////////////////////////////////////////// @@ -164,19 +163,16 @@ template < /// Permute operand A typename PermuteALayout, /// Permute operand B - typename PermuteBLayout - > + typename PermuteBLayout> struct DefaultQuantBMma { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "simt epilogue must be row major"); + kAlignmentB, ElementQScale, ElementQOffset, + LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator, false, + GatherA, GatherB, PermuteALayout, PermuteBLayout> { + static_assert(platform::is_same::value || platform::is_same>::value, + "simt epilogue must be row major"); static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) @@ -208,7 +204,7 @@ struct DefaultQuantBMma; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; // Define iterators over tiles from the quant scales @@ -225,8 +221,8 @@ struct DefaultQuantBMma; using IteratorQOffset = cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< - typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, - 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< @@ -241,8 +237,8 @@ struct DefaultQuantBMma::value || is_complex::value) -> + bool IsComplex = false // (is_complex::value || is_complex::value) + > struct DefaultQuantBMmaCore; //////////////////////////////////////////////////////////////////////////////// @@ -173,10 +173,10 @@ template < /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultQuantBMmaCore { + layout::RowMajor, ElementB_, layout::ColumnMajor, + ElementQScale_, ElementQOffset_, LayoutQMeta_, QuantBlocking_, + ElementC_, LayoutC_, arch::OpClassTensorOp, Stages, + Operator_, false, CacheOpA, CacheOpB> { using Shape = Shape_; using WarpShape = WarpShape_; using InstructionShape = InstructionShape_; @@ -239,7 +239,7 @@ struct DefaultQuantBMmaCore::value, Shape::kK>; using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK/2>; + sizeof_bits::value, Shape::kK / 2>; // // Iterators to write to shared memory @@ -259,14 +259,14 @@ struct DefaultQuantBMmaCore, kThreads, + layout::PitchLinearShape, kThreads, layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to B operand using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, + MatrixShape, ElementB, SmemLayoutB, 1, IteratorThreadMapB>; using SmemLayoutQScale = LayoutQMeta; @@ -278,9 +278,9 @@ struct DefaultQuantBMmaCore 0, "QuantBlocking too big to fit in a thread block!"); static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, - "Only support single column or row quantize blocking!"); + "Only support single column or row quantize blocking!"); static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, - "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); /// Threadblock-level quantization meta data shape in pitch-linear layout using TBQPitchLinearShape = typename std::conditional< @@ -303,7 +303,7 @@ struct DefaultQuantBMmaCore; using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< - ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; static int const kElementsPerAccessQOffset = (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous @@ -316,7 +316,7 @@ struct DefaultQuantBMmaCore; using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< - ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; // // Warp-level matrix multiply operator @@ -330,7 +330,7 @@ struct DefaultQuantBMmaCore, - MatrixShape<0, 0>, WarpCount::kK>; + MatrixShape<0, 0>, WarpCount::kK>; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h index 6f27a692a3a2e..1e02e1264e09a 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -22,7 +22,6 @@ namespace cutlass { namespace transform { namespace threadblock { - //////////////////////////////////////////////////////////////////////////////// /// Optional 2-D matrix data loader, when element is std::monostate, the @@ -43,9 +42,8 @@ template < /// Number of threads in the threadblock, when provided, the iterator /// will utilize the higher numbered threads int kThreadBlockSize_ = -1> -class OptionalPredicatedTileAccessIterator{ +class OptionalPredicatedTileAccessIterator { public: - using Shape = Shape_; using Element = Element_; using Layout = Layout_; @@ -56,9 +54,9 @@ class OptionalPredicatedTileAccessIterator{ static constexpr int kThreadblockSize = kThreadBlockSize_; static_assert(!std::is_same::value, - "Disabled Iterator failed to match the specialized version below."); + "Disabled Iterator failed to match the specialized version below."); static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, - "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); using Base = PredicatedTileAccessIterator; @@ -72,7 +70,7 @@ class OptionalPredicatedTileAccessIterator{ static constexpr int kAccessesPerVector = Base::kAccessesPerVector; CUTLASS_HOST_DEVICE - static int flip_thread_id(int thread_id){ + static int flip_thread_id(int thread_id) { if constexpr (kThreadblockSize > 0) { return kThreadblockSize - 1 - thread_id; } @@ -80,17 +78,17 @@ class OptionalPredicatedTileAccessIterator{ } public: - Base base_; + Base base_; /// Default constructor - OptionalPredicatedTileAccessIterator(): base_() {}; + OptionalPredicatedTileAccessIterator() : base_(){}; /// Constructs a TileIterator from its precomputed state, threadblock offset, /// and thread ID CUTLASS_HOST_DEVICE OptionalPredicatedTileAccessIterator( /// Precomputed parameters object - Params const ¶ms, + Params const& params, /// Pointer to start of tensor Pointer pointer, /// Extent of tensor @@ -98,14 +96,14 @@ class OptionalPredicatedTileAccessIterator{ /// ID of each participating thread int thread_id, /// Initial offset of threadblock - TensorCoord const &threadblock_offset) + TensorCoord const& threadblock_offset) : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} /// Construct a PredicatedTileAccessIterator with zero threadblock offset CUTLASS_HOST_DEVICE OptionalPredicatedTileAccessIterator( /// Precomputed parameters object - Params const ¶ms, + Params const& params, /// Pointer to start of tensor Pointer pointer, /// Extent of tensor @@ -129,19 +127,19 @@ class OptionalPredicatedTileAccessIterator{ /// Advances an iterator along logical dimensions of matrix in units of whole tiles CUTLASS_DEVICE void add_tile_offset( - TensorCoord const &tile_offset) { + TensorCoord const& tile_offset) { base_.add_tile_offset(tile_offset); } /// Returns a pointer CUTLASS_HOST_DEVICE - AccessType *get() const { + AccessType* get() const { return base_.get(); } /// Increment and return an instance to self. CUTLASS_HOST_DEVICE - OptionalPredicatedTileAccessIterator &operator++() { + OptionalPredicatedTileAccessIterator& operator++() { ++base_; return *this; } @@ -168,13 +166,13 @@ class OptionalPredicatedTileAccessIterator{ /// Sets the predicate mask, overriding value stored in predicate iterator CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { + void set_mask(Mask const& mask) { base_.set_mask(mask); } /// Gets the mask CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { + void get_mask(Mask& mask) { base_.get_mask(mask); } @@ -198,9 +196,8 @@ template < typename ThreadMap_, typename AccessType_, int kThreadBlockSize_> -class OptionalPredicatedTileAccessIterator{ +class OptionalPredicatedTileAccessIterator { public: - using Shape = Shape_; using Element = std::monostate; using Layout = Layout_; @@ -225,14 +222,14 @@ class OptionalPredicatedTileAccessIterator::value * ThreadMap_::kElementsPerAccess / 8> -class OptionalRegularTileAccessIterator{ + sizeof_bits::value* ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator { public: - using Shape = Shape_; using Element = Element_; using Layout = Layout_; @@ -59,9 +58,9 @@ class OptionalRegularTileAccessIterator{ static constexpr int kThreadblockSize = ThreadblockSize_; static_assert(!std::is_same::value, - "Disabled Iterator failed to match the specialized template"); + "Disabled Iterator failed to match the specialized template"); static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, - "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); using Base = RegularTileAccessIterator; @@ -71,7 +70,7 @@ class OptionalRegularTileAccessIterator{ using AccessType = typename Base::AccessType; CUTLASS_HOST_DEVICE - static int flip_thread_id(int thread_id){ + static int flip_thread_id(int thread_id) { if constexpr (kThreadblockSize > 0) { return kThreadblockSize - 1 - thread_id; } @@ -79,15 +78,14 @@ class OptionalRegularTileAccessIterator{ } private: - Base base_; public: /// Construct a TileIterator with zero threadblock offset CUTLASS_HOST_DEVICE OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) + int thread_id ///< ID of each participating thread + ) : base_(ref, flip_thread_id(thread_id)) {} /// Overrides the internal iteration index @@ -104,13 +102,13 @@ class OptionalRegularTileAccessIterator{ /// Returns a pointer CUTLASS_DEVICE - AccessType *get() const { + AccessType* get() const { return base_.get(); } /// Advances to the next tile in memory. CUTLASS_HOST_DEVICE - OptionalRegularTileAccessIterator &operator++() { + OptionalRegularTileAccessIterator& operator++() { ++base_; return *this; } @@ -134,7 +132,7 @@ class OptionalRegularTileAccessIterator{ /// Below two classes map col/row major to the pitch linear coordinates used /// in this base class. CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { + void add_tile_offset(TensorCoord const& coord) { base_.add_tile_offset(coord); } }; @@ -151,9 +149,8 @@ template < int ThreadblockSize_, int Alignment> class OptionalRegularTileAccessIterator{ + AdvanceRank, ThreadMap_, ThreadblockSize_, Alignment> { public: - using Shape = Shape_; using Element = std::monostate; using Layout = Layout_; @@ -169,15 +166,14 @@ class OptionalRegularTileAccessIterator -struct QuantBLayoutDebug{ +struct QuantBLayoutDebug { static constexpr bool debug_smem = true; static constexpr bool debug_fragment = true; ElementWeight* smem_b_ptr_; @@ -77,15 +77,14 @@ struct QuantBLayoutDebug{ int lane_id_; int block_id_; - template - CUTLASS_DEVICE - static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + template + CUTLASS_DEVICE static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id) { static_assert(Size % 4 == 0, "Size must be multiple of 4"); - if constexpr (debug_fragment){ - if (block_id == 1 && warp_id == 0){ + if constexpr (debug_fragment) { + if (block_id == 1 && warp_id == 0) { const Element* ptr = reinterpret_cast(&frag); - for (int i = 0; i < Size/4; i++, ptr+=4){ - if constexpr(std::is_integral::value){ + for (int i = 0; i < Size / 4; i++, ptr += 4) { + if constexpr (std::is_integral::value) { printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", threadIdx.x, label, i, ptr[0], ptr[1], ptr[2], ptr[3]); @@ -99,21 +98,19 @@ struct QuantBLayoutDebug{ } } - template - CUTLASS_DEVICE - static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + template + CUTLASS_DEVICE static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id) { constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); - if constexpr (debug_fragment){ - if (block_id == 1 && warp_id == 0){ + if constexpr (debug_fragment) { + if (block_id == 1 && warp_id == 0) { const uint8_t* ptr = reinterpret_cast(&frag); - for (int i = 0; i < I8Size/2; i++, ptr+=2){ + for (int i = 0; i < I8Size / 2; i++, ptr += 2) { printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); } } } } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,8 +118,9 @@ struct QuantBLayoutDebug{ /// Dummy type when quant offset is not used, to avoid compilation error, /// and reduce runtime footprint /// -struct DummyType{ +struct DummyType { std::monostate dummy_; + public: DummyType() = default; @@ -137,7 +135,7 @@ struct DummyType{ } }; -} +} // namespace ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -239,8 +237,9 @@ class QuantBMmaBase { Shape::kN / QuantBlocking::kColumn>; using BufTypeQOffset = std::conditional_t, - DummyType>; + AlignedBuffer, + DummyType>; + public: // // Data members @@ -259,7 +258,6 @@ class QuantBMmaBase { BufTypeQOffset operand_QOffset; public: - // // Methods // @@ -306,7 +304,7 @@ class QuantBMmaBase { CUTLASS_HOST_DEVICE TensorRefQOffset operand_QOffset_ref() { - if constexpr (!kHasQOffset){ + if constexpr (!kHasQOffset) { return TensorRefQOffset(); } else { return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; @@ -315,7 +313,6 @@ class QuantBMmaBase { }; protected: - // // Data members // @@ -329,25 +326,21 @@ class QuantBMmaBase { /// Iterator to load a warp-scoped tile of quant scales from shared memory typename Operator::IteratorQMeta warp_tile_iterator_QScale_; -public: - + public: /// Construct from tensor references CUTLASS_DEVICE QuantBMmaBase( ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, + SharedStorage& shared_storage, ///< ID within the threadblock int thread_idx, ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx - ): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), - warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), - shared_storage.operand_QOffset_ref(), lane_idx) - {} + int lane_idx) : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -397,9 +390,8 @@ template < int Stages, /// Used for partial specialization typename Enable = bool> -class QuantBMmaMultistage : - public QuantBMmaBase { -public: +class QuantBMmaMultistage : public QuantBMmaBase { + public: ///< Base class using Base = QuantBMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> @@ -452,7 +444,6 @@ class QuantBMmaMultistage : /// Internal structure exposed for introspection. struct Detail { - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -497,11 +488,8 @@ class QuantBMmaMultistage : }; private: - - // Structure encapsulating pipeline state live from one iteration to the next struct PipeState { - using WarpLoadedFragmentA = typename Operator::FragmentA; using WarpLoadedFragmentB = typename Operator::FragmentB; using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; @@ -521,14 +509,12 @@ class QuantBMmaMultistage : WarpLoadedFragmentQScale warp_loaded_frag_QScale_; using WarpLoadedFragmentQOffset = typename std::conditional::type; + typename Operator::FragmentQOffset, + std::monostate>::type; WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; }; - private: - // // Data members // @@ -559,43 +545,39 @@ class QuantBMmaMultistage : /// Shared memory pointers for debug dumping static constexpr bool debug_layout = false; using LayoutDebugType = typename std::conditional, - std::monostate>::type; + QuantBLayoutDebug, + std::monostate>::type; LayoutDebugType layout_debug_; -public: - + public: /// Construct from tensor references CUTLASS_DEVICE QuantBMmaMultistage( ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, + typename Base::SharedStorage& shared_storage, ///< ID within the threadblock int thread_idx, ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), - smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), - should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), - should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) - { + int lane_idx) : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: // _m: the warp's position within the threadblock along the M dimension // _n: the warp's position within the threadblock along the N dimension // _k: the warp's position within the threadblock along the K dimension - if constexpr(debug_layout){ + if constexpr (debug_layout) { layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); } else { layout_debug_.smem_qoffset_ptr_ = nullptr; @@ -622,8 +604,7 @@ class QuantBMmaMultistage : /// Advance shared memory read-iterators to the next stage CUTLASS_DEVICE - void advance_smem_read_stage() - { + void advance_smem_read_stage() { ++smem_read_stage_idx_; if (smem_read_stage_idx_ == Base::kStages) { @@ -639,11 +620,10 @@ class QuantBMmaMultistage : /// Advance global memory read-iterators and shared memory write-iterators to the stage CUTLASS_DEVICE void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - IteratorQScale &iterator_QScale, - IteratorQOffset &iterator_QOffset) - { + IteratorA& iterator_A, + IteratorB& iterator_B, + IteratorQScale& iterator_QScale, + IteratorQOffset& iterator_QOffset) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); iterator_B.add_tile_offset({1, 0}); @@ -675,7 +655,7 @@ class QuantBMmaMultistage : } CUTLASS_DEVICE - void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + void copy_qscale_tiles(IteratorQScale& iterator_QScale) { // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so @@ -687,41 +667,41 @@ class QuantBMmaMultistage : "Quant scale should 1 access per vector!"); // Async Copy for quantization scale - typename IteratorQScale::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQScale::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QScale_.get()); constexpr int kSrcBytes = sizeof_bits::value * - IteratorQScale::ThreadMap::kElementsPerAccess / 8; + IteratorQScale::ThreadMap::kElementsPerAccess / 8; cutlass::arch::cp_async( dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); } CUTLASS_DEVICE - void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + void copy_qoffset_tiles(IteratorQOffset& iterator_QOffset) { static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); - if constexpr(kHasQOffset) { + if constexpr (kHasQOffset) { // Async Copy for quantization offset - typename IteratorQOffset::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQOffset::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QOffset_.get()); constexpr int kSrcBytes = sizeof_bits::value * IteratorQOffset::ThreadMap::kElementsPerAccess / 8; cutlass::arch::cp_async( - dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); } } CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start = 0) { auto group_start_A = group_start * Detail::kAccessesPerGroupA; iterator_A.set_iteration_index(group_start_A * @@ -732,8 +712,8 @@ class QuantBMmaMultistage : CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_A_.get()); int const kSrcBytes = sizeof_bits::value * @@ -763,8 +743,8 @@ class QuantBMmaMultistage : CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_B_.get()); int const kSrcBytes = sizeof_bits::value * @@ -789,16 +769,15 @@ class QuantBMmaMultistage : /// the global fragments needed by the first kStages-1 threadblock mainloop iterations CUTLASS_DEVICE void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory - IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset& iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int& gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); @@ -810,8 +789,8 @@ class QuantBMmaMultistage : // Async Copy for operand A CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_A_.get()); CUTLASS_PRAGMA_UNROLL @@ -838,8 +817,8 @@ class QuantBMmaMultistage : // Async Copy for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_B_.get()); CUTLASS_PRAGMA_UNROLL @@ -862,8 +841,8 @@ class QuantBMmaMultistage : static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); - typename IteratorQScale::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQScale::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QScale_.get()); constexpr int kSrcBytes = @@ -881,13 +860,13 @@ class QuantBMmaMultistage : // Async Copy for quantization offset static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); - typename IteratorQOffset::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQOffset::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QOffset_.get()); constexpr int kSrcBytes = sizeof_bits::value * - IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; cutlass::arch::cp_async( dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); @@ -901,22 +880,20 @@ class QuantBMmaMultistage : } } - /// Wait until we have at least one completed global fetch stage CUTLASS_DEVICE - void gmem_wait() - { + void gmem_wait() { // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) cutlass::arch::cp_async_wait(); __syncthreads(); - if constexpr(debug_layout) { + if constexpr (debug_layout) { if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { - if (threadIdx.x == 0){ + if (threadIdx.x == 0) { printf("stage: %d\n", smem_write_stage_idx_); } cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); } } @@ -926,13 +903,13 @@ class QuantBMmaMultistage : /// Perform a threadblock mainloop iteration of matrix multiply-accumulate CUTLASS_DEVICE void mac_loop_iter( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory - IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + PipeState& pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC& accum, ///< [in|out] destination accumulator tile + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset& iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int& gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL @@ -960,35 +937,34 @@ class QuantBMmaMultistage : iterator_B, (warp_mma_k + 1) % Base::kWarpGemmIterations); - if constexpr(debug_layout) { - if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + if constexpr (debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0) { printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); } LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } warp_mma_.transform( - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_loaded_frag_QScale_, - pipe_state.warp_loaded_frag_QOffset_); + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout) { + if constexpr (debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } // Execute the current warp-tile of MMA operations if (Detail::kStagedAccumulation) { warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); if (warp_mma_k == 0) { plus plus_accum; @@ -997,11 +973,10 @@ class QuantBMmaMultistage : } } else { warp_mma_( - accum, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum); } if (warp_mma_k == 0) { @@ -1025,51 +1000,50 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } // Wait until we have at least one completed global fetch stage gmem_wait(); } - } } /// Specialized mainloop iteration of matrix multiply-accumulate, for small M CUTLASS_DEVICE void mac_loop_iter_small_m( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory - IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + PipeState& pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC& accum, ///< [in|out] destination accumulator tile + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset& iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int& gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { // In the case of small M, memory latency dominates. We try to move uses far // from their definitions to hide latency. - if constexpr(debug_layout) { - if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + if constexpr (debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0) { printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); } LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } warp_mma_.transform( - pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_loaded_frag_QScale_, - pipe_state.warp_loaded_frag_QOffset_); + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout) { + if constexpr (debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } @@ -1096,11 +1070,10 @@ class QuantBMmaMultistage : // Execute the current warp-tile of MMA operations if (Detail::kStagedAccumulation) { warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); if (warp_mma_k == 0) { plus plus_accum; @@ -1109,11 +1082,10 @@ class QuantBMmaMultistage : } } else { warp_mma_( - accum, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum); } // The second-to-last warp-tile also moves to the next global fetch stage @@ -1130,7 +1102,7 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } @@ -1140,21 +1112,19 @@ class QuantBMmaMultistage : // Wait until we have at least one completed global fetch stage gmem_wait(); } - } } - /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. CUTLASS_DEVICE void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory - IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC& accum, ///< [in|out] accumulator tile + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset& iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory { PipeState pipe_state; @@ -1162,7 +1132,7 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset) { + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } @@ -1183,26 +1153,26 @@ class QuantBMmaMultistage : copy_tiles_and_advance(iterator_A, iterator_B, 0); - if constexpr(Shape::kM > 32) { + if constexpr (Shape::kM > 32) { // the case of bigger m - if constexpr(debug_layout) { - if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + if constexpr (debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0) { printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); } LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } warp_mma_.transform( - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_loaded_frag_QScale_, - pipe_state.warp_loaded_frag_QOffset_); + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout) { + if constexpr (debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } else { @@ -1218,24 +1188,24 @@ class QuantBMmaMultistage : // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { - if constexpr(Shape::kM > 32) { + if constexpr (Shape::kM > 32) { mac_loop_iter( - pipe_state, - accum, - iterator_A, - iterator_B, - iterator_QScale, - iterator_QOffset, - gemm_k_iterations); + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); } else { mac_loop_iter_small_m( - pipe_state, - accum, - iterator_A, - iterator_B, - iterator_QScale, - iterator_QOffset, - gemm_k_iterations); + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); } } @@ -1248,17 +1218,15 @@ class QuantBMmaMultistage : cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); - } - /// Perform a threadblock-scoped matrix multiply-accumulate CUTLASS_DEVICE void operator()( ///< problem size of GEMM int gemm_k_iterations, ///< destination accumulator tile - FragmentC &accum, + FragmentC& accum, ///< iterator over A operand in global memory IteratorA iterator_A, ///< iterator over B operand in global memory @@ -1268,8 +1236,7 @@ class QuantBMmaMultistage : ///< Iterator over quant offsets in global memory IteratorQOffset iterator_QOffset, ///< initial value of accumulator - FragmentC const &src_accum) { - + FragmentC const& src_accum) { // Prologue (start fetching iterations of global fragments into shared memory) prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h index 2c49888c94504..858e616e9b5be 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -101,9 +101,9 @@ struct DefaultQuantBMmaTensorOp { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 0b9ac0cb0e087..ce00c74103d70 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -33,21 +33,21 @@ //////////////////////////////////////////////////////////////////////////////// -namespace{ +namespace { -struct b32_pair{ +struct b32_pair { uint32_t a; uint32_t b; }; -struct fp16_quad{ +struct fp16_quad { cutlass::half_t a; cutlass::half_t b; cutlass::half_t c; cutlass::half_t d; }; -struct b16_quad{ +struct b16_quad { int16_t a; int16_t b; int16_t c; @@ -66,17 +66,15 @@ static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); /// Convert packed 4b weights into fp16(weight + 16) /// Current bit hacking only supports fp16, need to add bf16 later. /// -template -CUTLASS_DEVICE -void weights2Half(cutlass::Array const &weights, - cutlass::Array& dest) -{ +template +CUTLASS_DEVICE void weights2Half(cutlass::Array const& weights, + cutlass::Array& dest) { static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); uint32_t* dest_pair = reinterpret_cast(dest.data()); const uint32_t* w_oct = reinterpret_cast(weights.data()); CUTLASS_PRAGMA_UNROLL - for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ + for (int oct_idx = 0; oct_idx < Size / 8; oct_idx++, w_oct++, dest_pair += 4) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // static_cast(16 + weight) @@ -88,7 +86,7 @@ void weights2Half(cutlass::Array const &weights, " shl.b32 %1, %4, 2;\n" " shr.u32 %2, %4, 2;\n" " shr.u32 %3, %4, 6;\n" - " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" @@ -100,10 +98,9 @@ void weights2Half(cutlass::Array const &weights, assert(0); #endif } - } -} // namespace +} // namespace //////////////////////////////////////////////////////////////////////////////// @@ -117,18 +114,17 @@ namespace warp { // Since operand B is quantized on a per block basis, it's one meta data per block. template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads> -class QuantBMetaMmaTile{ -public: - + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile { + public: using WarpShapeB = WarpShapeB_; using BlockingShape = BlockingShape_; using ArchMmaOperator = ArchMmaOperator_; @@ -169,8 +165,9 @@ class QuantBMetaMmaTile{ /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension /// exceeds the tile depth, so two tiles share the same meta data static int const kTilesPerMma = ((kBTilesPerMma == 2) && - (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) - ? 2 : 1; + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 + : 1; /// stride to reach the meta data for the next CoreTile on the K dimension static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; @@ -190,24 +187,21 @@ class QuantBMetaMmaTile{ CUTLASS_DEVICE static MatrixCoord lane_position(int lane_id) { - if constexpr(kNumBsPerCoreTileFragement == 2 - && kBTilesPerMma == 2 - && BlockingShape::kRow == 1){ + if constexpr (kNumBsPerCoreTileFragement == 2 && kBTilesPerMma == 2 && BlockingShape::kRow == 1) { // Optimize for a special case of: // 16b gemm (kNumBsPerCoreTileFragement == 2) // 2 B operand tiles per mma (kBTilesPerMma == 2) // (1,n) quantization blocking // The scale and offset tensors are prepacked to reduce the number of load instructions. return make_Coord((lane_id % CoreTile::kContiguous) * 4, - lane_id / CoreTile::kContiguous); + lane_id / CoreTile::kContiguous); } else { return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, - lane_id / CoreTile::kContiguous); + lane_id / CoreTile::kContiguous); } } }; - //////////////////////////////////////////////////////////////////////////////// /// This tile iterator is to load quantization meta data for operand B from @@ -222,25 +216,25 @@ class QuantBMetaMmaTile{ /// out the operand B layout in the tensor core. /// template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Data type of the quant scales - typename ElementScale_, - /// Layout of the quant scales - typename LayoutScale_, - /// Data type of quant offsets - typename ElementOffset_, - /// Layout of quant offsets - typename LayoutOffset_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads, - /// Number of partitions along K dimension - int PartitionsK_ = 1> + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> class QuantBMetaMmaTensorOpTileIterator; //////////////////////////////////////////////////////////////////////////////// @@ -248,25 +242,24 @@ class QuantBMetaMmaTensorOpTileIterator; /// Specialization for column major layout template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Data type of the meta data elements - typename ElementScale_, - /// Data type of quant offsets - typename ElementOffset_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads> + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> class QuantBMetaMmaTensorOpTileIterator{ -public: - + ElementScale_, cutlass::layout::ColumnMajor, + ElementOffset_, cutlass::layout::ColumnMajor, + ArchMmaOperator_, Threads, 1> { + public: using WarpShapeB = WarpShapeB_; using BlockingShape = BlockingShape_; using ElementScale = ElementScale_; @@ -277,7 +270,7 @@ class QuantBMetaMmaTensorOpTileIterator::value); static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, - "Only support row blocking for column major layout"); + "Only support row blocking for column major layout"); using MetaTile = QuantBMetaMmaTile; @@ -316,44 +309,39 @@ class QuantBMetaMmaTensorOpTileIterator; using FragmentOffset = typename std::conditional, - std::monostate>::type; + Array, + std::monostate>::type; using AccessTypeScale = Array; using AccessTypeOffset = Array; -private: - - ElementScale *pointer_; + private: + ElementScale* pointer_; Layout layout_; - ElementOffset *pointer_offset_; + ElementOffset* pointer_offset_; Layout layout_offset_; TensorCoord lane_position_; -public: - + public: CUTLASS_DEVICE - QuantBMetaMmaTensorOpTileIterator() { } + QuantBMetaMmaTensorOpTileIterator() {} CUTLASS_DEVICE QuantBMetaMmaTensorOpTileIterator( - TensorRefScale const &ref, - TensorRefOffset const &ref_offset, - int lane_idx - ): - pointer_(ref.data()), - layout_(ref.layout()), - pointer_offset_(ref_offset.data()), - layout_offset_(ref_offset.layout()), - lane_position_(MetaTile::lane_position(lane_idx)){} + TensorRefScale const& ref, + TensorRefOffset const& ref_offset, + int lane_idx) : pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) {} /// Loads a fragment CUTLASS_HOST_DEVICE - void load(FragmentScale &frag, FragmentOffset &frag_offset) { - if constexpr(kNumBsPerCoreTileFragement == 2 - && kBTilesPerMma == 2){ + void load(FragmentScale& frag, FragmentOffset& frag_offset) { + if constexpr (kNumBsPerCoreTileFragement == 2 && kBTilesPerMma == 2) { // Optimize for a special case of: // 16b gemm (kNumBsPerCoreTileFragement == 2) // 2 B operand tiles per mma (kBTilesPerMma == 2) @@ -362,19 +350,19 @@ class QuantBMetaMmaTensorOpTileIterator *dst_ptr = reinterpret_cast*>(frag.data()); + Array* dst_ptr = reinterpret_cast*>(frag.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ - Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { + Array* src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); *dst_ptr = *src_ptr; dst_ptr++; } - if constexpr(kHasOffset){ - Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + if constexpr (kHasOffset) { + Array* dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ - Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { + Array* src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); *dst_ptr_offset = *src_ptr_offset; dst_ptr_offset++; } @@ -388,21 +376,21 @@ class QuantBMetaMmaTensorOpTileIterator(frag.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { CUTLASS_PRAGMA_UNROLL - for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride) { AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); *dst_ptr = *src_ptr; dst_ptr++; } } - if constexpr(kHasOffset){ + if constexpr (kHasOffset) { AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { CUTLASS_PRAGMA_UNROLL - for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride) { AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); *dst_ptr = *src_ptr; dst_ptr++; @@ -413,18 +401,17 @@ class QuantBMetaMmaTensorOpTileIterator - CUTLASS_HOST_DEVICE - static Array debug_expand(Array const &frag){ + CUTLASS_HOST_DEVICE static Array debug_expand(Array const& frag) { Array ret; int out_idx = 0; CUTLASS_PRAGMA_UNROLL - for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + for (int n_out = 0; n_out < kMmaIterationsB; n_out++) { int n_idx = n_out / kNRepeats; CUTLASS_PRAGMA_UNROLL - for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++) { int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); CUTLASS_PRAGMA_UNROLL - for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++) { int elem_idx = elem_out_idx / BlockingShape::kRow; int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; ret[out_idx] = frag[idx]; @@ -436,17 +423,17 @@ class QuantBMetaMmaTensorOpTileIterator const &weights, - Array& dest){ + static void dequant(FragmentScale const& scales, + FragmentOffset const& fragment_offsets, + Array const& weights, + Array& dest) { static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); // First convert 4b weight into fp16(weight + 16) weights2Half(weights, dest); - if constexpr(kBTilesPerMma == 2){ + if constexpr (kBTilesPerMma == 2) { // Optimize for a special case of: // 2 B operand tiles per mma (kBTilesPerMma == 2) // (1,n) quantization blocking (BlockingShape::kRow == 1) @@ -454,28 +441,30 @@ class QuantBMetaMmaTensorOpTileIterator(dest.data()); const b64* scales_ptr = reinterpret_cast(scales.data()); [[maybe_unused]] const ElementOffset* fragment_offsets_ptr = nullptr; - if constexpr(kHasOffset) { fragment_offsets_ptr = fragment_offsets.data(); } + if constexpr (kHasOffset) { + fragment_offsets_ptr = fragment_offsets.data(); + } CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++) { // dequantize: d = scale * (weight - offset) // to use FMA, d = scale * weight + (scale * (-offset)) [[maybe_unused]] b64 offsets{0}; - if constexpr(kHasOffset) { + if constexpr (kHasOffset) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) const uint32_t* p = reinterpret_cast(fragment_offsets_ptr); asm volatile( "{\n\t" - " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands // static_cast(-16 - offset) // input [d, b, c, a], - " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 - " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 - " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" - " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) " mul.rn.f16x2 %1, %3, rb1;\n" "}\n" : "=r"(offsets.pair.a), "=r"(offsets.pair.b) @@ -492,25 +481,25 @@ class QuantBMetaMmaTensorOpTileIteratorpair.a), "r"(scales_ptr->pair.b)); #else - offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); - offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); - offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); - offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16 - 8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16 - 8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16 - 8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16 - 8); #endif } CUTLASS_PRAGMA_UNROLL - for (int n_r = 0; n_r < kNRepeats; n_r++){ + for (int n_r = 0; n_r < kNRepeats; n_r++) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( "{\n\t" - " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) " fma.rn.f16x2 %1, %3, %1, %5;\n" "}\n" : "+r"(dest_pair[0]), "+r"(dest_pair[1]) @@ -529,75 +518,70 @@ class QuantBMetaMmaTensorOpTileIterator(-16 - static_cast(fragment_offsets[idx])); } else { - offset = s * static_cast(-16-8); + offset = s * static_cast(-16 - 8); } dest[out_idx] = s * dest[out_idx] + offset; out_idx++; } } } - } - } /// Advances the pointer CUTLASS_HOST_DEVICE - QuantBMetaMmaTensorOpTileIterator &operator++() { + QuantBMetaMmaTensorOpTileIterator& operator++() { // This is for operand B, so advance on the K dimension lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); return *this; } CUTLASS_DEVICE - QuantBMetaMmaTensorOpTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { + QuantBMetaMmaTensorOpTileIterator& add_tile_offset( + TensorCoord const& tile_offset) { int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; lane_position_ += TensorCoord(rows, columns); return *this; } - }; - //////////////////////////////////////////////////////////////////////////////// /// Specialization for row major layout template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Data type of the meta data elements - typename ElementScale_, - /// Data type of quant offsets - typename ElementOffset_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads> + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> class QuantBMetaMmaTensorOpTileIterator{ -public: - + ElementScale_, cutlass::layout::RowMajor, + ElementOffset_, cutlass::layout::RowMajor, + ArchMmaOperator_, Threads, 1> { + public: using WarpShapeB = WarpShapeB_; using BlockingShape = BlockingShape_; using ElementScale = ElementScale_; @@ -608,7 +592,7 @@ class QuantBMetaMmaTensorOpTileIterator::value); static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, - "Only support column blocking for row major layout"); + "Only support column blocking for row major layout"); using MetaTile = QuantBMetaMmaTile; @@ -647,40 +631,35 @@ class QuantBMetaMmaTensorOpTileIterator; using FragmentOffset = typename std::conditional, - std::monostate>::type; + Array, + std::monostate>::type; -private: - - ElementScale *pointer_; + private: + ElementScale* pointer_; Layout layout_; - ElementOffset *pointer_offset_; + ElementOffset* pointer_offset_; Layout layout_offset_; TensorCoord lane_position_; -public: - + public: CUTLASS_DEVICE - QuantBMetaMmaTensorOpTileIterator() { } + QuantBMetaMmaTensorOpTileIterator() {} CUTLASS_DEVICE QuantBMetaMmaTensorOpTileIterator( - TensorRefScale const &ref, - TensorRefOffset const &ref_offset, - int lane_idx - ): - pointer_(ref.data()), - layout_(ref.layout()), - pointer_offset_(ref_offset.data()), - layout_offset_(ref_offset.layout()), - lane_position_(MetaTile::lane_position(lane_idx)) - {} + TensorRefScale const& ref, + TensorRefOffset const& ref_offset, + int lane_idx) : pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) {} /// Loads a fragment CUTLASS_HOST_DEVICE - void load(FragmentScale &frag, FragmentOffset &frag_offset) { + void load(FragmentScale& frag, FragmentOffset& frag_offset) { const int row = lane_position_.row() / BlockingShape::kRow; const int column = lane_position_.column() / BlockingShape::kColumn; static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); @@ -688,34 +667,33 @@ class QuantBMetaMmaTensorOpTileIterator - CUTLASS_HOST_DEVICE - static Array debug_expand(Array const &frag){ + CUTLASS_HOST_DEVICE static Array debug_expand(Array const& frag) { Array ret; int out_idx = 0; CUTLASS_PRAGMA_UNROLL - for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + for (int n_out = 0; n_out < kMmaIterationsB; n_out++) { int n_idx = n_out / kNRepeats; CUTLASS_PRAGMA_UNROLL - for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++) { int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); CUTLASS_PRAGMA_UNROLL - for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++) { int elem_idx = elem_out_idx / BlockingShape::kRow; int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; int idx = col * kMmaIterations + n_idx; @@ -728,10 +706,10 @@ class QuantBMetaMmaTensorOpTileIterator const &weights, - Array& dest){ + static void dequant(FragmentScale const& scales, + FragmentOffset const& offsets, + Array const& weights, + Array& dest) { static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); @@ -742,30 +720,30 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); uint32_t* addon_ptr = reinterpret_cast(addon); - if constexpr(kHasOffset){ + if constexpr (kHasOffset) { const uint32_t* p = reinterpret_cast(offsets.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( - "{\n\t" - " .reg .b32 rb0, rb1, rb2;\n" + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" - // offset from [d, c, b, a] --> [d, b, c, a] - " prmt.b32 rb2, %4, rb0, 0x3120;\n" + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" - // static_cast(-16 - offset) - // input [d, b, c, a], - " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 - " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 - " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 - " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" - " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) - " mul.rn.f16x2 %1, %3, rb1;\n" - "}\n" - : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) - : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), - "r"(p[0])); + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); #else assert(0); #endif @@ -775,17 +753,17 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) asm volatile( - "{\n\t" - " .reg .b32 rb0;\n" - " mov.u32 rb0, 0xce00ce00;\n" - " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) - " mul.rn.f16x2 %1, %3, rb0;\n" - "}\n" - : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) - : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); #else assert(0); #endif @@ -794,7 +772,7 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) const uint32_t* scales_ptr = reinterpret_cast(scales.data()); uint32_t* addon_ptr = reinterpret_cast(addon); @@ -802,21 +780,20 @@ class QuantBMetaMmaTensorOpTileIterator(offsets.data()); asm volatile( - "{\n\t" - " .reg .b32 rb0, rb1, rb2;\n" - - // offset from [?, ?, b, a] --> [?, b, ?, a] - " prmt.b32 rb2, %2, rb0, 0x3120;\n" - - // static_cast(-16 - offset) - // input [d, b, c, a], - " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 - " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 - " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) - "}\n" - : "=r"(addon_ptr[0]) - : "r"(scales_ptr[0]) - "r"(p[0])); + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) "r"(p[0])); #else assert(0); #endif @@ -825,32 +802,32 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); uint32_t* addon_ptr = reinterpret_cast(addon); asm volatile( - "{\n\t" - " .reg .b32 rb0;\n" - " mov.u32 rb0, 0xce00ce00;\n" - " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) - "}\n" - : "=r"(addon_ptr[0]) - : "r"(scales_ptr[0])); + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); #else assert(0); #endif } } else { // kMmaIterationsB == 1 - if constexpr(kHasOffset){ + if constexpr (kHasOffset) { uint8_t zp = offsets[0]; addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); } else { - addon[0] = scales[0] * static_cast(-16-8); + addon[0] = scales[0] * static_cast(-16 - 8); } } int out_idx = 0; CUTLASS_PRAGMA_UNROLL - for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + for (int n_out = 0; n_out < kMmaIterationsB; n_out++) { CUTLASS_PRAGMA_UNROLL - for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++) { dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; out_idx += 2; @@ -860,24 +837,22 @@ class QuantBMetaMmaTensorOpTileIterator - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Data type of quant scales - typename ElementQScale_, - /// Layout of quant scales (concept: MatrixLayout) - typename SmemLayoutQScale_, - /// Data type of quant offsets - typename ElementQOffset_, - /// Layout of quant offsets (concept: MatrixLayout) - typename SmemLayoutQOffset_, - /// Blocking dimensions of quantization - typename QuantBlocking_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> class QuantBMmaTensorOp { -public: + public: /// Shape of warp-level matrix operation (concept: GemmShape) using Shape = Shape_; @@ -157,13 +156,12 @@ class QuantBMmaTensorOp { /// Number of partitions along K dimension static int const kPartitionsK = PartitionsK_; -public: - + public: /// Iterates over the A operand in memory using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for A tile using FragmentA = typename IteratorA::Fragment; @@ -174,8 +172,8 @@ class QuantBMmaTensorOp { /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; // warp B MatrixShape<64, 64>, // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, @@ -184,7 +182,7 @@ class QuantBMmaTensorOp { // FragmentB::kElements 32 /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; // cutlass::Array + using FragmentB = typename IteratorB::Fragment; // cutlass::Array /// Storage for transformed B tile /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded @@ -196,8 +194,8 @@ class QuantBMmaTensorOp { /// Iterates over the C operand in memory using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; /// Storage for C tile using FragmentC = typename IteratorC::Fragment; @@ -218,26 +216,23 @@ class QuantBMmaTensorOp { // TODO This is an expanding iterator, it needs to replicate the quantization parameters // to all threads in the warp. using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< - MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, - ElementQOffset, SmemLayoutQOffset, - ArchMmaOperator, kThreadCount, kPartitionsK>; + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; using FragmentQScale = typename IteratorQMeta::FragmentScale; using FragmentQOffset = typename IteratorQMeta::FragmentOffset; /// Number of mma operations performed using MmaIterations = MatrixShape< - (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN - >; - -public: + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + public: /// Underlying matrix multiply operator (concept: arch::Mma) ArchMmaOperator mma; -public: - + public: // // Methods // @@ -249,113 +244,106 @@ class QuantBMmaTensorOp { /// Performs a warp-level matrix multiply-accumulate operation CUTLASS_DEVICE void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - + FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C) const { using MmaOperandA = typename ArchMmaOperator::FragmentA; using MmaOperandB = typename ArchMmaOperator::FragmentB; using MmaOperandC = typename ArchMmaOperator::FragmentC; D = C; - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - // The visitation order is like - // _ - // | | | | - // | | | | - // |_| |_| - // - // Down Up Down Up - + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n], ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } else { - mma( + } else { + mma( ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n], ptr_D[m_serpentine + n * MmaIterations::kRow]); - } } } - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - // The visitation order is like - // _________ - // _________| - // |_________ - // __________| - // - // Right Left Right Left - + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine], ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } } - #else - assert(0); - #endif + } +#else + assert(0); +#endif } /// Transform the mma operands to the required types CUTLASS_DEVICE - void transform(TransformedFragmentB &dst_B, - FragmentB const &B, - FragmentQScale const &scales, - FragmentQOffset const &offsets) const { - - Array const *ptr_B = - reinterpret_cast const *>(&B); + void transform(TransformedFragmentB& dst_B, + FragmentB const& B, + FragmentQScale const& scales, + FragmentQOffset const& offsets) const { + Array const* ptr_B = + reinterpret_cast const*>(&B); IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// -//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" +// #include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cdfd283899c8c..675f7c7a13e8c 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1825,3 +1825,35 @@ MlasNhwcAvgPool( ); #endif + +struct MlasFlashAttentionThreadedArgs { + int batch_size; + int num_heads; + int q_sequence_length; + int kv_sequence_length; + int qk_head_size; + int v_head_size; + int q_block_size; + int kv_block_size; + float scale; + int thread_count; + float* buffer; + size_t buffer_size_per_thread; + const float* query; + const float* key; + const float* value; + float* output; +}; + +/** + * @brief Per-thread worker function for fp32 Flash Attention + * @param thread_id Thread index + * @param args Arguments + * @return +*/ +void +MLASCALL +MlasFlashAttention( + MlasFlashAttentionThreadedArgs* args, + MLAS_THREADPOOL* ThreadPool +); diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 316344ad8c214..aec14070ffd55 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -358,3 +358,80 @@ MlasDequantizeBlockwise( int columns, MLAS_THREADPOOL* thread_pool ); + +/** + * @brief Blockwise 4 bits quantization. After quantization, the weights and zero points + * are packed row-wise. If zero_points is null, quantized type is int4 with default + * zero point 0, to align with DQ schema. Otherwise, quantized type is uint4. + * In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales. + * @tparam Tin + * @tparam qbits number of bits used for quantization, only 4 is supported + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param scales points to the scales matrix, row major + * @param zero_points points to the zero_points matrix, row major + * @param dst points to the quantized matrix, shape [rows, columns] row major in qbits type. + * In uint8_t type, shape is [rows, columns * qbits / 8]. + * @param columnwise true when quantize elements in a column, false when quantize elements in a row. + * @param rows + * @param columns + * @param quant_block_size number of elements in a quantize block + * @param thread_pool + * @return the quantized type is signed. + */ +template +bool +MlasQDQQuantizeBlockwise( + const Tin* src, + Tin* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero + * points are packed row-wise. The dst tensors are column major. dst weights and zero points + * are packed column-wise. + * dst_weights and dst_zero_points are in uint4. + * If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are + * converted to uint4 by adding 8. + * If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8. + * src_zero_points is 0 and dst_zero_points is 8. + * If src_weights is uint4 and has src_zero_points, just transpose. + * If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with + * 0 values. Otherwise exception is thrown. + * @tparam Tin + * @tparam qbits number of bits used for quantization, only 4 is supported + * @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned + * @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type. + * In uint8_t type, shape is [rows, columns * qbits / 8]. + * @param src_scales points to the scales matrix, row major + * @param src_zero_points points to the zero_points matrix, row major. Packed row-wise. + * @param dst_weights points to the quantized matrix, column major. Packed column-wise. + * @param dst_scales points to the scales matrix, column major + * @param dst_zero_points points to the zero_points matrix, column major. Packed column-wise. + * @param columnwise true when quantize elements in a column, false when quantize elements in a row. + * @param rows + * @param columns + * @param quant_block_size number of elements in a quantize block + * @param thread_pool + */ +template +void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const Tin* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + Tin* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); diff --git a/onnxruntime/core/mlas/lib/flashattn.cpp b/onnxruntime/core/mlas/lib/flashattn.cpp new file mode 100644 index 0000000000000..fe5402ed144aa --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn.cpp @@ -0,0 +1,167 @@ +#include + +#include "mlasi.h" + +void +MlasFlashAttentionThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionThreadedArgs* args = reinterpret_cast(argptr); + ptrdiff_t q_block_size = static_cast(args->q_block_size); + ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + ptrdiff_t batch_size = static_cast(args->batch_size); + ptrdiff_t num_heads = static_cast(args->num_heads); + ptrdiff_t q_sequence_length = static_cast(args->q_sequence_length); + ptrdiff_t kv_sequence_length = static_cast(args->kv_sequence_length); + ptrdiff_t qk_head_size = static_cast(args->qk_head_size); + ptrdiff_t v_head_size = static_cast(args->v_head_size); + float* buffer = args->buffer; + ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + ptrdiff_t thread_count = static_cast(args->thread_count); + const float* query = args->query; + const float* key = args->key; + const float* value = args->value; + float* output = args->output; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + char* buffer_current_thread = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_current_thread); + float* m = l + q_block_size; + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + } + float* intermediate = m + q_block_size; + float* temp_output = intermediate + q_block_size * kv_block_size; + float negmax = 0; + + for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) { + /* + S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T + old_m = m + m = max(m, rowmax(S)) + diff = old_m - m + S = exp(S - m) + l = exp(diff) * l + rowsum(S) + O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :] + */ + ptrdiff_t h = batch_idx * num_heads + head_idx; + const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size; + const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size; + const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size; + + size_t row_size_q_capped = static_cast(std::min(q_block_size, q_sequence_length - q_idx)); + size_t row_size_kv_capped = static_cast(std::min(kv_block_size, kv_sequence_length - ir)); + + MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, + CBLAS_TRANSPOSE::CblasTrans, + row_size_q_capped, + row_size_kv_capped, + static_cast(qk_head_size), + args->scale, + inputQ, + static_cast(qk_head_size), + inputK, + static_cast(qk_head_size), + 0.0f, + intermediate, + row_size_kv_capped); + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q_capped); ++irow) { + float* p = intermediate + irow * row_size_kv_capped; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped); +#endif + float m_diff = m[irow]; + m[irow] = std::max(m[irow], rowmax); // new m + negmax = -m[irow]; + m_diff -= m[irow]; // old - new (less than 0) + +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); +#endif + + // Note: for ir == 0, there is actually no need to calculate exp_diff + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { + temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol]; + } + } else { + l[irow] = rowsum; + // When ir == 0, there is no need to scale the old result because it is zero. + } + } + MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, + CBLAS_TRANSPOSE::CblasNoTrans, + row_size_q_capped, + static_cast(v_head_size), + row_size_kv_capped, + 1.0f, + intermediate, + row_size_kv_capped, + inputV, + static_cast(v_head_size), + ir == 0 ? 0.0f : 1.0f, + temp_output, + static_cast(v_head_size)); + } + + float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size; + ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx); + // TODO: leverage advanced instruction sets + for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) { + for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { + output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow]; + } + output_row += num_heads * v_head_size; + } + } +} + +void +MLASCALL +MlasFlashAttention( + MlasFlashAttentionThreadedArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + MlasExecuteThreaded( + MlasFlashAttentionThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 72eb35c894094..859b7c2f560a4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -20,8 +20,15 @@ Module Name: #include #include -#if defined(MLAS_TARGET_POWER) && defined(__linux__) +#if defined(MLAS_TARGET_POWER) +#if defined(__linux__) #include +#elif defined(_AIX) +#define POWER_10 0x40000 +#define POWER_10_ANDUP (POWER_10) +#include +#define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) +#endif #endif #if defined(MLAS_TARGET_ARM64) @@ -554,6 +561,9 @@ Return Value: unsigned long hwcap2 = getauxval(AT_HWCAP2); bool HasP9Instructions = hwcap2 & PPC_FEATURE2_ARCH_3_00; +#elif defined(_AIX) + bool HasP9Instructions = __power_9_andup(); +#endif // __linux__ if (HasP9Instructions) { this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelVSX; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelVSX; @@ -562,7 +572,11 @@ Return Value: #if defined(POWER10) #if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ (defined(__clang__) && (__clang_major__ >= 12)) +#if defined(__linux__) bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); +#elif defined(_AIX) + bool HasP10Instructions = (__power_10_andup() && __power_mma_version() == MMA_V31); +#endif // __linux__ if (HasP10Instructions) { this->GemmFloatKernel = MlasSgemmKernelPOWER10; this->GemmDoubleKernel = MlasDgemmKernelPOWER10; @@ -571,7 +585,6 @@ Return Value: #endif #endif -#endif // __linux__ #endif // MLAS_TARGET_POWER #if defined(MLAS_TARGET_LARCH64) @@ -676,7 +689,6 @@ MlasPlatformU8S8Overflow( } #endif - thread_local size_t ThreadedBufSize = 0; #ifdef _MSC_VER thread_local std::unique_ptr ThreadedBufHolder(nullptr, &_aligned_free); diff --git a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp index a67be1dbfa710..0f3bc1d579711 100644 --- a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp +++ b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp @@ -874,10 +874,18 @@ MlasQgemmStoreVectorMMA { size_t RowCount; __vector signed int vsum0, vsum1, vsum2, vsum3; +#if defined(_AIX) && defined(__clang__) + __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); +#else __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); +#endif C += VectorCount; if (ZeroPointB != nullptr) { +#if defined(_AIX) && defined(__clang__) + __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); +#else __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); +#endif if (ZeroMode) { for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index b5784ecb56d01..015d69de68766 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -314,14 +314,18 @@ struct Shape2D { }; -template +template struct BitsTraits { static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); static constexpr int kBits = qbits; - static constexpr int kMax = (1 << qbits) - 1; - static constexpr int kMid = 1 << (qbits - 1); + static constexpr int kMax = signed_quant ? (1 << (qbits -1)) - 1 : (1 << qbits) - 1; + static constexpr int kMid = signed_quant ? 0 : (1 << (qbits - 1)); + static constexpr int kMin = signed_quant ? -(1 << (qbits - 1)) : 0; static constexpr float kMaxFp = static_cast(kMax); + static constexpr float kMinFp = static_cast(kMin); + static constexpr float fullRange = kMaxFp - kMinFp; + static constexpr float halfRange = static_cast(kMid - kMin); // number of qbit elements to pack into whole bytes static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; @@ -331,53 +335,54 @@ struct BitsTraits { /** * @brief Rectify min/max from a set of weights, and convert to scale and zero point - * for quantization - * @tparam ScaleT type of scale, usually floating point of various bits - * @tparam qbits number of int bits used for zero point value + * for quantization. + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @tparam signed_quant output quantized type is signed * @param[in] min * @param[in] max * @param[out] scale * @param[out] zp */ -template +template MLAS_FORCEINLINE void range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) { - constexpr int zp_max = BitsTraits::kMax; - constexpr float zp_max_fp = BitsTraits::kMaxFp; - min = std::min(min, 0.0f); max = std::max(max, 0.0f); - float scale_f = (max - min) / zp_max; + float scale_f = (max - min) / BitsTraits::fullRange; float zero_point_fp = min; if (scale_f != 0.0f) { - zero_point_fp = 0.f - min / scale_f; + zero_point_fp = BitsTraits::kMinFp - min / scale_f; } - if (zero_point_fp < 0.0f) { - zp = 0; - } else if (zero_point_fp > zp_max_fp) { - zp = zp_max; + if (zero_point_fp < BitsTraits::kMinFp) { + zp = static_cast(BitsTraits::kMin); + } else if (zero_point_fp > BitsTraits::kMaxFp) { + zp = static_cast(BitsTraits::kMax); } else { zp = (uint8_t)roundf(zero_point_fp); } scale = ScaleT(scale_f); } -template +/** + * @brief Rectify min/max from a set of symmetric weights, and convert + * to scale for quantization. + */ +template MLAS_FORCEINLINE void range2scale(float min, float max, ScaleT& scale) { - constexpr int mid_v = BitsTraits::kMid; - constexpr float mid_fp = static_cast(-mid_v); - max = fabsf(max) > fabsf(min) ? max : min; - - scale = ScaleT(max / mid_fp); + // !!Note: in the quantized space, abs of min -8 > abs of max 7. + // Therefore map the larger half FP space to [-8, 0]. + // Minus sign achieves this purpose. + scale = ScaleT(-max / BitsTraits::halfRange); }; @@ -400,7 +405,7 @@ struct BlockwiseQuantizer { static_assert(qbits == 4, "Only 4b block quantization is supported!"); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; - using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; static MLAS_FORCEINLINE @@ -474,8 +479,8 @@ struct BlockwiseQuantizer { MlasTryBatchParallel( thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { - uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -490,7 +495,7 @@ struct BlockwiseQuantizer { const int meta_col = c / QuantBlk::kColumn; // compute scale and zero point - for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { // scan a single block to extract range [min, max] float min = std::numeric_limits::max(); @@ -509,9 +514,9 @@ struct BlockwiseQuantizer { if (row_start < row_end) { const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; if (zero_points == nullptr) { - range2scale(min, max, scales[meta_idx]); + range2scale(min, max, scales[meta_idx]); } else { - range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); } } } @@ -533,7 +538,7 @@ struct BlockwiseQuantizer { const float v0 = static_cast(src[i * leadingDimension + j]); const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); + 0.0f, BitsTraits::kMaxFp); uint8_t vi1 = (uint8_t)zp; if (i + 1 < r_end) { @@ -545,7 +550,7 @@ struct BlockwiseQuantizer { } const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); + BitsTraits::kMaxFp); } // !! 4b specific code @@ -638,6 +643,650 @@ struct BlockwiseQuantizer { } }; +/** + * @brief Blockwise quantization methods for QDQ format. Input tensor is quantized along column + * or row. Scales and zeros are calculated. Based on qbits, consecutive quantized elements + * in memory are packed together, which means the packing is along the row. Quantized data + * are stored in row major, so the output tensor reserves same shape, in terms of qbits type, + * as the input tensor. + * If has zero points, quantized type is unsigned. Otherwise, quantized type is signed and the + * zero point is 0. + * The transposed outputs are used by MatMulNBits, so quant type becomes uint4 with default + * zp at 8. + * @tparam Tin source data type, e.g. fp32/fp16 + * @tparam qbits number of bits in each quantized element + * @tparam signed_quant quantized type is signed + */ +template +struct BlockwiseQDQQuantizer; + +template +struct BlockwiseQDQQuantizer { + static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) + { + return (val >> (idx << 2)) & 0xF; + } + + static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) + { + auto shift = idx << 2; + return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } + + template + static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1) + { + if constexpr (add8) { + return ((v0 & 0xF) ^ 8) | (((v1 & 0xF) ^ 8) << 4); + } else { + return (v0 & 0xF) | ((v1 & 0xF) << 4); + } + } + + // If src is row major, then dst is column major. Transpose: + // | src0: low 4 bit | src0: high 4 bit | + // | src1: low 4 bit | src1: high 4 bit | + // --> + // | dst0: low 4 bit | dst1: low 4 bit | + // | dst0: high 4 bit| dst1: high 4 bit | + // If src is column major, then dst is row major. Transpose: + // | src0: low 4 bit | src1: low 4 bit | + // | src0: high 4 bit| src1: high 4 bit | + // --> + // | dst0: low 4 bit | dst0: high 4 bit | + // | dst1: low 4 bit | dst1: high 4 bit | + template + static MLAS_FORCEINLINE void Transpose(uint8_t src0, uint8_t src1, uint8_t& dst0, uint8_t& dst1) + { + if constexpr (add8) { + dst0 = ((src0 & 0xF) ^ 8) | (((src1 & 0xF) ^ 8) << 4); + dst1 = (((src0 & 0xF0) ^ 0x80) >> 4) | ((src1 & 0xF0) ^ 0x80); + } else { + dst0 = (src0 & 0xF) | ((src1 & 0xF) << 4); + dst1 = ((src0 & 0xF0) >> 4) | (src1 & 0xF0); + } + } + + static MLAS_FORCEINLINE uint8_t QuantizeV(Tin src, float reciprocal_scale, uint8_t zero_point) + { + return static_cast( + std::clamp( + static_cast( + std::roundf(static_cast(src) * reciprocal_scale) + ) + static_cast(zero_point), + BitsTraits<4, signed_quant>::kMin, + BitsTraits<4, signed_quant>::kMax + ) + ); + } + + /** + * @brief Quantize a matrix shape [rows, columns] column-wise. Scales and zero points are calculated. + * Quantized data are packed row-wise based on qbits. Quantized data are stored in row major + * so the output tensor reserves the shape, in terms output type. + * @param src the source matrix, row major: [rows * columns] + * @param scales the scales of quantized blocks, row major with shape: + * [ceil(rows/quant_block_size) * columns] + * @param zero_points the zero points of quantized blocks, packed. Same shape as scales in terms + * of output type. In uint8_t, the shape is: + * [ceil(columns * ceil(rows / quant_block_size) * qbits / 8)] + * @param dst the quantized weights, row major: [rows * columns] in terms of output type. + * In uint8_t, the shape is: [ceil(rows * columns * qbits / 8] + * @param rows number of rows in the source matrix + * @param columns number of columns in the source matrix. + * @param quant_block_size number of rows/columns quantized together + * @param thread_pool thread pool for parallel processing + */ + static void QuantizeColumnWise( + const Tin* src, + Tin* scales, + uint8_t* zero_points, + uint8_t* dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL* thread_pool + ) + { + ORT_ENFORCE(zero_points || signed_quant, "Unsigned quant with no zero points is not supported."); + // Must avoid multiple thread write to a single byte, which means the starting index + // of a thread block must be even. To achieve that, we need to customize the thread + // block size based on the parity of columns. + if (columns & 1) { + QuantizeColumnWisePackUnaligned( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + } else { + QuantizeColumnWisePackAligned( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + } + } + + + /** + * @brief Transpose quantized tensors, which has been column-wise quantized, for use in MatMulNbits. + * Since both src tensor and dst tensor are packed, it's not needed to consider sign + * during the unpacking/packing in transpose. + * @param src_weights The quantized weights, row major: [rows, columns] in qbits type. + * In uint8_t, size of [ceil(rows * columns * qbits / 8)]. + * @param src_scales [ceil(rows / quant_block_size), columns] + * @param src_zero_points [ceil(rows / quant_block_size), columns] in qbits type. In uint8_t, size of + * [ceil(ceil(rows / quant_block_size) * columns * qbits / 8 )]. + * @param dst_weights the transposed quantized weights, column major. In uint8_t, the shape is + * [columns, ceil(rows / quant_block_size), ceil(quant_block_size * qbits / 8)] + * @param dst_scales [columns, ceil(rows / quant_block_size)] + * @param dst_zero_points [columns, ceil(ceil(rows / quant_block_size) * qbits / 8)] in uint8_t. + * @param rows number of src rows in qbits type. + * @param columns number of src columns in qbits type. + * @param quant_block_size number of elements quantized together + * @param thread_pool thread pool for parallel processing + */ + static void TransposeColumnWiseQuantized( + const uint8_t* src_weights, + const Tin* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + Tin* dst_scales, + uint8_t* dst_zero_points, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL* thread_pool + ) + { + ORT_ENFORCE( + src_zero_points || signed_quant || dst_zero_points, + "Unsigned quant types without zero points must allocate zero points with value 0." + ); + // Must avoid multiple thread write to a single byte, which means the starting index + // of a thread block must be even. To achieve that, we need to customize the thread + // block size based on the parity of columns. + if (columns & 1) { + TransposeColumnWiseQuantizedPackUnaligned( + src_weights, src_scales, src_zero_points, + dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } else { + TransposeColumnWiseQuantizedPackAligned( + src_weights, src_scales, src_zero_points, + dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } + } + +private: + static void QuantizeColumnWisePackAligned( + const Tin* src, + Tin* scales, + uint8_t* zero_points, + uint8_t* dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL* thread_pool + ) + { + ORT_ENFORCE(columns % 2 == 0, "Columns must be multiple of 2."); + // Thread block is [quant_block_size, thread_blk_size]. thread_blk_size % 2 == 0. + constexpr int32_t thread_blk_size = 128; + const auto num_row_thread_blk = (rows + quant_block_size - 1) / quant_block_size; + const auto num_col_thread_blk = (columns + thread_blk_size - 1) / thread_blk_size; + const auto num_thread_blk = num_row_thread_blk * num_col_thread_blk; + constexpr auto minf = std::numeric_limits::lowest(); + constexpr auto maxf = std::numeric_limits::max(); + + MlasTryBatchParallel( + thread_pool, static_cast(num_thread_blk), + [&](ptrdiff_t thread_blk_idx) { + // !!warning!!: buffering the whole thread block + constexpr int32_t buffer_size = 128; + ORT_ENFORCE(buffer_size == thread_blk_size, "buffer size must be equal to thread block size."); + float reciprocal_scale_t[buffer_size]; + uint8_t zp_t[buffer_size]; + float vmin_t[buffer_size]; + float vmax_t[buffer_size]; + + const int32_t row_thread_blk_idx = static_cast(thread_blk_idx / num_col_thread_blk); + const int32_t col_thread_blk_idx = static_cast(thread_blk_idx % num_col_thread_blk); + const int32_t row_idx = row_thread_blk_idx * quant_block_size; + const int32_t col_idx = col_thread_blk_idx * buffer_size; + const int32_t row_size = std::min(quant_block_size, rows - row_idx); + const int32_t col_size = std::min(buffer_size, columns - col_idx); + // input_idx, scale_idx, col_size are aligned to 2 + auto input_idx = row_idx * columns + col_idx; + auto scale_idx = row_thread_blk_idx * columns + col_idx; + + Tin scale0_tt, scale1_tt; + uint8_t v0_tt, v1_tt; + + std::fill_n(vmin_t, buffer_size, maxf); + std::fill_n(vmax_t, buffer_size, minf); + + // calculate min/max + for (int32_t j = 0, input_idx_t = input_idx; j < row_size; ++j, input_idx_t += columns) { + // TODO(fajin): use SIMD + for (int32_t i = 0; i < col_size; i += 2) { + auto v0 = static_cast(src[input_idx_t + i]); + auto v1 = static_cast(src[input_idx_t + i + 1]); + vmin_t[i] = std::min(vmin_t[i], v0); + vmax_t[i] = std::max(vmax_t[i], v0); + vmin_t[i + 1] = std::min(vmin_t[i + 1], v1); + vmax_t[i + 1] = std::max(vmax_t[i + 1], v1); + } + } + + // calculate scale and zero point, and store + for (int32_t i = 0; i < col_size; i += 2) { + v0_tt = v1_tt = BitsTraits<4, signed_quant>::kMid; + + if (zero_points) { + range2scalezp(vmin_t[i], vmax_t[i], scale0_tt, v0_tt); + range2scalezp(vmin_t[i + 1], vmax_t[i + 1], scale1_tt, v1_tt); + zero_points[(scale_idx + i) >> 1] = Pack(v0_tt, v1_tt); + } else { + range2scale(vmin_t[i], vmax_t[i], scale0_tt); + range2scale(vmin_t[i + 1], vmax_t[i + 1], scale1_tt); + } + + scales[scale_idx + i] = scale0_tt; + scales[scale_idx + i + 1] = scale1_tt; + + float scalef0 = static_cast(scale0_tt); + reciprocal_scale_t[i] = scalef0 ? 1.0f / scalef0 : 0.0f; + zp_t[i] = v0_tt; + + float scalef1 = static_cast(scale1_tt); + reciprocal_scale_t[i + 1] = scalef1 ? 1.0f / scalef1 : 0.0f; + zp_t[i + 1] = v1_tt; + } + + // quantize and pack + for (int32_t j = 0, input_idx_t = input_idx; j < row_size; ++j, input_idx_t += columns) { + // TODO(fajin): use SIMD + for (int32_t i = 0; i < col_size; i += 2) { + v0_tt = QuantizeV(src[input_idx_t + i], reciprocal_scale_t[i], zp_t[i]); + v1_tt = QuantizeV(src[input_idx_t + i + 1], reciprocal_scale_t[i + 1], zp_t[i + 1]); + dst[(input_idx_t + i) >> 1] = Pack(v0_tt, v1_tt); + } + } + } + ); + } + + static void QuantizeColumnWisePackUnaligned( + const Tin* src, + Tin* scales, + uint8_t* zero_points, + uint8_t* dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL* thread_pool + ) + { + // Thread block is [quant_block_size * 2, columns], so the packed bytes do not cross threads. + constexpr auto minf = std::numeric_limits::lowest(); + constexpr auto maxf = std::numeric_limits::max(); + auto row_thread_blk_size = quant_block_size * 2; + auto num_row_thread_blk = (rows + row_thread_blk_size - 1) / (row_thread_blk_size); + + MlasTryBatchParallel( + thread_pool, static_cast(num_row_thread_blk), + [&](ptrdiff_t thread_blk_idx) { + constexpr int32_t buffer_size = 128; + float reciprocal_scale_t[buffer_size]; + uint8_t zp_t[buffer_size]; + float vmin_t[buffer_size]; + float vmax_t[buffer_size]; + + auto row_thread_blk_idx = static_cast(thread_blk_idx); + int32_t row_idx = row_thread_blk_idx * row_thread_blk_size; + int32_t row_idx_end = std::min(row_thread_blk_size + row_idx, rows); + auto input_idx = row_idx * columns; + auto scale_idx = row_thread_blk_idx * 2 * columns; + Tin scale0_tt, scale1_tt; + uint8_t v0_tt, v1_tt; + + for (; row_idx < row_idx_end; row_idx += quant_block_size) { + // per quant block row + auto quant_row_size = std::min(quant_block_size, row_idx_end - row_idx); + auto input_buffer_idx = input_idx; + auto scale_buffer_idx = scale_idx; + for (int32_t buffer_idx = 0; buffer_idx < columns; buffer_idx += buffer_size) { + // per buffer column + auto buffer_col_size = std::min(buffer_size, columns - buffer_idx); + + std::fill_n(vmin_t, buffer_size, maxf); + std::fill_n(vmax_t, buffer_size, minf); + // calculate min/max of [quant block, buffer] + auto input_idx_t = input_buffer_idx; + for (int32_t j = 0; j < quant_row_size; ++j, input_idx_t += columns) { + // TODO(fajin): use SIMD + for (int32_t i = 0; i < buffer_col_size; ++i) { + auto v = static_cast(src[input_idx_t + i]); + vmin_t[i] = std::min(vmin_t[i], v); + vmax_t[i] = std::max(vmax_t[i], v); + } + } + + // calculate scale and zero point + auto scale_buffer_idx_end = scale_buffer_idx + buffer_col_size; + int32_t col_idx = 0; + // leading unailgned zero points + if (scale_buffer_idx & 1) { + v0_tt = BitsTraits<4, signed_quant>::kMid; + if (zero_points) { + range2scalezp(vmin_t[0], vmax_t[0], scale0_tt, v0_tt); + zero_points[scale_buffer_idx >> 1] = SetElem( + v0_tt, 1, zero_points[scale_buffer_idx >> 1] + ); + } else { + range2scale(vmin_t[0], vmax_t[0], scale0_tt); + } + + scales[scale_buffer_idx] = scale0_tt; + + float scalef = static_cast(scale0_tt); + reciprocal_scale_t[0] = scalef ? 1.0f / scalef : 0.0f; + zp_t[0] = v0_tt; + + ++col_idx; + ++scale_buffer_idx; + } + // aligned zero points + for (; scale_buffer_idx < scale_buffer_idx_end - 1; col_idx += 2, scale_buffer_idx += 2) { + v0_tt = v1_tt = BitsTraits<4, signed_quant>::kMid; + if (zero_points) { + range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); + range2scalezp( + vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt, v1_tt + ); + zero_points[scale_buffer_idx >> 1] = Pack(v0_tt, v1_tt); + } else { + range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); + range2scale(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt); + } + + scales[scale_buffer_idx] = scale0_tt; + scales[scale_buffer_idx + 1] = scale1_tt; + + float scalef0 = static_cast(scale0_tt); + reciprocal_scale_t[col_idx] = scalef0 ? 1.0f / scalef0 : 0.0f; + zp_t[col_idx] = v0_tt; + + float scalef1 = static_cast(scale1_tt); + reciprocal_scale_t[col_idx + 1] = scalef1 ? 1.0f / scalef1 : 0.0f; + zp_t[col_idx + 1] = v1_tt; + } + // tailing unaligned elements + if (scale_buffer_idx < scale_buffer_idx_end) { + v0_tt = BitsTraits<4, signed_quant>::kMid; + if (zero_points) { + range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); + zero_points[scale_buffer_idx >> 1] = SetElem( + v0_tt, 0, zero_points[scale_buffer_idx >> 1] + ); + } else { + range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); + } + + scales[scale_buffer_idx] = scale0_tt; + + float scalef = static_cast(scale0_tt); + reciprocal_scale_t[col_idx] = scalef ? 1.0f / scalef : 0.0f; + zp_t[col_idx] = v0_tt; + + ++scale_buffer_idx; + } + + // quantize and pack + input_idx_t = input_buffer_idx; + for (int32_t j = 0; j < quant_row_size; ++j, input_idx_t += columns) { + auto input_idx_t_start = input_idx_t; + auto input_idx_t_end = input_idx_t + buffer_col_size; + col_idx = 0; + // leading unaligned output + if (input_idx_t_start & 1) { + v1_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); + dst[input_idx_t_start >> 1] = SetElem(v1_tt, 1, dst[input_idx_t_start >> 1]); + + ++col_idx; + ++input_idx_t_start; + } + // aligned output + // TODO(fajin): use SIMD + for (; input_idx_t_start < input_idx_t_end - 1; col_idx += 2, input_idx_t_start += 2) { + v0_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); + v1_tt = QuantizeV( + src[input_idx_t_start + 1], reciprocal_scale_t[col_idx + 1], zp_t[col_idx + 1] + ); + + dst[input_idx_t_start >> 1] = Pack(v0_tt, v1_tt); + } + // tailing unaligned output + if (input_idx_t_start < input_idx_t_end) { + v0_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); + dst[input_idx_t_start >> 1] = SetElem(v0_tt, 0, dst[input_idx_t_start >> 1]); + } + } + + input_buffer_idx += buffer_size; + } + + input_idx += quant_block_size * columns; + scale_idx += columns; + } + } + ); + } + + static void TransposeColumnWiseQuantizedPackAligned( + const uint8_t* src_weights, // [rows, columns / 2] + const Tin* src_scales, // [ceil(rows / quant_block_size), columns] + const uint8_t* src_zero_points, // [ceil(rows / quant_block_size), columns / 2] + uint8_t* dst_weights, // [columns, ceil(rows / quant_block_size), ceil(quant_block_size / 2)] + Tin* dst_scales, // [columns, ceil(rows / quant_block_size)] + uint8_t* dst_zero_points, // [columns, ceil(ceil(rows / quant_block_size) / 2)] + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL* thread_pool + ) + { + ORT_ENFORCE(columns % 2 == 0, "Columns must be multiple of 2"); + + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto dst_bytes_per_quant_blk = (quant_block_size * 4 + 7) / 8; + // number of rows in transposed dst + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + auto packed_col_size = columns / 2; + + // weight transpose thread block is [dst_bytes_per_quant_blk, 2] on dst_Transpose. + // Map to src it is [quant_block_size, 1]. Both in uint8_t. + auto num_thread_blk = row_quant_blk_num * packed_col_size; + MlasTryBatchParallel( + thread_pool, static_cast(num_thread_blk), + [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + uint8_t dst0_t, dst1_t; + + auto row_thread_blk_idx = static_cast(thread_blk_idx / packed_col_size); + auto col_thread_blk_idx = static_cast(thread_blk_idx % packed_col_size); + + auto dstT_row_idx = row_thread_blk_idx * dst_bytes_per_quant_blk; + auto dstT_col_idx = col_thread_blk_idx * 2; + auto dst_idx = dstT_col_idx * dstT_num_row + dstT_row_idx; + + auto src_row_idx = row_thread_blk_idx * quant_block_size; + auto src_row_end_idx = std::min(src_row_idx + quant_block_size, rows); + auto src_col_idx = col_thread_blk_idx; + auto src_idx = src_row_idx * packed_col_size + src_col_idx; + auto src_end_idx = src_row_end_idx * packed_col_size + src_col_idx; + + for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { + src0_t = src_weights[src_idx]; + src1_t = src_weights[src_idx + packed_col_size]; + src_idx += packed_col_size + packed_col_size; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_weights[dst_idx] = dst0_t; + dst_weights[dst_idx + dstT_num_row] = dst1_t; + } + + if (src_idx < src_end_idx) { + src0_t = src_weights[src_idx]; + src1_t = 0; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_weights[dst_idx] = dst0_t; + dst_weights[dst_idx + dstT_num_row] = dst1_t; + } + } + ); + + // Transpose scales. Thread block is [row_quant_blk_num, 1] on dst_Transpose. + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t thread_blk_idx) { + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + } + ); + + if (src_zero_points) { + // Transpose zero points. Thread block is [ceil(row_quant_blk_num / 2), 2] + // on dst_Transpose. Map to src it is [row_quant_blk_num, 1]. Both in uint8_t. + auto dst_zp_row_num = (row_quant_blk_num + 1) / 2; + MlasTryBatchParallel( + thread_pool, static_cast(packed_col_size), + [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + uint8_t dst0_t, dst1_t; + + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto src_end_idx = row_quant_blk_num * packed_col_size + col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * 2 * dst_zp_row_num; + + for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { + src0_t = src_zero_points[src_idx]; + src1_t = src_zero_points[src_idx + packed_col_size]; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_zero_points[dst_idx] = dst0_t; + dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; + src_idx += packed_col_size + packed_col_size; + } + + if (src_idx < src_end_idx) { + src0_t = src_zero_points[src_idx]; + src1_t = 0; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_zero_points[dst_idx] = dst0_t; + dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; + } + } + ); + } + } + + static void TransposeColumnWiseQuantizedPackUnaligned( + const uint8_t* src_weights, // size of [ceil(rows * columns / 2)] + const Tin* src_scales, // [ceil(rows / quant_block_size), columns] + const uint8_t* src_zero_points, // size of [ceil(ceil(rows / quant_block_size) * columns / 2)] + uint8_t *dst_weights, // [columns, ceil(rows / quant_block_size), ceil(quant_block_size / 2)] + Tin* dst_scales, // [columns, ceil(rows / quant_block_size)] + uint8_t* dst_zero_points, // [columns, ceil(ceil(rows / quant_block_size) / 2)] + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL* thread_pool) + { + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto dst_bytes_per_quant_blk = (quant_block_size * 4 + 7) / 8; + // number of rows in transposed dst + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + + // weight transpose thread block is [dst_bytes_per_quant_blk, 1] on dst_Transpose in uint8_t. + // Map to src it is [quant_block_size, 1] in int4. + auto num_thread_blk = row_quant_blk_num * columns; + MlasTryBatchParallel( + thread_pool, static_cast(num_thread_blk), + [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + + auto row_thread_blk_idx = static_cast(thread_blk_idx / columns); + auto col_thread_blk_idx = static_cast(thread_blk_idx % columns); + + auto dstT_row_idx = row_thread_blk_idx * dst_bytes_per_quant_blk; + auto dst_idx = col_thread_blk_idx * dstT_num_row + dstT_row_idx; + + auto src_row_idx = row_thread_blk_idx * quant_block_size; + auto src_row_end_idx = std::min(src_row_idx + quant_block_size, rows); + auto src_idx = src_row_idx * columns + col_thread_blk_idx; + auto src_end_idx = src_row_end_idx * columns + col_thread_blk_idx; + + for (; src_idx < src_end_idx - columns; ++dst_idx) { + src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); + src1_t = GetElem(src_weights[(src_idx + columns) >> 1], (src_idx + columns) & 1); + dst_weights[dst_idx] = Pack(src0_t, src1_t); + src_idx += columns + columns; + } + + if (src_idx < src_end_idx) { + src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); + dst_weights[dst_idx] = Pack(src0_t, 0); + } + } + ); + + // Transpose scales. Thread block is [row_quant_blk_num, 1] on dst_Transpose. + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t thread_blk_idx) { + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + } + ); + + if (src_zero_points) { + // Transpose zero points. Thread block is [ceil(row_quant_blk_num / 2), 1] on dst_Transpose in uint8_t. + // Map to src it is [row_quant_blk_num, 1] in int4. + auto dst_zp_row_num = (row_quant_blk_num + 1) / 2; + MlasTryBatchParallel( + thread_pool, static_cast(columns), + [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto src_end_idx = row_quant_blk_num * columns + col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * dst_zp_row_num; + + for (; src_idx < src_end_idx - columns; ++dst_idx) { + src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); + src1_t = GetElem(src_zero_points[(src_idx + columns) >> 1], (src_idx + columns) & 1); + dst_zero_points[dst_idx] = Pack(src0_t, src1_t); + src_idx += columns + columns; + } + + if (src_idx < src_end_idx) { + src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); + dst_zero_points[dst_idx] = Pack(src0_t, 0); + } + } + ); + } + } +}; template void @@ -1068,8 +1717,7 @@ MlasDequantizeBlockwise( } } -template -void +template void MlasDequantizeBlockwise( float* dst, const uint8_t* src, @@ -1080,4 +1728,147 @@ MlasDequantizeBlockwise( int rows, int columns, MLAS_THREADPOOL* thread_pool - ); +); + +template +bool +MlasQDQQuantizeBlockwise( + const Tin* src, + Tin* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +) +{ + if (columnwise) { + if (zero_points) { + BlockwiseQDQQuantizer::QuantizeColumnWise( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + return false; + } else { + BlockwiseQDQQuantizer::QuantizeColumnWise( + src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool + ); + return true; + } + } else { + ORT_THROW("Row-wise MlasQDQQuantizeBlockwise is not implemented"); + } +} + +template bool +MlasQDQQuantizeBlockwise( + const float* src, + float* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template bool +MlasQDQQuantizeBlockwise( + const MLAS_FP16* src, + MLAS_FP16* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template +void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const Tin* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + Tin* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +) +{ + if (columnwise) { + BlockwiseQDQQuantizer::TransposeColumnWiseQuantized( + src_weights, src_scales, src_zero_points, dst_weights, dst_scales, dst_zero_points, + rows, columns, quant_block_size, thread_pool + ); + } else { + ORT_THROW("Row-wise MlasQDQTransposeBlockwiseQuantized is not implemented"); + } +} + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const MLAS_FP16* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + MLAS_FP16* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 75c17a6b5a177..127aea9029b65 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -894,7 +894,7 @@ MlasGemmQuantGetDispatch( if (!AIsSigned) { GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd; } -#elif defined(MLAS_TARGET_POWER) && defined(__linux__) && defined(POWER10) && \ +#elif defined(MLAS_TARGET_POWER) && (defined(__linux__) || defined(_AIX)) && defined(POWER10) && \ ((defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ (defined(__clang__) && (__clang_major__ >= 12))) if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchPOWER10) { diff --git a/onnxruntime/core/mlas/lib/qlmul.cpp b/onnxruntime/core/mlas/lib/qlmul.cpp index 38818e1190d21..4a6d57db0d211 100644 --- a/onnxruntime/core/mlas/lib/qlmul.cpp +++ b/onnxruntime/core/mlas/lib/qlmul.cpp @@ -325,12 +325,20 @@ MlasQLinearMulKernel( } while (N >= 4) { - __vector int32_t IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; +#if defined(_AIX) && defined(__clang__) + __vector int IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; +#else + __vector int32_t IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; +#endif auto IntegerVector = vec_sub(IntegerAVector, ZeroPointAVector); auto ValueAVector = vec_mul(ScaleAVector, vec_ctf(IntegerVector, 0)); if (!IsScalarB) { - __vector int32_t IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; +#if defined(_AIX) && defined(__clang__) + __vector int IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; +#else + __vector int32_t IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; +#endif IntegerVector = vec_sub(IntegerBVector, ZeroPointBVector); ValueBVector = vec_mul(ScaleBVector, vec_ctf(IntegerVector, 0)); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 4b852be951c91..81789386a3200 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,10 +16,11 @@ Module Name: --*/ #include "sqnbitgemm.h" -#include "sqnbitgemm_q8_block.h" #include +#include "sqnbitgemm_q8_block.h" + namespace { @@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable( Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && + return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr; } default: { @@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32( if (bias) { AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } + if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc ); } c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; } } @@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8( const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } - - // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. - // TODO Replace it with an optimized implementation. size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); @@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc ); } - c_blk += ldc; - a_row += lda; + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index effb59b250ca0..8321dcc217e9a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -184,7 +184,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. * * @param BlkLen Number of values in a block. * @param QuantA Supplies the quantized A matrix. @@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param QuantBScale Supplies the quantized B matrix block scale values. * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. + * @param CountM Number of rows of A and C to process, an upper bound. + * @param CountN Number of columns of B and C to process. * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param BlockCountK Number of blocks in one row of A and one column of B. + * @param ldc Number of elements between adjacent rows of C. * @param Bias Bias vector of length N. + * + * @return The number of rows of A and C that were processed, at most CountM. */ - typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( + typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, + size_t CountM, size_t CountN, size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, + size_t ldc, const float* Bias ); - SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index be573381c39c3..0922f5ef646be 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } } +size_t +SQ4BitGemmKernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + MLAS_UNREFERENCED_PARAMETER(ldc); + + if (CountM == 0) { + return 0; + } + + SQ4BitGemmM1Kernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + + return 1; +} + template MLAS_FORCEINLINE void ComputeDotProducts_BlkLen16_CompFp32_avx2( @@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 0099b61d8196e..b868906760701 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 27310d8253342..6477a2019b21a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +size_t +SQ4BitGemmKernel_CompInt8_avx512vnni( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + MLAS_UNREFERENCED_PARAMETER(ldc); + + if (CountM == 0) { + return 0; + } + + SQ4BitGemmM1Kernel_CompInt8_avx512vnni( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + + return 1; +} + void MLASCALL MlasQ80BlkQuantRow_avx512( size_t BlkLen, @@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni; d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index cfc0564cd041f..706e08fc467ba 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( const size_t BlockStrideQuantB ); -void -SQ4BitGemmM1Kernel_CompInt8_avx2( +size_t +SQ4BitGemmKernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, + size_t CountM, size_t CountN, size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, + size_t ldc, const float* Bias ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 6d1864794f943..3f32cc6c5312d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + sqnbitgemm_kernel_neon.cpp Abstract: @@ -17,20 +17,22 @@ Module Name: #include -#include #include -#include #include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" -// -// Quantized B data packing function implementation. -// +namespace sqnbitgemm_neon +{ namespace { +// +// Quantized B data packing function implementation. +// + size_t SQ4BitGemmPackQuantBDataSize( size_t N, @@ -134,7 +136,7 @@ SQ4BitGemmPerGemmWorkspaceSize( { MLAS_UNREFERENCED_PARAMETER(N); - switch(ComputeType) { + switch (ComputeType) { case CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -167,1316 +169,7 @@ SQ4BitGemmPerGemmWorkspaceAlignment( } // namespace -// -// General helpers. -// - -namespace -{ - -template -MLAS_FORCEINLINE void -UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) -{ - (f(Indices), ...); -} - -template -MLAS_FORCEINLINE void -UnrolledLoop(IterationFn&& f) -{ - UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); -} - -MLAS_FORCEINLINE void -Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) -{ - // aN: aN_0 aN_1 aN_2 aN_3 - - float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 - float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 - float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 - float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 - - // a0_0 a1_0 a2_0 a3_0 - a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_1 a1_1 a2_1 a3_1 - a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_2 a1_2 a3_2 a3_2 - a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); - // a0_3 a1_3 a2_3 a3_3 - a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); -} - -MLAS_FORCEINLINE float32x4_t -FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) -{ - Transpose4x4(a0, a1, a2, a3); - return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); -} - -template -MLAS_FORCEINLINE void -LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) -{ - static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); - - assert(count <= Capacity); - - size_t vi = 0; // vector index - - // handle 4 values at a time - while (count > 3) { - dst[vi] = vld1q_f32(src); - - vi += 1; - src += 4; - count -= 4; - } - - // handle remaining values - if (count > 0) { - dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); - - if (count > 1) { - dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); - - if (count > 2) { - dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); - } - } - } -} - -} // namespace - -// -// CompFp32 kernel implementation. -// - -namespace -{ - -namespace fp32_conversion -{ - -// Manual conversion to float takes place in two steps: -// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. -// This target float range is convenient because the 4-bit source values can be placed directly into the -// target float bits. -// 2. Subtract the conversion offset of 16 from the float result. - -// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. -constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; -// sign|exponent|partial mantissa -// +|131: 2^4|~~~~ <- 4 bits go here - -const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); - -constexpr float offset = 16.0f; - -} // namespace fp32_conversion - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompFp32( - size_t BlkLen, - const float* ARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; - - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - - assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); - - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - float32x4_t acc[NCols]{}; - - const std::byte* QuantBData = QuantBDataColPtr; - const float* QuantBScale = QuantBScaleColPtr; - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint is true - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - float scale[NCols]; - UnrolledLoop( - [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } - ); - - [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. - // only used if HasZeroPoint is true - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = fp32_conversion::offset + std::to_integer(zp); - }); - } - - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { - // load A row vector elements - - // load `SubBlkLen` elements from A, padded with 0's if there aren't enough - const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); - float32x4_t av[4]{}; - LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(scale[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // c[m,n] += a[m,k] * b[k,n] - UnrolledLoop<4>([&](size_t j) { - UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); - }); - } - - // increment pointers to next block - QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - QuantBScale += 1; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } - - if constexpr (NCols == 4) { - float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - - if (BiasPtr != nullptr) { - sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); - } - - vst1q_f32(SumPtr, sum); - } else { - for (size_t i = 0; i < NCols; ++i) { - SumPtr[i] = vaddvq_f32(acc[i]); - if (BiasPtr != nullptr) { - SumPtr[i] += BiasPtr[i]; - } - } - } -} - -template -void -SQ4BitGemmM1Kernel_CompFp32_Impl( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; - - const float* ARowPtr = A; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols; - - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - - nblk -= NCols; - } - - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -void -SQ4BitGemmM1Kernel_CompFp32( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } -} - -// Block dequantize a 16 x NCols section of B from column major source to row major destination. -template -MLAS_FORCEINLINE void -Q4BitBlkDequantB_16xNCols( - const std::byte* QuantBDataPtr, - size_t StrideQuantBData, - const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns - [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns - // only used if HasZeroPoint is true - float* DstColPtr -) -{ - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // write, transposed, 16 x NCols values - if constexpr (NCols == 4) { - UnrolledLoop<4>([&](size_t j) { - Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); - - vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); - vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); - vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); - vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); - }); - } else { - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { - DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); - DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); - DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); - DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); - }); - }); - } -} - -template -void -Q4BitBlkDequantBForSgemm_CompFp32_Impl( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -) -{ - constexpr size_t BlkBitWidth = 4; - - float* Dst = FpData; - - const std::byte* QuantBDataCol = QuantBData; - const float* QuantBScaleCol = QuantBScale; - [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true - - const size_t StrideQuantBData = BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true - MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); - - // - // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. - // - - // scales of blocks from 16 adjacent columns - float scale[16]; - // float conversion offsets (including zero point) of blocks from 16 adjacent columns - [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true - - size_t n_cols_remaining = CountN; - while (n_cols_remaining > 15) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < 16; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - constexpr size_t NCols = 4; - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < 16; nn += NCols) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += NCols; - if constexpr (HasZeroPoint) { - OffsetPtr += NCols; - } - DstColPtr += NCols; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - - n_cols_remaining -= 16; - - QuantBDataCol += 16 * StrideQuantBData; - QuantBScaleCol += 16 * BlockStrideQuantB; - if constexpr (HasZeroPoint) { - QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; - } - } - - if (n_cols_remaining > 0) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - // zero out the 16x16 block in Dst first to ensure zero padding - const float32x4_t zero_v = vdupq_n_f32(0.0f); - UnrolledLoop<16 * 4>([&](size_t i) { - vst1q_f32(Dst + 4 * i, zero_v); - }); - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += 1; - if constexpr (HasZeroPoint) { - OffsetPtr += 1; - } - DstColPtr += 1; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - } -} - -void -Q4BitBlkDequantBForSgemm_CompFp32( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -) -{ - if (QuantBZeroPoint != nullptr) { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockStrideQuantB - ); - } else { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockStrideQuantB - ); - } -} - -// -// CompInt8 kernel implementation. -// - -template -MLAS_FORCEINLINE void -QuantizeBlock( - size_t BlkLen, - const float* A, - size_t ElementCount, - std::byte* QuantA -) -{ - static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); - - assert(BlkLen % SubBlkLen == 0); - - // - // Scan block values first to determine scale. - // - - float amax = 0.0f; // max of absolute values of A block - - size_t k; - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - float32x4_t abs_a[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - abs_a[i] = vabsq_f32(a[i]); - }); - - // find amax of SubBlkLen elements - for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { - for (size_t i = 0; i < interval; ++i) { - abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); - } - } - - // update existing amax - amax = std::max(amax, vmaxvq_f32(abs_a[0])); - } - - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - - Q8BlkScale(QuantA) = scale; - - // - // Compute quantized block values. - // - - int8_t* QuantAData = Q8BlkData(QuantA); - - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - UnrolledLoop([&](size_t i) { - a[i] = vmulq_n_f32(a[i], scale_reciprocal); - }); - - int32x4_t a_s32[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - a_s32[i] = vcvtaq_s32_f32(a[i]); - }); - - UnrolledLoop([&](size_t i) { - QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); - QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); - QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); - QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); - }); - } - - // - // Zero out any remaining sub-block elements. - // - - for (; k < BlkLen; k += SubBlkLen) { - const int8x16_t Zeros = vdupq_n_s8(0); - UnrolledLoop([&](size_t i) { - vst1q_s8(QuantAData + k + i * 16, Zeros); - }); - } -} - -void -QuantizeARow_CompInt8( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA -) -{ - const float* ADataBlkPtr = A; - std::byte* QuantABlkPtr = QuantA; - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); - - ADataBlkPtr += BlkLen; - QuantABlkPtr += Q8BlkSize(BlkLen); - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 16; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] >> 4) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); - - // load B - const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); - const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); - int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - bv1 = vsubq_s8(bv1, bzp1); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 8 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - - // load B - const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); - const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); - const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - bv_lo1 = vsubq_s8(bv_lo1, bzp1); - bv_hi1 = vsubq_s8(bv_hi1, bzp1); - - // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - - // quantized dot product - int32x4_t dot0{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen > 32); - assert(BlkLen % 32 == 0); - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - // process blocks in 32-element sub-blocks - const size_t SubBlksPerBlk = BlkLen / 32; - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - // compute combined scale - const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp = [&]() -> int8x16_t { - if constexpr (HasZeroPoint) { - return vdupq_n_s8( - ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) - : std::to_integer((*QuantBZeroPointPtr) >> 4) - ); - } else { - return vdupq_n_s8(8); - } - }(); - - const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); - - for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { - // load A - const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); - const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); - const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); - const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp); - bv1 = vsubq_s8(bv1, bzp); - bv2 = vsubq_s8(bv2, bzp); - bv3 = vsubq_s8(bv3, bzp); - - // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); - dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale); - - // increment block data pointers to next sub-block - QuantADataPtr += 16 * 4; - QuantBDataPtr += 16 * 2; - } - - // increment other block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } -} - -void -SQ4BitGemmM1Kernel_CompInt8( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t /*CountK*/, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } -} - -} // namespace +} // namespace sqnbitgemm_neon // // Kernel dispatch structure definition. @@ -1485,17 +178,17 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h new file mode 100644 index 0000000000000..ef9345d7ac484 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + SQNBitGemm ARM NEON kernels. + +--*/ + +#pragma once + +#include + +#include +#include +#include + +#include "mlasi.h" + +namespace sqnbitgemm_neon +{ + +// +// Function declarations for SQNBitGemm ARM NEON kernel entry points. +// Refer to the prototypes in sqnbitgemm.h for documentation. +// These are declared here so they can be used to initialize the +// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// files. +// + +// CompFp32 declarations + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +); + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +); + +// CompInt8 declarations + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); + +size_t +SQ4BitGemmKernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + size_t ldc, + const float* Bias +); + +// +// General helpers. +// + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +template +MLAS_FORCEINLINE void +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp new file mode 100644 index 0000000000000..ca64ebe3b1137 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -0,0 +1,646 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_fp32.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + +--*/ + +#include + +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +namespace +{ + +// +// CompFp32 kernel implementation. +// + +MLAS_FORCEINLINE void +Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + Transpose4x4(a0, a1, a2, a3); + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +namespace fp32_conversion +{ + +// Manual conversion to float takes place in two steps: +// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. +// This target float range is convenient because the 4-bit source values can be placed directly into the +// target float bits. +// 2. Subtract the conversion offset of 16 from the float result. + +// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. +constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; +// sign|exponent|partial mantissa +// +|131: 2^4|~~~~ <- 4 bits go here + +const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + +constexpr float offset = 16.0f; + +} // namespace fp32_conversion + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; + + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + float32x4_t acc[NCols]{}; + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint is true + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. + // only used if HasZeroPoint is true + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = fp32_conversion::offset + std::to_integer(zp); + }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +} // namespace + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + } else { + constexpr bool HasZeroPoint = false; + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + } +} + +namespace +{ + +// Block dequantize a 16 x NCols section of B from column major source to row major destination. +template +MLAS_FORCEINLINE void +Q4BitBlkDequantB_16xNCols( + const std::byte* QuantBDataPtr, + size_t StrideQuantBData, + const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns + [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns + // only used if HasZeroPoint is true + float* DstColPtr +) +{ + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // write, transposed, 16 x NCols values + if constexpr (NCols == 4) { + UnrolledLoop<4>([&](size_t j) { + Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); + + vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); + vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); + vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); + vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); + }); + } else { + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { + DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); + DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); + DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); + DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); + }); + }); + } +} + +template +void +Q4BitBlkDequantBForSgemm_CompFp32_Impl( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth = 4; + + float* Dst = FpData; + + const std::byte* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true + MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + // + // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. + // + + // scales of blocks from 16 adjacent columns + float scale[16]; + // float conversion offsets (including zero point) of blocks from 16 adjacent columns + [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true + + size_t n_cols_remaining = CountN; + while (n_cols_remaining > 15) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < 16; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + constexpr size_t NCols = 4; + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < 16; nn += NCols) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += NCols; + if constexpr (HasZeroPoint) { + OffsetPtr += NCols; + } + DstColPtr += NCols; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + + n_cols_remaining -= 16; + + QuantBDataCol += 16 * StrideQuantBData; + QuantBScaleCol += 16 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; + } + } + + if (n_cols_remaining > 0) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + // zero out the 16x16 block in Dst first to ensure zero padding + const float32x4_t zero_v = vdupq_n_f32(0.0f); + UnrolledLoop<16 * 4>([&](size_t i) { + vst1q_f32(Dst + 4 * i, zero_v); + }); + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += 1; + if constexpr (HasZeroPoint) { + OffsetPtr += 1; + } + DstColPtr += 1; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + } +} + +} // namespace + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +) +{ + if (QuantBZeroPoint != nullptr) { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockCountK + ); + } else { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockCountK + ); + } +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp new file mode 100644 index 0000000000000..ec5cdbc75220a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -0,0 +1,1401 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + +--*/ + +#include + +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" +#include "sqnbitgemm_q8_block.h" + +namespace sqnbitgemm_neon +{ + +// +// CompInt8 kernel implementation. +// + +namespace +{ + +template +MLAS_FORCEINLINE void +QuantizeBlock( + size_t BlkLen, + const float* A, + size_t ElementCount, + std::byte* QuantA +) +{ + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); + + assert(BlkLen % SubBlkLen == 0); + + // + // Scan block values first to determine scale. + // + + float amax = 0.0f; // max of absolute values of A block + + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } + + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +} // namespace + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } +} + +namespace +{ + +// +// The ComputeRxC functions compute an R row by C column tile of the output matrix. +// + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute4x2_BlkLen16( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK, + size_t StrideQuantA, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + size_t ldc +) +{ + constexpr size_t BlkLen = 16; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}, acc20{}, acc21{}, acc30{}, acc31{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlkRow0 = QuantAPtr; + const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + const std::byte* QuantABlkRow2 = QuantAPtr + StrideQuantA * 2; + const std::byte* QuantABlkRow3 = QuantAPtr + StrideQuantA * 3; + + const float QuantBScaleCol0 = *QuantBScalePtr; + const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); + + // compute combined scales + const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; + const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; + const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; + const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + const float scale20 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol0; + const float scale21 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol1; + const float scale30 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol0; + const float scale31 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol1; + + // load B zero point + int8_t bzp_col0; + int8_t bzp_col1; + if constexpr (HasZeroPoint) { + const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; + const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); + if ((k_blk_idx & 1) == 0) { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); + } else { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); + } + } else { + bzp_col0 = 8; + bzp_col1 = 8; + } + + const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); + const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + const int8_t* QuantADataPtrRow2 = Q8BlkData(QuantABlkRow2); + const int8_t* QuantADataPtrRow3 = Q8BlkData(QuantABlkRow3); + + // TODO handling only 16 elements per accumulator at a time here, probably can do better + { + // load B + const uint8x8_t bv_packed_col0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x8_t bv_packed_col1 = vld1_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); + + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + int8x16_t bv_col0 = vreinterpretq_s8_u8( + vcombine_u8( + vand_u8(bv_packed_col0, LowMaskU8x8), + vshr_n_u8(bv_packed_col0, 4) + ) + ); + int8x16_t bv_col1 = vreinterpretq_s8_u8( + vcombine_u8( + vand_u8(bv_packed_col1, LowMaskU8x8), + vshr_n_u8(bv_packed_col1, 4) + ) + ); + + // subtract B zero point + bv_col0 = vsubq_s8(bv_col0, vdupq_n_s8(bzp_col0)); + bv_col1 = vsubq_s8(bv_col1, vdupq_n_s8(bzp_col1)); + + // rows 0 and 1 of A + { + // load A + const int8x16_t av_row0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row1 = vld1q_s8(QuantADataPtrRow1 + 0); + + // quantized dot product + const int32x4_t dot00 = vdotq_s32(int32x4_t{}, av_row0, bv_col0); + const int32x4_t dot01 = vdotq_s32(int32x4_t{}, av_row0, bv_col1); + const int32x4_t dot10 = vdotq_s32(int32x4_t{}, av_row1, bv_col0); + const int32x4_t dot11 = vdotq_s32(int32x4_t{}, av_row1, bv_col1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + } + + // rows 2 and 3 of A + { + // load A + const int8x16_t av_row2 = vld1q_s8(QuantADataPtrRow2 + 0); + const int8x16_t av_row3 = vld1q_s8(QuantADataPtrRow3 + 0); + + // quantized dot product + const int32x4_t dot20 = vdotq_s32(int32x4_t{}, av_row2, bv_col0); + const int32x4_t dot21 = vdotq_s32(int32x4_t{}, av_row2, bv_col1); + const int32x4_t dot30 = vdotq_s32(int32x4_t{}, av_row3, bv_col0); + const int32x4_t dot31 = vdotq_s32(int32x4_t{}, av_row3, bv_col1); + + // convert to float + const float32x4_t dot_f32_20 = vcvtq_f32_s32(dot20); + const float32x4_t dot_f32_21 = vcvtq_f32_s32(dot21); + const float32x4_t dot_f32_30 = vcvtq_f32_s32(dot30); + const float32x4_t dot_f32_31 = vcvtq_f32_s32(dot31); + + // multiply by scale and update accumulator + acc20 = vfmaq_f32(acc20, dot_f32_20, vdupq_n_f32(scale20)); + acc21 = vfmaq_f32(acc21, dot_f32_21, vdupq_n_f32(scale21)); + acc30 = vfmaq_f32(acc30, dot_f32_30, vdupq_n_f32(scale30)); + acc31 = vfmaq_f32(acc31, dot_f32_31, vdupq_n_f32(scale31)); + } + } + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBDataPtr += 8; + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); + SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); + SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); + SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); + SumPtr[ldc * 2 + 0] = vaddvq_f32(acc20); + SumPtr[ldc * 2 + 1] = vaddvq_f32(acc21); + SumPtr[ldc * 3 + 0] = vaddvq_f32(acc30); + SumPtr[ldc * 3 + 1] = vaddvq_f32(acc31); + + if (BiasPtr != nullptr) { + SumPtr[ldc * 0 + 0] += BiasPtr[0]; + SumPtr[ldc * 0 + 1] += BiasPtr[1]; + SumPtr[ldc * 1 + 0] += BiasPtr[0]; + SumPtr[ldc * 1 + 1] += BiasPtr[1]; + SumPtr[ldc * 2 + 0] += BiasPtr[0]; + SumPtr[ldc * 2 + 1] += BiasPtr[1]; + SumPtr[ldc * 3 + 0] += BiasPtr[0]; + SumPtr[ldc * 3 + 1] += BiasPtr[1]; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK, + size_t StrideQuantA, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + size_t ldc +) +{ + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}, acc20{}, acc21{}, acc30{}, acc31{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlkRow0 = QuantAPtr; + const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + const std::byte* QuantABlkRow2 = QuantAPtr + StrideQuantA * 2; + const std::byte* QuantABlkRow3 = QuantAPtr + StrideQuantA * 3; + + const float QuantBScaleCol0 = *QuantBScalePtr; + const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); + + // compute combined scales + const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; + const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; + const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; + const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + const float scale20 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol0; + const float scale21 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol1; + const float scale30 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol0; + const float scale31 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol1; + + // load B zero point + int8_t bzp_col0; + int8_t bzp_col1; + if constexpr (HasZeroPoint) { + const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; + const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); + if ((k_blk_idx & 1) == 0) { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); + } else { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); + } + } else { + bzp_col0 = 8; + bzp_col1 = 8; + } + + const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); + const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + const int8_t* QuantADataPtrRow2 = Q8BlkData(QuantABlkRow2); + const int8_t* QuantADataPtrRow3 = Q8BlkData(QuantABlkRow3); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) { + // load B + const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed_col1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16)); + int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4)); + int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16)); + int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4)); + + // subtract B zero point + bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0)); + bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0)); + bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1)); + bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1)); + + // rows 0 and 1 of A + { + // load A + const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); + const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); + const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); + + // quantized dot product + const int32x4_t dot00 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); + const int32x4_t dot01 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); + const int32x4_t dot10 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); + const int32x4_t dot11 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + } + + // rows 2 and 3 of A + { + // load A + const int8x16_t av_row2_0 = vld1q_s8(QuantADataPtrRow2 + 0); + const int8x16_t av_row2_1 = vld1q_s8(QuantADataPtrRow2 + 16); + const int8x16_t av_row3_0 = vld1q_s8(QuantADataPtrRow3 + 0); + const int8x16_t av_row3_1 = vld1q_s8(QuantADataPtrRow3 + 16); + + // quantized dot product + const int32x4_t dot20 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row2_0, bv_col0_0), av_row2_1, bv_col0_1); + const int32x4_t dot21 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row2_0, bv_col1_0), av_row2_1, bv_col1_1); + const int32x4_t dot30 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row3_0, bv_col0_0), av_row3_1, bv_col0_1); + const int32x4_t dot31 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row3_0, bv_col1_0), av_row3_1, bv_col1_1); + + // convert to float + const float32x4_t dot_f32_20 = vcvtq_f32_s32(dot20); + const float32x4_t dot_f32_21 = vcvtq_f32_s32(dot21); + const float32x4_t dot_f32_30 = vcvtq_f32_s32(dot30); + const float32x4_t dot_f32_31 = vcvtq_f32_s32(dot31); + + // multiply by scale and update accumulator + acc20 = vfmaq_f32(acc20, dot_f32_20, vdupq_n_f32(scale20)); + acc21 = vfmaq_f32(acc21, dot_f32_21, vdupq_n_f32(scale21)); + acc30 = vfmaq_f32(acc30, dot_f32_30, vdupq_n_f32(scale30)); + acc31 = vfmaq_f32(acc31, dot_f32_31, vdupq_n_f32(scale31)); + } + + // increment block data pointers to next sub-block + QuantADataPtrRow0 += 32; + QuantADataPtrRow1 += 32; + QuantADataPtrRow2 += 32; + QuantADataPtrRow3 += 32; + QuantBDataPtr += 16; + } + + // increment other block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); + SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); + SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); + SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); + SumPtr[ldc * 2 + 0] = vaddvq_f32(acc20); + SumPtr[ldc * 2 + 1] = vaddvq_f32(acc21); + SumPtr[ldc * 3 + 0] = vaddvq_f32(acc30); + SumPtr[ldc * 3 + 1] = vaddvq_f32(acc31); + + if (BiasPtr != nullptr) { + SumPtr[ldc * 0 + 0] += BiasPtr[0]; + SumPtr[ldc * 0 + 1] += BiasPtr[1]; + SumPtr[ldc * 1 + 0] += BiasPtr[0]; + SumPtr[ldc * 1 + 1] += BiasPtr[1]; + SumPtr[ldc * 2 + 0] += BiasPtr[0]; + SumPtr[ldc * 2 + 1] += BiasPtr[1]; + SumPtr[ldc * 3 + 0] += BiasPtr[0]; + SumPtr[ldc * 3 + 1] += BiasPtr[1]; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + constexpr size_t BlkLen = 16; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); + + // load B + const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); + const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); + int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + bv1 = vsubq_s8(bv1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(int32x4_t{}, av0, bv0); + const int32x4_t dot1 = vdotq_s32(int32x4_t{}, av1, bv1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 8 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + + // load B + const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); + const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(int32x4_t{}, av0, bv0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + constexpr size_t BlkLen = 32; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); + const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + bv_lo1 = vsubq_s8(bv_lo1, bzp1); + bv_hi1 = vsubq_s8(bv_hi1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo0, bv_lo0), av_hi0, bv_hi0); + const int32x4_t dot1 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo1, bv_lo1), av_hi1, bv_hi1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo0, bv_lo0), av_hi0, bv_hi0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + + // load B zero point + const int8x16_t bzp = [&]() -> int8x16_t { + if constexpr (HasZeroPoint) { + return vdupq_n_s8( + ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) + : std::to_integer((*QuantBZeroPointPtr) >> 4) + ); + } else { + return vdupq_n_s8(8); + } + }(); + + const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { + // load A + const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); + const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); + const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); + const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp); + bv1 = vsubq_s8(bv1, bzp); + bv2 = vsubq_s8(bv2, bzp); + bv3 = vsubq_s8(bv3, bzp); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av0, bv0), av1, bv1); + const int32x4_t dot1 = vdotq_s32(vdotq_s32(int32x4_t{}, av2, bv2), av3, bv3); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale); + + // increment block data pointers to next sub-block + QuantADataPtr += 16 * 4; + QuantBDataPtr += 16 * 2; + } + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +AdvanceColPtrs( + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const std::byte*& QuantBDataColPtr, + const float*& QuantBScaleColPtr, + const std::byte*& QuantBZeroPointColPtr, + const float*& BiasPtr, + float*& SumPtr +) +{ + QuantBDataColPtr += NumCols * StrideQuantBData; + QuantBScaleColPtr += NumCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NumCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NumCols : 0; + SumPtr += NumCols; +} + +template +MLAS_FORCEINLINE void +AdvanceRowPtrs( + size_t StrideQuantA, + size_t ldc, + const std::byte*& QuantARowPtr, + float*& SumRowPtr +) +{ + QuantARowPtr += NumRows * StrideQuantA; + SumRowPtr += NumRows * ldc; +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLen16( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 16; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 3) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 4x2 tiles of output + SQ4BitGemm_CompInt8_Compute4x2_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 4x1 tile of output + for (size_t i = 0; i < 4; ++i) { + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr + StrideQuantA * i, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc * i, + BlockCountK + ); + } + } + + // Move to next 4 rows + AdvanceRowPtrs<4>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 4; + } + + while (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + + // Move to next row + AdvanceRowPtrs<1>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 1; + } +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLen32( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 3) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 4x2 tiles of output + SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 4x1 tile of output + for (size_t i = 0; i < 4; ++i) { + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr + StrideQuantA * i, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc * i, + BlockCountK + ); + } + } + + // Move to next 4 rows + AdvanceRowPtrs<4>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 4; + } + + while (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + + // Move to next row + AdvanceRowPtrs<1>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 1; + } +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 3) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 4x2 tiles of output + SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 4x1 tile of output + for (size_t i = 0; i < 4; ++i) { + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr + StrideQuantA * i, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc * i, + BlockCountK + ); + } + } + + // Move to next 4 rows + AdvanceRowPtrs<4>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 4; + } + + while (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + + // Move to next row + AdvanceRowPtrs<1>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 1; + } +} + +template +void +SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmKernel_CompInt8_BlkLen16( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmKernel_CompInt8_BlkLen32( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else { + SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } +} + +} // namespace + +size_t +SQ4BitGemmKernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else { + constexpr bool HasZeroPoint = false; + SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index b88f2d6a4637e..08066f030a381 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -126,7 +126,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, } else { MergeWeights(q_weight, k_weight, v_weight, result, hidden_size); } - initializer.set_raw_data(result.data(), gsl::narrow(element_count) * sizeof(float)); + utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow(element_count) * sizeof(float)); } else { // data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 const MLFloat16* q_weight = q_initializer.data(); const MLFloat16* k_weight = k_initializer.data(); @@ -138,7 +138,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, } else { MergeWeights(q_weight, k_weight, v_weight, result, hidden_size); } - initializer.set_raw_data(result.data(), gsl::narrow(element_count) * sizeof(MLFloat16)); + utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow(element_count) * sizeof(MLFloat16)); } return graph_utils::AddInitializer(graph, initializer); diff --git a/onnxruntime/core/optimizer/attention_fusion.h b/onnxruntime/core/optimizer/attention_fusion.h index acb478da5f31a..befb66b5aa960 100644 --- a/onnxruntime/core/optimizer/attention_fusion.h +++ b/onnxruntime/core/optimizer/attention_fusion.h @@ -19,7 +19,8 @@ class AttentionFusion : public GraphTransformer { Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - static bool FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_index_map, const logging::Logger& logger); + static bool FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, + std::map& mask_index_map, const logging::Logger& logger); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 6233e5b839bb9..ca744adddbeec 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -23,7 +23,8 @@ struct MatchGemmResult { }; // Compare the expected parameters (starts, ends, axes and step) -bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vector& input_indices, const std::vector& expected_values, const logging::Logger& logger) { +bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vector& input_indices, + const std::vector& expected_values, const logging::Logger& logger) { ORT_ENFORCE(input_indices.size() == expected_values.size() && input_indices.size() > 0); // Here assumes that the last element of input_indices is the maximum one. diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index 48df511d0c672..471e4ee7c03a3 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -491,7 +491,8 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i if (graph_outputs.count(output_def) > 0) { // Currently, we don't support eliminating the graph's outputs. - LOGS(logger, VERBOSE) << "Not eliminating output " << output_def->Name() << " of node " << node->Name() << "[" << node->OpType() << "] because it's the graph's output."; + LOGS(logger, VERBOSE) << "Not eliminating output " << output_def->Name() << " of node " << node->Name() + << "[" << node->OpType() << "] because it's the graph's output."; continue; } diff --git a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc index 913f3b6811183..86a7a4d6afbf8 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc @@ -188,7 +188,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph, "The total count of dims does not match the size of values. ", "total_count: ", total_count, " values.size(): ", values.size()); - const_tensor.set_raw_data(values.data(), values.size() * sizeof(int64_t)); + utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t)); return &graph_utils::AddInitializer(graph, const_tensor); } diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 9df300d6f4f88..1466de51d0b99 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -82,8 +82,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { shape_constant.set_name(constant_arg_out->Name()); shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); shape_constant.add_dims(clamped_slice_length); - shape_constant.set_raw_data(dim_values.data() + start, - clamped_slice_length * sizeof(int64_t)); + utils::SetRawDataInTensorProto(shape_constant, dim_values.data() + start, clamped_slice_length * sizeof(int64_t)); ONNX_NAMESPACE::TensorShapeProto result_shape; result_shape.add_dim()->set_dim_value(clamped_slice_length); constant_arg_out->SetShape(result_shape); diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 7b6f829b7a0a4..e8e395678436e 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -465,15 +465,13 @@ static NodeArg* ExtractEmbedding(Graph& graph, if (!CheckEmbeddingData(data, batch_size, element_count)) { return nullptr; } - - initializer.set_raw_data(data, gsl::narrow(element_count) * sizeof(float)); + utils::SetRawDataInTensorProto(initializer, data, gsl::narrow(element_count) * sizeof(float)); } else { // data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 const MLFloat16* data = old_initializer.data(); if (!CheckEmbeddingData(data, batch_size, element_count)) { return nullptr; } - - initializer.set_raw_data(data, gsl::narrow(element_count) * sizeof(MLFloat16)); + utils::SetRawDataInTensorProto(initializer, data, gsl::narrow(element_count) * sizeof(MLFloat16)); } NodeArg& node_arg = graph_utils::AddInitializer(graph, initializer); diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.h b/onnxruntime/core/optimizer/free_dim_override_transformer.h index f9553339a7ce6..18e0b128b8641 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.h +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/optimizer/graph_transformer.h" diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index 098278a77dafe..f431d98b3b827 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -10,7 +10,7 @@ namespace onnxruntime { /** @Class GatherSliceToSplitFusion -Fuse multiple Gather/Slice nodes that comsuming one output to one Split node. +Fuse multiple Gather/Slice nodes that consuming one output to one Split node. */ class GatherSliceToSplitFusion : public GraphTransformer { public: diff --git a/onnxruntime/core/optimizer/gemm_sum_fusion.h b/onnxruntime/core/optimizer/gemm_sum_fusion.h index 0e2ec104703f3..9b2fa22ecc317 100644 --- a/onnxruntime/core/optimizer/gemm_sum_fusion.h +++ b/onnxruntime/core/optimizer/gemm_sum_fusion.h @@ -12,14 +12,14 @@ namespace onnxruntime { Rewrite rule that fuses Gemm and Sum nodes to a single Gemm node. This fusion can be applied in the following scenario: -1) Sum at output of Gemm: when the output of a Gemm is immedietly summed with +1) Sum at output of Gemm: when the output of a Gemm is immediately summed with exactly one other element, we can fuse this Sum with Gemm by using the other Sum input as C, provided that the C input to the Gemm is missing. This is supported for opset >= 11, as this is when Gemm input C became optional. TODO: Support the Add use case: Sum(x, y) ~= Add. -This patterm is attempted to be triggered only on nodes with op type "Gemm". +This pattern is attempted to be triggered only on nodes with op type "Gemm". A --> Gemm --> D --> Sum --> E ^ ^ diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 53c7f39bdd7f1..4298551aec412 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -71,7 +71,6 @@ #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer.h" @@ -215,9 +214,9 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create - // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output - // or consume different initializers with same value, by default, CSE will not merge them. + // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for + // CSE. For example, if A and B nodes consume different initializers with same value, by default, + // CSE will not merge them. InlinedHashSet excluded_initializers; excluded_initializers.reserve(session_options.initializers_to_share_map.size()); for (const auto& p : session_options.initializers_to_share_map) { @@ -225,7 +224,6 @@ InlinedVector> GenerateTransformers( } const InlinedHashSet no_limit_empty_ep_list = {}; transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); - transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, session_options.config_options)); diff --git a/onnxruntime/core/optimizer/identity_elimination.h b/onnxruntime/core/optimizer/identity_elimination.h index 5e76275207c32..4b20edec12dfa 100644 --- a/onnxruntime/core/optimizer/identity_elimination.h +++ b/onnxruntime/core/optimizer/identity_elimination.h @@ -26,6 +26,6 @@ class EliminateIdentity : public RewriteRule { bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; -}; // namespace onnxruntime +}; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 3679a40d32eee..33fd613bb1a50 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -6,8 +6,7 @@ #include #include -#include "core/common/gsl.h" -#include "core/common/path.h" +#include #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_external_data_info.h" #include "core/platform/env.h" @@ -25,13 +24,13 @@ Initializer::Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, } } -Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const Path& model_path) { +Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& model_path) { ORT_ENFORCE(utils::HasDataType(tensor_proto), "Initializer must have a datatype"); #if !defined(__wasm__) // using full filepath is required by utils::TensorProtoToTensor(). One exception is WebAssembly platform, where // external data is not loaded from real file system. if (utils::HasExternalData(tensor_proto)) { - ORT_ENFORCE(!model_path.IsEmpty(), + ORT_ENFORCE(!model_path.empty(), "model_path must not be empty. Ensure that a path is provided when the model is created or loaded."); } #endif @@ -46,7 +45,7 @@ Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const // This must be pre-allocated Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, std::make_shared()); - ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path.ToPathString().c_str(), tensor_proto, w)); + ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path, tensor_proto, w)); data_ = std::move(w); } diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index 78e3fd6a3d24e..3099faed18ac3 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -7,10 +7,9 @@ #include #include #include - +#include #include "core/common/common.h" #include "core/common/narrow.h" -#include "core/common/path.h" #include "core/framework/allocator.h" #include "core/optimizer/graph_transformer.h" #include "core/framework/tensor_shape.h" @@ -28,7 +27,7 @@ class Initializer final { gsl::span dims); Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const Path& model_path = {}); + const std::filesystem::path& model_path = {}); ~Initializer() = default; diff --git a/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc b/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc index bc0523d517ee6..7d249ea715e8d 100644 --- a/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc +++ b/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc @@ -41,7 +41,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l auto input_defs = isinf_node.MutableInputDefs(); // see if there is a Cast before IsInf - // This will happen if input type is FP16 but IsInf doesnt support fp16, so it will be cast to float/double + // This will happen if input type is FP16 but IsInf doesn't support fp16, so it will be cast to float/double // This Cast can be skipped as we are replacing the subgraph with IsAllFinite, which supports FP16 auto cast1_node_iter = isinf_node.InputNodesBegin(); if (cast1_node_iter != isinf_node.InputNodesEnd() && diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index 7a43483cf37d4..39cd0dd186d53 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -8,7 +8,7 @@ namespace onnxruntime { /* * This fusion submerges a BatchNormalization operator to it's super - * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * preceding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() * is true. */ class MatmulBNFusion : public RewriteRule { @@ -24,4 +24,4 @@ class MatmulBNFusion : public RewriteRule { Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index e4cdeadbf54d7..338722fb00782 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -17,7 +17,8 @@ namespace onnxruntime { namespace { template struct ExtractScalarAsFloatDispatchTarget { - Status operator()(const ONNX_NAMESPACE::TensorProto& tensor_proto, const Path& model_path, float& scalar_float) { + Status operator()(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& model_path, + float& scalar_float) { T scalar; ORT_RETURN_IF_ERROR(utils::UnpackTensor(tensor_proto, model_path, &scalar, 1)); scalar_float = static_cast(scalar); diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 2b29473f876c3..46f306b92bed5 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -428,7 +428,8 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_W_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); nchwc_conv_W_tensor_proto.set_name(graph_.GenerateNodeArgName("reorder")); - nchwc_conv_W_tensor_proto.set_raw_data(reordered_filter.data(), reordered_filter.size() * sizeof(float)); + utils::SetRawDataInTensorProto(nchwc_conv_W_tensor_proto, reordered_filter.data(), + reordered_filter.size() * sizeof(float)); nchwc_conv_W_tensor_proto.add_dims(nchwc_output_channels); nchwc_conv_W_tensor_proto.add_dims(filter_input_channels); @@ -458,7 +459,8 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_B_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); nchwc_conv_B_tensor_proto.set_name(graph_.GenerateNodeArgName("reorder")); - nchwc_conv_B_tensor_proto.set_raw_data(aligned_bias.data(), gsl::narrow(nchwc_output_channels) * sizeof(float)); + utils::SetRawDataInTensorProto(nchwc_conv_B_tensor_proto, aligned_bias.data(), + gsl::narrow(nchwc_output_channels) * sizeof(float)); nchwc_conv_B_tensor_proto.add_dims(nchwc_output_channels); @@ -883,7 +885,8 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto; nchwc_conv_W_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); nchwc_conv_W_tensor_proto.set_name(graph_.GenerateNodeArgName("bn_scale")); - nchwc_conv_W_tensor_proto.set_raw_data(padded_buffer.data(), gsl::narrow(nchwc_channels) * sizeof(float)); + utils::SetRawDataInTensorProto(nchwc_conv_W_tensor_proto, padded_buffer.data(), + gsl::narrow(nchwc_channels) * sizeof(float)); nchwc_conv_W_tensor_proto.add_dims(nchwc_channels); nchwc_conv_W_tensor_proto.add_dims(1); nchwc_conv_W_tensor_proto.add_dims(1); @@ -896,7 +899,8 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { ONNX_NAMESPACE::TensorProto nchwc_conv_B_tensor_proto; nchwc_conv_B_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); nchwc_conv_B_tensor_proto.set_name(graph_.GenerateNodeArgName("bn_B")); - nchwc_conv_B_tensor_proto.set_raw_data(padded_buffer.data(), gsl::narrow(nchwc_channels) * sizeof(float)); + utils::SetRawDataInTensorProto(nchwc_conv_B_tensor_proto, padded_buffer.data(), + gsl::narrow(nchwc_channels) * sizeof(float)); nchwc_conv_B_tensor_proto.add_dims(nchwc_channels); auto* nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index 1eabc079f3a20..ed7d5feb2beb3 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -30,7 +30,7 @@ static size_t EstimateInputsOutputs(gsl::span nodes) { OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const InitializedTensorSet& initialized_tensor_set, - const Path& model_path, + const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, const std::function& is_sparse_initializer_func) : execution_provider_(execution_provider), @@ -52,7 +52,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, OrtValue ort_value; ORT_RETURN_IF_ERROR( utils::TensorProtoToOrtValue(Env::Default(), - model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), + model_path, tensor_proto, allocator_ptr_, ort_value)); initializers_[idx] = std::move(ort_value); @@ -77,7 +77,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, - const Path& model_path, + const std::filesystem::path& /* model_path */, const IExecutionProvider& execution_provider, const std::function& is_sparse_initializer_func) : execution_provider_(execution_provider), @@ -88,8 +88,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); // Create MLValues related maps - auto initialize_maps = [this, &initialized_tensor_set, &model_path](const NodeArg& arg, size_t /*index*/) -> Status { - (void)model_path; + auto initialize_maps = [this, &initialized_tensor_set](const NodeArg& arg, size_t /*index*/) -> Status { int idx = ort_value_name_idx_map_.Add(arg.Name()); ort_value_idx_nodearg_map_.insert_or_assign(idx, &arg); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index 3dbf6c1d97aa6..b0f7f461661b5 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/inlined_containers.h" #include "core/graph/graph.h" @@ -24,13 +25,13 @@ class OptimizerExecutionFrame final : public IExecutionFrame { public: Info(const std::vector& nodes, const InitializedTensorSet& initialized_tensor_set, - const Path& model_path, + const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, const std::function& is_sparse_initializer_func); Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, - const Path& model_path, + const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, const std::function& is_sparse_initializer_func); diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 6f0f38b1de56e..18e462c04dff3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -129,7 +129,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { weights_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8"); weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims()); - weights_proto_u8.set_raw_data(w_temp.data(), static_cast(w_temp.size())); + utils::SetRawDataInTensorProto(weights_proto_u8, w_temp.data(), static_cast(w_temp.size())); input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; @@ -140,7 +140,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8"); r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims()); - r_proto_u8.set_raw_data(r_temp.data(), static_cast(r_temp.size())); + utils::SetRawDataInTensorProto(r_proto_u8, r_temp.data(), static_cast(r_temp.size())); input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8); ONNX_NAMESPACE::TensorProto r_zp_proto_u8; diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc index 49ce06e47c98a..507bc71709b2f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc @@ -1,10 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/selectors_actions/actions.h" #include "core/optimizer/utils.h" @@ -31,98 +32,106 @@ bool CleanUpNodeSequence(NodeSequence node_sequence_type, Graph& graph, NodeInde if (!match_first(first_node) || // not filtering on provider currently // !graph_utils::IsSupportedProvider(first_node, compatible_execution_providers) || - !optimizer_utils::CheckOutputEdges(graph, first_node, 1)) { + !(first_node.GetOutputEdgesCount() >= 1)) { return false; } - Node& second_node = *graph.GetNode(first_node.OutputNodesBegin()->Index()); - if (!match_second(second_node) - // not filtering on provider currently - // || !graph_utils::IsSupportedProvider(second_node, compatible_execution_providers) - ) { - return false; + std::vector second_node_ptrs; + for (auto node_it = first_node.OutputNodesBegin(); node_it != first_node.OutputNodesEnd(); ++node_it) { + second_node_ptrs.push_back(graph.GetNode(node_it->Index())); } - if (node_sequence_type == NodeSequence::DQ_Q) { - // for DQ -> Q, check for constant, matching scale/ZP values + for (auto second_node_ptr : second_node_ptrs) { + // check for constant, matching scale/ZP values const auto get_constant_initializer = [&graph](const std::string& initializer_name) { return graph.GetConstantInitializer(initializer_name, true); }; - if (!QDQ::IsQDQPairSupported(second_node, first_node, get_constant_initializer, graph.ModelPath())) { + const bool produces_graph_output = graph.NodeProducesGraphOutput(*second_node_ptr); + const auto output_edges_count = second_node_ptr->GetOutputEdgesCount(); + + if (!match_second(*second_node_ptr) || + !QDQ::IsQDQPairSupported(first_node, *second_node_ptr, get_constant_initializer, graph.ModelPath(), false) || + (produces_graph_output && output_edges_count != 0) || + (!produces_graph_output && output_edges_count != 1)) { return false; } } // we have a node sequence to clean up - - // we support a second_node that produces a graph output if it has no output edges, or a second_node with one output edge. - const bool produces_graph_output = graph.NodeProducesGraphOutput(second_node); - const auto output_edges_count = second_node.GetOutputEdgesCount(); - - if ((produces_graph_output && output_edges_count != 0) || - (!produces_graph_output && output_edges_count != 1)) { - return false; + if (logger.GetSeverity() == logging::Severity::kVERBOSE) { + LOGS(logger, VERBOSE) << "Found back-to-back nodes: " + << first_node.OpType() << " with name \"" << first_node.Name() << "\""; + for (auto& second_node_ptr : second_node_ptrs) { + LOGS(logger, VERBOSE) << ", " << second_node_ptr->OpType() << " with name \"" << second_node_ptr->Name() << "\""; + } } - LOGS(logger, VERBOSE) << "Cleaning up back-to-back nodes: " - << first_node.OpType() << " with name \"" << first_node.Name() << "\" and " - << second_node.OpType() << " with name \"" << second_node.Name() << "\""; - - // src node or graph input/initializer -> first_node -> second_node -> downstream node or graph output - NodeIndex src_node_idx = 0; - int src_arg_idx = -1; - NodeIndex downstream_node_idx = 0; - int downstream_arg_idx = -1; - - // input could be node or initializer/graph input so need to handle both. - // if it's a node we need to replace the edge, so need info on which output idx it was attached to on the src node. - const Node::EdgeEnd* input_edge = nullptr; - if (first_node.GetInputEdgesCount() == 1) { - input_edge = &*first_node.InputEdgesBegin(); - src_node_idx = input_edge->GetNode().Index(); - src_arg_idx = input_edge->GetSrcArgIndex(); - // remove edge from src to first_node. dest arg idx is 0 as first_node (Q or DQ) only has one input - graph.RemoveEdge(src_node_idx, first_node.Index(), src_arg_idx, 0); - } + for (auto second_node_ptr : second_node_ptrs) { + Node& second_node = *graph.GetNode(second_node_ptr->Index()); + // we support a second_node that produces a graph output if it has no output edges, or a second_node with one + // output edge. + const bool produces_graph_output = graph.NodeProducesGraphOutput(second_node); + + // src node or graph input/initializer -> first_node -> second_node -> downstream node or graph output + NodeIndex src_node_idx = 0; + int src_arg_idx = -1; + NodeIndex downstream_node_idx = 0; + int downstream_arg_idx = -1; + + // input could be node or initializer/graph input so need to handle both. + // if it's a node we need to replace the edge, so need info on which output idx it was attached to on the src node. + const Node::EdgeEnd* input_edge = nullptr; + if (first_node.GetInputEdgesCount() == 1) { + input_edge = &*first_node.InputEdgesBegin(); + src_node_idx = input_edge->GetNode().Index(); + src_arg_idx = input_edge->GetSrcArgIndex(); + // remove edge from src to first_node. dest arg idx is 0 as first_node (Q or DQ) only has one input + if (second_node_ptr == second_node_ptrs.back()) { + graph.RemoveEdge(src_node_idx, first_node.Index(), src_arg_idx, 0); + } + } - // remove edge between pair we're removing - // both DQ and Q are single input single output so src idx and dest idx must be 0 - graph.RemoveEdge(first_node.Index(), second_node.Index(), 0, 0); + // remove edge between pair we're removing + // both DQ and Q are single input single output so src idx and dest idx must be 0 + graph.RemoveEdge(first_node.Index(), second_node.Index(), 0, 0); - if (!produces_graph_output) { - // remove edge to downstream node - const Node::EdgeEnd& output_edge = *second_node.OutputEdgesBegin(); - downstream_node_idx = output_edge.GetNode().Index(); - downstream_arg_idx = output_edge.GetDstArgIndex(); + if (!produces_graph_output) { + // remove edge to downstream node + const Node::EdgeEnd& output_edge = *second_node.OutputEdgesBegin(); + downstream_node_idx = output_edge.GetNode().Index(); + downstream_arg_idx = output_edge.GetDstArgIndex(); - // source arg idx is 0 as Q/DQ only has one output - graph.RemoveEdge(second_node.Index(), downstream_node_idx, 0, downstream_arg_idx); + // source arg idx is 0 as Q/DQ only has one output + graph.RemoveEdge(second_node.Index(), downstream_node_idx, 0, downstream_arg_idx); - // replace input on downstream node - Node& downstream_node = *graph.GetNode(downstream_node_idx); - downstream_node.MutableInputDefs()[downstream_arg_idx] = first_node.MutableInputDefs()[0]; + // replace input on downstream node + Node& downstream_node = *graph.GetNode(downstream_node_idx); + downstream_node.MutableInputDefs()[downstream_arg_idx] = first_node.MutableInputDefs()[0]; - // create edge between src_node (if available) and downstream node - if (input_edge) { - graph.AddEdge(src_node_idx, downstream_node_idx, src_arg_idx, downstream_arg_idx); - } - } else { - NodeArg* graph_output_nodearg = second_node.MutableOutputDefs()[0]; - if (src_arg_idx >= 0) { - // update the src node to produce the graph output that was being provided by second_node - Node& src_node = *graph.GetNode(src_node_idx); - src_node.MutableOutputDefs()[src_arg_idx] = graph_output_nodearg; + // create edge between src_node (if available) and downstream node + if (input_edge) { + graph.AddEdge(src_node_idx, downstream_node_idx, src_arg_idx, downstream_arg_idx); + } } else { - // add Identity node to connect the graph input or initializer to the graph output. - Node& id_node = graph.AddNode(graph.GenerateNodeName("QDQFinalCleanupTransformer"), - "Identity", "", {first_node.MutableInputDefs()[0]}, {graph_output_nodearg}); - id_node.SetExecutionProviderType(second_node.GetExecutionProviderType()); + NodeArg* graph_output_nodearg = second_node.MutableOutputDefs()[0]; + if (src_arg_idx >= 0 && second_node_ptrs.size() == 1) { + // update the src node to produce the graph output that was being provided by second_node + Node& src_node = *graph.GetNode(src_node_idx); + src_node.MutableOutputDefs()[src_arg_idx] = graph_output_nodearg; + } else { + // add Identity node to connect the graph input or initializer to the graph output. + Node& id_node = graph.AddNode(graph.GenerateNodeName("QDQFinalCleanupTransformer"), + "Identity", "", {first_node.MutableInputDefs()[0]}, {graph_output_nodearg}); + id_node.SetExecutionProviderType(second_node.GetExecutionProviderType()); + } } - } - graph.RemoveNode(first_node.Index()); - graph.RemoveNode(second_node.Index()); + if (second_node_ptr == second_node_ptrs.back()) { + graph.RemoveNode(first_node.Index()); + } + graph.RemoveNode(second_node.Index()); + } return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc index 199fbffc9f723..f2033dcbc1b03 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc @@ -60,7 +60,7 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) { ONNX_NAMESPACE::TensorProto zp_tensor_proto_u8; zp_tensor_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); zp_tensor_proto_u8.set_name(graph.GenerateNodeArgName("qdq_s8_to_u8_zp_conversion")); - zp_tensor_proto_u8.set_raw_data(&q_zp_value, sizeof(uint8_t)); + utils::SetRawDataInTensorProto(zp_tensor_proto_u8, &q_zp_value, sizeof(uint8_t)); NodeArg* zp_u8_arg = &graph_utils::AddInitializer(graph, zp_tensor_proto_u8); auto q_output_node_arg_name = graph.GenerateNodeArgName("qdq_s8_to_u8_quant"); diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index e245636ce9a84..a4d1ea3c7cf56 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -17,7 +17,14 @@ namespace onnxruntime::QDQ { bool IsQDQPairSupported( const Node& q_node, const Node& dq_node, const GetConstantInitializerFn& get_const_initializer, - const Path& model_path) { + const std::filesystem::path& model_path, + bool check_op_type) { + if (check_op_type) { + if (!MatchQNode(q_node) || !MatchDQNode(dq_node)) { + return false; + } + } + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); ConstPointerContainer> q_input_defs = q_node.InputDefs(); @@ -79,7 +86,7 @@ bool IsQDQPairSupported( bool IsDQQConversion( const Node& dq_node, const Node& q_node, const GetConstantInitializerFn& get_const_initializer, - const Path& model_path) { + const std::filesystem::path& model_path) { ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); ConstPointerContainer> q_input_defs = q_node.InputDefs(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 8333168b0093f..5d11b8bfd5558 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -5,6 +5,7 @@ #include #include +#include namespace ONNX_NAMESPACE { class TensorProto; @@ -36,7 +37,8 @@ using GetConstantInitializerFn = std::function Q sequence represents a conversion in quantization data type. // Example of uint8 to uint16: @@ -48,7 +50,7 @@ bool IsQDQPairSupported( bool IsDQQConversion( const Node& dq_node, const Node& q_node, const GetConstantInitializerFn& get_const_initializer, - const Path& model_path); + const std::filesystem::path& model_path); // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h index 6caa35ea61ed7..1c1341fe5a127 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h +++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h @@ -27,7 +27,7 @@ inline bool Int8TensorProto2Uint8( if (nullptr == src) { uint8_t zero_val = 128; dst.set_name(graph.GenerateNodeArgName("weight_zp_s8_2_u8")); - dst.set_raw_data(&zero_val, sizeof(uint8_t)); + utils::SetRawDataInTensorProto(dst, &zero_val, sizeof(uint8_t)); return true; } @@ -58,7 +58,7 @@ inline bool Int8TensorProto2Uint8( p++; } if (force || should_convert) { - dst.set_raw_data(temp.data(), size_t(temp.size())); + utils::SetRawDataInTensorProto(dst, temp.data(), size_t(temp.size())); return true; } return false; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3d2a81ce7f8cd..3497ea4c85523 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -5,6 +5,7 @@ #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/graph/node_attr_utils.h" +#include "core/framework/tensorprotoutils.h" namespace onnxruntime { namespace QDQ { @@ -132,7 +133,7 @@ struct SetOptionalZeroPoint { ONNX_NAMESPACE::TensorProto tensor_proto; tensor_proto.set_name(name); tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); - tensor_proto.set_raw_data(a.data(), sizeof(int8_t)); + onnxruntime::utils::SetRawDataInTensorProto(tensor_proto, a.data(), sizeof(int8_t)); return tensor_proto; }; @@ -145,8 +146,7 @@ struct SetOptionalZeroPoint { ONNX_NAMESPACE::TensorProto tensor_proto; tensor_proto.set_name(name); tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); - tensor_proto.set_raw_data(a.data(), sizeof(uint8_t)); - + onnxruntime::utils::SetRawDataInTensorProto(tensor_proto, a.data(), sizeof(uint8_t)); return tensor_proto; }; static ONNX_NAMESPACE::TensorProto GetOptionalZeroPointInt8() { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index de36202afff29..f388206551172 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -5,7 +5,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/framework/node_unit.h" #include "core/graph/basic_types.h" diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 7768a835d5042..7f94e18458be2 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -435,7 +435,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo shape_initializer_proto.set_name(shape_def->Name()); shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t)); + utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto); // Safely remove concat parent nodes which have only one output diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index e52ab16efe95a..9384bfa7027cd 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -6,7 +6,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/graph/graph_utils.h" // TODO: Minimize usage of this given we want to use Actions in a minimal build #include "core/graph/runtime_optimization_record.h" #include "core/optimizer/selectors_actions/helpers.h" diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.h b/onnxruntime/core/optimizer/selectors_actions/helpers.h index cf5489dc1960a..c3d50d6de05ad 100644 --- a/onnxruntime/core/optimizer/selectors_actions/helpers.h +++ b/onnxruntime/core/optimizer/selectors_actions/helpers.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/basic_types.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/graph/graph.h" #include "core/graph/runtime_optimization_record.h" diff --git a/onnxruntime/core/optimizer/shape_input_merge.cc b/onnxruntime/core/optimizer/shape_input_merge.cc deleted file mode 100644 index dec1382319f16..0000000000000 --- a/onnxruntime/core/optimizer/shape_input_merge.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/shape_input_merge.h" - -#include "core/graph/graph_utils.h" - -namespace onnxruntime { - -namespace { -std::string GetShapeString(const NodeArg* input_arg) { - auto shape = input_arg->Shape(); - if (!shape) return ""; - std::stringstream ss; - ss << "["; - for (int i = 0; i < shape->dim_size(); ++i) { - if (i != 0) ss << ","; - auto dim = shape->dim(i); - if (dim.has_dim_value()) { - ss << std::to_string(dim.dim_value()); - } else if (dim.has_dim_param()) { - ss << "'" << dim.dim_param() << "'"; - } else { - return ""; - } - } - ss << "]"; - return ss.str(); -} - -} // namespace - -Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedHashMap> input_hash_to_nodes; - for (auto node_index : node_topology_list) { - auto* p_node = graph.GetNode(node_index); - if (!p_node) continue; // we removed the node as part of an earlier fusion - ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) || - !graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) { - continue; - } - std::string shape_str = GetShapeString(p_node->InputDefs()[0]); - if (shape_str.empty()) continue; - if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) { - input_hash_to_nodes[shape_str] = InlinedVector(); - } - input_hash_to_nodes[shape_str].emplace_back(p_node); - } - - // All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input. - for (auto& kv : input_hash_to_nodes) { - if (kv.second.size() < 2) continue; - NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0]; - bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg); - for (size_t i = 1; i < kv.second.size(); ++i) { - Node* p_node = kv.second[i]; - const NodeArg* input_arg = p_node->InputDefs()[0]; - if (input_arg->Name() == first_input_arg->Name()) continue; - if (!graph.IsInputsIncludingInitializers(input_arg) && p_node->GetInputEdgesCount()) { - const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin(); - graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0); - } - graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg); - if (!is_first_input_arg_graph_input && kv.second[0]->GetInputEdgesCount()) { - const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin(); - graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0); - } - modified = true; - } - } - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_input_merge.h b/onnxruntime/core/optimizer/shape_input_merge.h deleted file mode 100644 index 5cb943998487b..0000000000000 --- a/onnxruntime/core/optimizer/shape_input_merge.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -/** -@Class ShapeInputMerge -Merge all shape inputs having same shape value to a single shape input. -This change will not affect the performance, but it open chances for CSE fusion to merge nodes. -*/ -class ShapeInputMerge : public GraphTransformer { - public: - ShapeInputMerge(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("ShapeInputMerge", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc index a54904ff15e1e..5c09e5225ab9c 100644 --- a/onnxruntime/core/optimizer/stft_decomposition.cc +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -45,7 +45,7 @@ NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[T element_count *= shape[i]; proto.add_dims(shape[i]); } - proto.set_raw_data(begin, element_count * sizeof(TDataType)); + utils::SetRawDataInTensorProto(proto, begin, element_count * sizeof(TDataType)); return &graph_utils::AddInitializer(graph, proto); } diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 0d7ab70eba613..f1e94dd4fe9e4 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -11,7 +11,7 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { // implements MemCpy node insertion in graph transform -// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTranformer +// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer class TransformerMemcpyImpl { public: TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index e6ffd0d91372b..d4ed9c4e26cc6 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -12,7 +12,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/make_string.h" #include "core/graph/constants.h" diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index c532f56b3d3d9..f756d01413eae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -46,11 +46,12 @@ class ApiValueInfo final : public api::ValueInfoRef { class ApiTensor final : public api::TensorRef { private: const onnx::TensorProto& tensor_proto_; - const Path& model_path_; + const std::filesystem::path& model_path_; AllocatorPtr cpu_allocator_; public: - explicit ApiTensor(const onnx::TensorProto& tensor_proto, const Path& model_path, AllocatorPtr cpu_allocator) + explicit ApiTensor(const onnx::TensorProto& tensor_proto, const std::filesystem::path& model_path, + AllocatorPtr cpu_allocator) : tensor_proto_(tensor_proto), model_path_(model_path), cpu_allocator_(std::move(cpu_allocator)) {} const onnx::TensorProto& TensorProto() { @@ -289,10 +290,12 @@ std::vector ApiTensor::Data() const { auto tensor_shape_dims = utils::GetTensorShapeFromTensorProto(tensor_proto_); TensorShape tensor_shape{std::move(tensor_shape_dims)}; onnxruntime::Tensor tensor(tensor_dtype, tensor_shape, cpu_allocator_); - ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path_.ToPathString().c_str(), + ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path_, tensor_proto_, tensor)); size_t num_bytes = gsl::narrow_cast(tensor.SizeInBytes()); const uint8_t* data = static_cast(tensor.DataRaw()); + // TODO: the returned data is unaligned, which does not meet the alignment requirement that mlas requires. Because + // the returned type is a vector, not a Tensor or tensor buffer that is allocated from a CPU allocator. return std::vector(data, data + num_bytes); } // @@ -554,7 +557,7 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector(dtype)); tensor_proto.set_name(name); - tensor_proto.set_raw_data(data.data(), data.size()); for (int64_t dim : shape) { tensor_proto.add_dims(dim); } + utils::SetRawDataInTensorProto(tensor_proto, data.data(), data.size()); const auto& node_arg = graph_utils::AddInitializer(graph_, tensor_proto); return node_arg.Name(); diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 6917f42091bf3..c42b31e64d129 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -22,7 +22,7 @@ limitations under the License. #include #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/path_string.h" @@ -147,6 +147,8 @@ class Env { virtual std::vector GetDefaultThreadAffinities() const = 0; + virtual int GetL2CacheSize() const = 0; + /// \brief Returns the number of micro-seconds since the Unix epoch. virtual uint64_t NowMicros() const { return env_time_->NowMicros(); diff --git a/onnxruntime/core/platform/path_lib.h b/onnxruntime/core/platform/path_lib.h index a9d89f32e91d3..fca8990f14821 100644 --- a/onnxruntime/core/platform/path_lib.h +++ b/onnxruntime/core/platform/path_lib.h @@ -281,7 +281,7 @@ void LoopDir(const std::string& dir_name, T func) { ORT_TRY { struct dirent* dp; while ((dp = readdir(dir)) != nullptr) { - std::basic_string filename = ConcatPathComponent(dir_name, dp->d_name); + std::basic_string filename = ConcatPathComponent(dir_name, dp->d_name); if (stat(filename.c_str(), &stats) != 0) { continue; } diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 9999550c241c8..04cf5ff6a3329 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -26,7 +26,9 @@ limitations under the License. #include #include #include +#if !defined(_AIX) #include +#endif #include #include @@ -43,8 +45,12 @@ limitations under the License. #define ORT_USE_CPUINFO #endif +#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) +#include +#endif + #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/platform/scoped_resource.h" @@ -302,6 +308,22 @@ class PosixEnv : public Env { return ret; } + int GetL2CacheSize() const override { +#ifdef _SC_LEVEL2_CACHE_SIZE + return static_cast(sysconf(_SC_LEVEL2_CACHE_SIZE)); +#else + int value = 0; // unknown +#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)) && defined(HW_L2CACHESIZE) + int mib[2] = {CTL_HW, HW_L2CACHESIZE}; + size_t len = sizeof(value); + if (sysctl(mib, 2, &value, &len, NULL, 0) < 0) { + return -1; // error + } +#endif + return value; +#endif + } + void SleepForMicroseconds(int64_t micros) const override { while (micros > 0) { timespec sleep_time; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index dc090e446e60f..73319cd9c9b1c 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/common/span_utils.h" @@ -303,6 +303,10 @@ std::vector WindowsEnv::GetDefaultThreadAffinities() const { return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; } +int WindowsEnv::GetL2CacheSize() const { + return l2_cache_size_; +} + WindowsEnv& WindowsEnv::Instance() { static WindowsEnv default_env; return default_env; @@ -851,6 +855,7 @@ ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) cons } WindowsEnv::WindowsEnv() { + l2_cache_size_ = 0; InitializeCpuInfo(); } @@ -924,9 +929,57 @@ void WindowsEnv::InitializeCpuInfo() { } iter += size; } + + DWORD newLength = 0; + GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); + last_error = GetLastError(); + if (last_error != ERROR_INSUFFICIENT_BUFFER) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + if (newLength > returnLength) { + // Re-allocate + allocation = std::make_unique(newLength); + processorInfos = reinterpret_cast(allocation.get()); + } + + if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + iter = reinterpret_cast(processorInfos); + end = iter + newLength; + + while (iter < end) { + auto processor_info = reinterpret_cast(iter); + auto size = processor_info->Size; + + if (processor_info->Relationship == RelationCache && + processor_info->Cache.Level == 2) { + // L2 cache + l2_cache_size_ = static_cast(processor_info->Cache.CacheSize); + break; + } + + iter += size; + } + if (logging::LoggingManager::HasDefaultLogger()) { LOGS_DEFAULT(VERBOSE) << "Found total " << cores_.size() << " core(s) from windows system:"; LOGS_DEFAULT(VERBOSE) << log_stream.str(); + LOGS_DEFAULT(VERBOSE) << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes"; } } } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/env.h b/onnxruntime/core/platform/windows/env.h index 79739db9e5640..395aface1d809 100644 --- a/onnxruntime/core/platform/windows/env.h +++ b/onnxruntime/core/platform/windows/env.h @@ -55,6 +55,7 @@ class WindowsEnv : public Env { static int DefaultNumCores(); int GetNumPhysicalCpuCores() const override; std::vector GetDefaultThreadAffinities() const override; + int GetL2CacheSize() const override; static WindowsEnv& Instance(); PIDType GetSelfPid() const override; Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override; @@ -113,6 +114,8 @@ class WindowsEnv : public Env { * } */ std::vector cores_; + + int l2_cache_size_; /* * "global_processor_info_map_" is a map of: * global_processor_id <--> (group_id, local_processor_id) diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index 5fb7f7a65161d..b0f9eaf4f62d2 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -104,7 +104,16 @@ HRESULT EtwRegistrationManager::Status() const { void EtwRegistrationManager::RegisterInternalCallback(const EtwInternalCallback& callback) { std::lock_guard lock(callbacks_mutex_); - callbacks_.push_back(callback); + callbacks_.push_back(&callback); +} + +void EtwRegistrationManager::UnregisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); } void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( @@ -126,6 +135,8 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( } EtwRegistrationManager::~EtwRegistrationManager() { + std::lock_guard lock(callbacks_mutex_); + callbacks_.clear(); ::TraceLoggingUnregister(etw_provider_handle); } @@ -150,7 +161,7 @@ void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, PVOID CallbackContext) { std::lock_guard lock(callbacks_mutex_); for (const auto& callback : callbacks_) { - callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 5d35d101f1242..3af45b813a625 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -71,6 +71,8 @@ class EtwRegistrationManager { void RegisterInternalCallback(const EtwInternalCallback& callback); + void UnregisterInternalCallback(const EtwInternalCallback& callback); + private: EtwRegistrationManager(); ~EtwRegistrationManager(); @@ -90,7 +92,7 @@ class EtwRegistrationManager { _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext); - std::vector callbacks_; + std::vector callbacks_; OrtMutex callbacks_mutex_; mutable OrtMutex provider_change_mutex_; OrtMutex init_mutex_; diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index d7d423e4a483e..3401507ae911f 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -12,7 +12,7 @@ #endif #include "core/common/logging/logging.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 850f40e846248..86067d377205b 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -64,7 +64,7 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; -std::vector WindowsTelemetry::callbacks_; +std::vector WindowsTelemetry::callbacks_; OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { @@ -86,6 +86,9 @@ WindowsTelemetry::~WindowsTelemetry() { TraceLoggingUnregister(telemetry_provider_handle); } } + + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.clear(); } bool WindowsTelemetry::IsEnabled() const { @@ -108,8 +111,17 @@ UINT64 WindowsTelemetry::Keyword() const { // } void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { - std::lock_guard lock(callbacks_mutex_); - callbacks_.push_back(callback); + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.push_back(&callback); +} + +void WindowsTelemetry::UnregisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock_callbacks(callbacks_mutex_); + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); } void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( @@ -131,9 +143,9 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { - std::lock_guard lock(callbacks_mutex_); + std::lock_guard lock_callbacks(callbacks_mutex_); for (const auto& callback : callbacks_) { - callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } } diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 27cd20c2d21d1..ed80f13e633ac 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -66,13 +66,15 @@ class WindowsTelemetry : public Telemetry { static void RegisterInternalCallback(const EtwInternalCallback& callback); + static void UnregisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; - static std::vector callbacks_; + static std::vector callbacks_; static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; diff --git a/onnxruntime/core/providers/coreml/DebugMLProgram.md b/onnxruntime/core/providers/coreml/DebugMLProgram.md new file mode 100644 index 0000000000000..e41a515594303 --- /dev/null +++ b/onnxruntime/core/providers/coreml/DebugMLProgram.md @@ -0,0 +1,85 @@ +# Steps to debug an ML Program operator implementation + +Basic debugging of everything, excluding model execution, (e.g. partitioning, checking if operator is supported, +adding CoreML operator input/outputs) can be done anywhere as the code is setup to build and be able to create the +protobuf based CoreML Model on all platforms. + +To debug model execution issues you will need a macOS machine. + +## Debugging invalid output + +If there is a crash during execution or unexpected output, the best approach is to see what using coremltools directly +produces. + +NOTE: that doesn't guarantee coremltools is correct as there could be a bug in their implementation. It does however +provide a data point on whether we are generating the same CoreML model as the coremltools python. + +### Comparing to coremltools output + +Create a small test script that replicates the inputs/outputs of the operator you are debugging. +This script should use the coremltools library to run the operator and print the output. +This can be used to compare the CoreML EP's output with the coremltools output. + +https://apple.github.io/coremltools/docs-guides/source/model-intermediate-language.html#create-a-mil-program + +Usage is reasonably intuitive. The below example defines a model with 2 inputs and a matmul operator. +The model is printed, and run with randomly generated inputs. The output from doing so is printed. + +```python +import numpy as np +import coremltools as ct +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (1, 4) +y_shape = (10, 4, 3) + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape), mb.TensorSpec(shape=y_shape)], + opset_version=target) +def prog(x, y): + # For reference, a constant can be added using `mb.const` and specifying the data in the `val` parameter. + # c_shape = (3, ) + # c_data = np.random.random_sample(c_shape) + # c = mb.const(val=c_data) + + # call the operator you are debugging with the inputs/constants. + # See the spec for the operator names, input/outputs and supported data types. + # https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html + z = mb.matmul(x=x, y=y) + + # can have additional function calls here if there are multiple operators involved. + # Contrived example that uses a constant and the output from a previous operator: + # z = mb.add(x=z, y=c) + + return z + +# Prints the MIL program in a reasonably concise manner. +print(prog) + +# Convert to ML Program model +m = ct.convert(prog, minimum_deployment_target=target) + +# If you want to dump the full protobuf of the model uncomment this. +# You can compare the values to what is being set by the ORT CoreML EP code if you suspect any issues there. +# spec = m.get_spec() +# print(spec) + +# run the model to generate output for comparison with the CoreML EP output +x = np.random.rand(*x_shape) +y = np.random.rand(*y_shape) + +print(m.predict({'x': x, 'y': y})) +``` + +## Dumping the ORT generated mlmodel + +You can also dump the mlmodel generated by the ORT CoreML EP. This can be handy with larger models. + +In a debug build, set the ORT_COREML_EP_MODEL_DIR environment variable to a directory where you want the ML Package +containing the mlmodel to be saved. The model will remain after the CoreML EP exits, unlike the default behavior +where we write it to a temporary directory that is automatically removed on application exit. + +Script to dump: [dump_mlprogram_model.py](dump_mlprogram_model.py) + +See [here](https://github.com/microsoft/onnxruntime/blob/3c0b407709fd3c71755ed046edd688b30a786d94/onnxruntime/core/providers/coreml/model/host_utils.h#L70-L75) for environment variable setup and [usage](https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20kOverrideModelOutputDirectoryEnvVar%20&type=code). diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index dee87ce3632a8..0e21715513707 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -26,6 +26,8 @@ class ActivationOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override; int GetMinSupportedOpSet(const Node& node) const override; + + bool SupportsMLProgram() const override { return true; } }; void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -74,33 +76,61 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - const auto& op_type(node.OpType()); - if (op_type == "Sigmoid") { - layer->mutable_activation()->mutable_sigmoid(); - } else if (op_type == "Tanh") { - layer->mutable_activation()->mutable_tanh(); - } else if (op_type == "Relu") { - layer->mutable_activation()->mutable_relu(); - } else if (op_type == "PRelu") { - auto* prelu = layer->mutable_activation()->mutable_prelu(); - ORT_RETURN_IF_ERROR(AddPReluWeight(model_builder, node, logger, *prelu)); - } else if (op_type == "LeakyRelu") { - NodeAttrHelper helper(node); - const auto alpha = helper.Get("alpha", 0.01f); - - auto* leaky_relu = layer->mutable_activation()->mutable_leakyrelu(); - leaky_relu->set_alpha(alpha); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation + std::string_view coreml_op_type; + if (op_type == "Sigmoid") { + coreml_op_type = "sigmoid"; + } else if (op_type == "Tanh") { + coreml_op_type = "tanh"; + } else if (op_type == "Relu") { + coreml_op_type = "relu"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + if (op_type == "Sigmoid") { + layer->mutable_activation()->mutable_sigmoid(); + } else if (op_type == "Tanh") { + layer->mutable_activation()->mutable_tanh(); + } else if (op_type == "Relu") { + layer->mutable_activation()->mutable_relu(); + } else if (op_type == "PRelu") { + auto* prelu = layer->mutable_activation()->mutable_prelu(); + ORT_RETURN_IF_ERROR(AddPReluWeight(model_builder, node, logger, *prelu)); + } else if (op_type == "LeakyRelu") { + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.01f); + + auto* leaky_relu = layer->mutable_activation()->mutable_leakyrelu(); + leaky_relu->set_alpha(alpha); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -165,9 +195,20 @@ bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_para bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& op_type = node.OpType(); - if (op_type == "PRelu") { - return IsPReluOpSupported(node, input_params, logger); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + if (op_type == "PRelu" || op_type == "LeakyRelu") { + return false; + } + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + if (op_type == "PRelu") { + return IsPReluOpSupported(node, input_params, logger); + } } + return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index cbea969904ed5..2fcf9a1d7d9ba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -140,30 +140,44 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span> shape) { + std::optional> shape, bool convert_scalar = false) { tensor_type.set_datatype(data_type); if (shape) { - tensor_type.set_rank(shape->size()); - for (const auto& dim : *shape) { - if (dim >= 0) { - tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); - } else { - tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + auto rank = shape->size(); + if (convert_scalar && rank == 0) { + // CoreML scalar has shape {1} + tensor_type.set_rank(1); + tensor_type.add_dimensions()->mutable_constant()->set_size(1); + } else { + tensor_type.set_rank(rank); + for (const auto& dim : *shape) { + if (dim >= 0) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } } } } } void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type, - const ONNX_NAMESPACE::TensorShapeProto* shape) { + const ONNX_NAMESPACE::TensorShapeProto* shape, bool convert_scalar = false) { tensor_type.set_datatype(data_type); if (shape) { - tensor_type.set_rank(shape->dim_size()); - for (const auto& dim : shape->dim()) { - if (dim.has_dim_value()) { - tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); - } else { - tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + auto rank = shape->dim_size(); + if (convert_scalar && rank == 0) { + // CoreML scalar has shape {1} + tensor_type.set_rank(1); + tensor_type.add_dimensions()->mutable_constant()->set_size(1); + } else { + tensor_type.set_rank(rank); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } } } } @@ -281,13 +295,13 @@ template MILSpec::Value CreateScalarTensorValue(const int32_t& data); template MILSpec::Value CreateScalarTensorValue(const std::string& data); template MILSpec::Value CreateScalarTensorValue(const bool& data); -COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) { +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar) { MILSpec::NamedValueType nvt; nvt.set_name(node_arg.Name()); MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype(); SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()), - node_arg.Shape()); + node_arg.Shape(), convert_scalar); return nvt; } @@ -308,7 +322,7 @@ void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& outp MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), - output.Shape()); + output.Shape(), /*convert_scalar*/ true); } void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 2804589065631..97fb83b6dc482 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -7,7 +7,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" @@ -114,8 +114,10 @@ template COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data); /// Create a NamedValueType from an ONNX tensor NodeArg. +/// NodeArg to create NamedValueType from. +/// If true, scalar shapes are converted to 1D. /// Used to create inputs for the 'main' function in an ML Program. -COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg); +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar = false); /// /// Add an input argument to a MILSpec::Operation diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 8daf64dc4a457..7338fc18fe779 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -109,19 +109,11 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_IGNORE_RETURN_VALUE(GetShape(b, b_shape, logger)); int64_t b0 = -1, b1 = -1; - // ML Program MatMul supports N-D input if (model_builder.CreateMLProgram() && is_matmul) { - if (b_shape.size() == 1) { - // B is treated as {b_shape[0], 1} according to the numpy rules. - b0 = b_shape[0]; - b1 = 1; - } else { - // last 2 dims are used - b0 = b_shape[b_shape.size() - 2]; - b1 = b_shape[b_shape.size() - 1]; - } + // ML Program MatMul supports N-D input, however we don't use the 'K' or 'N' values calculated below for it + // so we don't need to update b0 or b1. } else { - // we only support 2D input + // we only support 2D input for all other combinations b0 = b_shape[0]; b1 = b_shape[1]; } @@ -182,7 +174,6 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N model_builder.AddOperation(std::move(gemm_op)); } else { // CoreML implementation is the same as ONNX MatMul. - // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.matmul auto matmul_op = model_builder.CreateOperation(node, "matmul"); AddOperationInput(*matmul_op, "x", a.Name()); AddOperationInput(*matmul_op, "y", b.Name()); @@ -268,14 +259,28 @@ bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara } if (is_matmul) { + const auto a_rank = a_shape.size(); + const auto b_rank = b_shape.size(); + if (input_params.create_mlprogram) { - // ML Program matmul op has numpy semantics the same as the ONNX spec so we can use directly + // ML Program matmul op has numpy semantics the same as the ONNX spec, so we can use directly. + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.matmul + // + // There does appear to be a bug in handling one of the inputs being 1D, so for now skip these. + // See https://github.com/apple/coremltools/issues/2263 + // + // If required for perf we could manually do the shape alterations the spec documents (convert input to 2D, + // and remove extra dimension from output), as the 2D input is correctly handled by CoreML matmul. + if ((a_rank == 1 && b_rank > 1) || (a_rank > 1 && b_rank == 1)) { + LOGS(logger, VERBOSE) << "Skipping due to bug in CoreML ML Program when one of the inputs is 1D."; + return false; + } } else { // we could potentially support 1D and 3D if required. beyond 3D the dims that merge diverge. // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/onnx/_operators.py#L1607 // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/backend/nn/op_mapping.py#L1374 // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#innerproductlayerparams - if (a_shape.size() != 2 || b_shape.size() != 2) { + if (a_rank != 2 || b_rank != 2) { LOGS(logger, VERBOSE) << "a and b inputs must be 2D. "; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 6c2fcc2ace856..3400f09b4056f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -259,13 +259,13 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa // Onnx spec requires output sizes to be a positive integer, so we are not checking that here if (output_size_h % input_size_h != 0) { LOGS(logger, VERBOSE) << "Resize: output_size_h: " << output_size_h - << " is not a mutliple of input_size_h: " << input_size_h; + << " is not a multiple of input_size_h: " << input_size_h; return false; } if (output_size_w % input_size_w != 0) { LOGS(logger, VERBOSE) << "Resize: output_size_w: " << output_size_w - << " is not a mutliple of input_size_w: " << input_size_w; + << " is not a multiple of input_size_w: " << input_size_w; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc index f6a61d55a3d63..831c4cf4d08ba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -14,13 +15,13 @@ namespace coreml { class TransposeOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - NodeAttrHelper helper(node); std::vector perm = helper.Get("perm", std::vector()); std::vector input_shape; @@ -33,12 +34,27 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(perm.size() == input_dims, "Perm and input should have same dimension"); } - *layer->mutable_transpose()->mutable_axes() = {perm.cbegin(), perm.cend()}; +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::unique_ptr op = model_builder.CreateOperation(node, "transpose"); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*op, "perm", model_builder.AddConstant(op->type(), "perm", perm)); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + *layer->mutable_transpose()->mutable_axes() = {perm.cbegin(), perm.cend()}; - model_builder.AddLayer(std::move(layer)); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index eb4723a3b9746..eec0fcce51dbc 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -838,13 +838,8 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (create_ml_program_) { if (is_input) { // the model inputs need to be wired up as args to the 'main' function. - auto tensor_value_type = CreateNamedTensorValueType(node_arg); + auto tensor_value_type = CreateNamedTensorValueType(node_arg, /*convert_scalar*/ true); tensor_value_type.set_name(name); - if (node_arg.Shape()->dim_size() == 0) { - // update shape from {} to {1} (same change we made at the model input level above). - tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1); - tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1); - } mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type)); } else { @@ -911,6 +906,7 @@ Status ModelBuilder::SaveModel() { #if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { + // we need to jump through some hoops to get the model path the ML Program load wants. std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel"; CreateEmptyFile(tmp_model_path); diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py index a3ceee70684dc..dce98e5138d98 100644 --- a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py +++ b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py @@ -5,6 +5,11 @@ if len(sys.argv) < 2: print(f"Usage: {sys.argv[0]} ") print("If generated by onnxruntime this will be /Data/com.microsoft.onnxruntime/model.mlmodel") + print( + "The ML Package created by the CoreML EP can saved to a specific directory in a debug build of onnxruntime " + "by setting the environment variable ORT_COREML_EP_MODEL_DIR to the desired directory." + ) + sys.exit(-1) model_path = sys.argv[1] @@ -13,7 +18,9 @@ spec = m.get_spec() print(spec) -# Example code if you want to filter output or do more advanced things +# Example code if you want to filter output or do more advanced things. +# In the below example we print out the value of an attribute of one specific node from a larger model. +# # main = spec.mlProgram.functions["main"] # block = main.block_specializations[main.opset] # print(f"{len(block.operations)} operators") diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index e3cd43d786fc3..c4c3b38bba516 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -8,7 +8,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/platform/ort_mutex.h" diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 3edcdb3f95e46..1d506099b4367 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -13,7 +13,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" diff --git a/onnxruntime/core/providers/coreml/shape_utils.h b/onnxruntime/core/providers/coreml/shape_utils.h index 0a1fd47cfdfe4..23ee51af63d48 100644 --- a/onnxruntime/core/providers/coreml/shape_utils.h +++ b/onnxruntime/core/providers/coreml/shape_utils.h @@ -7,7 +7,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/graph/node_arg.h" diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index 9837aabe786cd..c65dd2a04bf55 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -23,7 +23,7 @@ #include "core/framework/TensorSeq.h" #include "core/providers/utils.h" -#include "core/common/gsl.h" +#include #ifdef _MSC_VER #pragma warning(pop) diff --git a/onnxruntime/core/providers/cpu/controlflow/scan.h b/onnxruntime/core/providers/cpu/controlflow/scan.h index 76c27827b99a1..8516fa786da31 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan.h +++ b/onnxruntime/core/providers/cpu/controlflow/scan.h @@ -3,7 +3,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #ifndef SHARED_PROVIDER #include "core/common/common.h" diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_8.cc b/onnxruntime/core/providers/cpu/controlflow/scan_8.cc index 67853790415e9..cea3217578129 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_8.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_8.cc @@ -246,9 +246,9 @@ Status Scan8Impl::ValidateSubgraphInput(int start_input, int end_input, bool is_ auto this_batch_size = input_shape[0]; - if (batch_size_ < 0) + if (batch_size_ < 0) { batch_size_ = this_batch_size; - else { + } else { if (batch_size_ != this_batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Scan inputs have inconsistent batch size. Previous value was ", batch_size_, " but ", graph_inputs[i]->Name(), " has batch size of ", @@ -263,7 +263,8 @@ Status Scan8Impl::ValidateSubgraphInput(int start_input, int end_input, bool is_ max_sequence_len_ = this_seq_len; } else { if (max_sequence_len_ != this_seq_len) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Scan inputs have inconsistent sequence lengths. Previous value was ", + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Scan inputs have inconsistent sequence lengths. Previous value was ", max_sequence_len_, " but ", graph_inputs[i]->Name(), " has length of ", this_seq_len); } @@ -396,13 +397,15 @@ Status Scan8Impl::Execute(const FeedsFetchesManager& ffm) { // Setup input OrtValue streams std::vector::Iterator> scan_input_stream_iterators; - scan_input_stream_iterators.reserve(static_cast(info_.num_variadic_inputs) - info_.num_loop_state_variables); + scan_input_stream_iterators.reserve(static_cast(info_.num_variadic_inputs) - + info_.num_loop_state_variables); for (int i = info_.num_loop_state_variables, end = info_.num_variadic_inputs; i < end; ++i) { const auto& ort_value = GetSubgraphInputMLValue(context_, i); // forward - if (directions_[static_cast(i) - info_.num_loop_state_variables] == static_cast(ScanDirection::kForward)) { + if (directions_[static_cast(i) - info_.num_loop_state_variables] == + static_cast(ScanDirection::kForward)) { // the iterator is self contained, so we don't need to keep the OrtValueTensorSlicer instance around scan_input_stream_iterators.push_back(device_helpers_.create_const_slicer_func(ort_value, 1, b).begin()); } else { // reverse @@ -417,8 +420,10 @@ Status Scan8Impl::Execute(const FeedsFetchesManager& ffm) { } // Call the subgraph for each item in the sequence - status = IterateSequence(context_, session_state_, batch_loop_state_variables[onnxruntime::narrow(b)], scan_input_stream_iterators, - sequence_len, info_.num_loop_state_variables, info_.num_variadic_inputs, info_.num_outputs, + status = IterateSequence(context_, session_state_, batch_loop_state_variables[onnxruntime::narrow(b)], + scan_input_stream_iterators, + sequence_len, info_.num_loop_state_variables, info_.num_variadic_inputs, + info_.num_outputs, implicit_inputs_, output_iterators_, ffm); // zero out any remaining values in the sequence diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_9.cc b/onnxruntime/core/providers/cpu/controlflow/scan_9.cc index f7548fbf60503..24d233c0594fa 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_9.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_9.cc @@ -21,7 +21,7 @@ #include "core/providers/cpu/tensor/utils.h" #include "core/providers/cpu/tensor/transpose.h" -#include "core/common/gsl.h" +#include #ifdef _MSC_VER #pragma warning(pop) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index c4a83efa01a91..fd7b19dea724d 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -192,6 +192,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } #ifndef DISABLE_CONTRIB_OPS Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) override { @@ -294,12 +299,6 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } - void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, - gsl::span input_dims, - InlinedVector& scales) const override { - p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); - } - #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index c0e674827e4d1..840d6f8e3e7aa 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -141,7 +141,9 @@ struct ProviderHostCPU { virtual Status Scan__Compute(const Scan<9>* p, OpKernelContext* ctx) = 0; virtual Status Scan__SetupSubgraphExecutionInfo(Scan<8>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; virtual Status Scan__SetupSubgraphExecutionInfo(Scan<9>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; - + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; #ifndef DISABLE_CONTRIB_OPS virtual Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) = 0; virtual Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) = 0; @@ -203,10 +205,6 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; - virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, - gsl::span input_dims, - InlinedVector& scales) const = 0; - #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 2ff6549794ff4..8a0390fe7af8c 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -4,7 +4,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/core/providers/cpu/math/hardmax.h b/onnxruntime/core/providers/cpu/math/hardmax.h index 02b9b96fd3bfd..1b77a30a164e6 100644 --- a/onnxruntime/core/providers/cpu/math/hardmax.h +++ b/onnxruntime/core/providers/cpu/math/hardmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/core/providers/cpu/math/sign.cc b/onnxruntime/core/providers/cpu/math/sign.cc index 60080135bbd26..1d3b444c83b6d 100644 --- a/onnxruntime/core/providers/cpu/math/sign.cc +++ b/onnxruntime/core/providers/cpu/math/sign.cc @@ -3,7 +3,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/data_types.h" diff --git a/onnxruntime/core/providers/cpu/math/softmax.h b/onnxruntime/core/providers/cpu/math/softmax.h index 448a97bfbe0a4..cac674b42945e 100644 --- a/onnxruntime/core/providers/cpu/math/softmax.h +++ b/onnxruntime/core/providers/cpu/math/softmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 0b6c35ffabb1b..cae20b42725b8 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -22,7 +22,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/util/math.h" #include "core/util/math_cpuonly.h" diff --git a/onnxruntime/core/providers/cpu/ml/cast_map.cc b/onnxruntime/core/providers/cpu/ml/cast_map.cc index 8dcc3393f581f..f21eb99dd64e2 100644 --- a/onnxruntime/core/providers/cpu/ml/cast_map.cc +++ b/onnxruntime/core/providers/cpu/ml/cast_map.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/cast_map.h" #include -#include "core/common/gsl.h" +#include using namespace ::onnxruntime::common; namespace { diff --git a/onnxruntime/core/providers/cpu/ml/category_mapper.cc b/onnxruntime/core/providers/cpu/ml/category_mapper.cc index f56e98a286cfe..88d83b2a2d33f 100644 --- a/onnxruntime/core/providers/cpu/ml/category_mapper.cc +++ b/onnxruntime/core/providers/cpu/ml/category_mapper.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/category_mapper.h" #include -#include "core/common/gsl.h" +#include using namespace ::onnxruntime::common; namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc b/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc index 24fbfac59473d..6e46b2279f7dd 100644 --- a/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc +++ b/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/feature_vectorizer.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace ml { diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.cc b/onnxruntime/core/providers/cpu/ml/label_encoder.cc index 65102b62a963b..67f38638d8dab 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.cc +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/label_encoder.h" #include -#include "core/common/gsl.h" +#include using namespace ::onnxruntime::common; namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.h b/onnxruntime/core/providers/cpu/ml/label_encoder.h index 0f9f7cfb5dba6..f7a454cf519e1 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.h +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once - +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/providers/cpu/ml/ml_common.h" @@ -123,7 +123,7 @@ std::vector GetAttribute(const OpKernelInfo& info, const std::string& name, c } const SafeInt tensor_size(element_count); std::vector out(tensor_size); - result = utils::UnpackTensor(attr_tensor_proto, Path(), out.data(), tensor_size); + result = utils::UnpackTensor(attr_tensor_proto, std::filesystem::path(), out.data(), tensor_size); ORT_ENFORCE(result.IsOK(), "LabelEncoder could not unpack tensor attribute ", name); return out; } @@ -134,7 +134,7 @@ T GetDefault(const OpKernelInfo& info, const std::string& attr_name, const T& ba auto result = info.GetAttr("default_tensor", &attr_tensor_proto); if (result.IsOK() && utils::HasDataType(attr_tensor_proto)) { T default_value; - result = utils::UnpackTensor(attr_tensor_proto, Path(), &default_value, 1); + result = utils::UnpackTensor(attr_tensor_proto, std::filesystem::path(), &default_value, 1); ORT_ENFORCE(result.IsOK(), "LabelEncoder could not unpack default tensor ", attr_name); return default_value; } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { diff --git a/onnxruntime/core/providers/cpu/nn/flatten.h b/onnxruntime/core/providers/cpu/nn/flatten.h index 45fb644d2a0e8..b776d5e28d578 100644 --- a/onnxruntime/core/providers/cpu/nn/flatten.h +++ b/onnxruntime/core/providers/cpu/nn/flatten.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/common/gsl.h" +#include #include "core/providers/cpu/tensor/utils.h" #include "core/providers/common.h" diff --git a/onnxruntime/core/providers/cpu/nn/lrn.h b/onnxruntime/core/providers/cpu/nn/lrn.h index e797ffda87f7f..dc27672aa056d 100644 --- a/onnxruntime/core/providers/cpu/nn/lrn.h +++ b/onnxruntime/core/providers/cpu/nn/lrn.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/exceptions.h" diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index dfc7a2b68699c..6d54c24b3808b 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -18,7 +18,7 @@ #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace rnn { diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index cbc4d8360d4b1..6742bab4fa4a2 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/narrow.h" diff --git a/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc b/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc index 62c3dbfc87a75..6eef6b35a8fda 100644 --- a/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc +++ b/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/common.h" #include "core/providers/cpu/tensor/transpose.h" #include "core/util/math_cpuonly.h" diff --git a/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc b/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc index a03d1143d058e..31ce89b3ed556 100644 --- a/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc +++ b/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc @@ -11,7 +11,7 @@ #pragma warning(disable : 4996) #endif -#include "core/common/gsl.h" +#include #ifdef _MSC_VER #pragma warning(pop) diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.h b/onnxruntime/core/providers/cpu/tensor/shape_op.h index b9e9389950193..05d22595dd838 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.h +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.h @@ -9,7 +9,7 @@ #include "core/framework/op_kernel.h" #endif -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h b/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h index 05f1479315fe8..c21cc13ad3160 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h +++ b/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/tensor_shape.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc index 724814e2f1d5b..4e43085fe288b 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.cc +++ b/onnxruntime/core/providers/cpu/tensor/split.cc @@ -4,7 +4,7 @@ #include "core/providers/cpu/tensor/split.h" #include "core/common/narrow.h" -#include "core/common/gsl.h" +#include #include "core/common/safeint.h" #include "core/framework/copy.h" #include "core/framework/element_type_lists.h" diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index fda41c28a2567..54d3584ba0dad 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -10,7 +10,7 @@ #include "core/framework/op_kernel.h" #endif -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/unique.cc b/onnxruntime/core/providers/cpu/tensor/unique.cc index 135bef0860cab..ab99d87da83fd 100644 --- a/onnxruntime/core/providers/cpu/tensor/unique.cc +++ b/onnxruntime/core/providers/cpu/tensor/unique.cc @@ -4,7 +4,7 @@ #include "core/providers/cpu/tensor/unique.h" #include #include -#include "core/common/gsl.h" +#include #include "core/framework/op_kernel_type_control_utils.h" #include "core/providers/common.h" #include "core/providers/op_kernel_type_control.h" diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index 17eac5417f0a3..6adcfec852690 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/common/gsl.h" +#include #include "core/common/narrow.h" #ifndef SHARED_PROVIDER diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 61da125b40953..0b56cac1038e4 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -16,7 +16,7 @@ #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fast_divmod.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 103c79c93b2ca..7851da7fa91a3 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -10,7 +10,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_execution_provider_info.h" diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 3c0bf183362dd..58e57572131b1 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -81,6 +81,9 @@ CudaStream::CudaStream(cudaStream_t stream, cudnn_handle_ = external_cudnn_handle; CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); } +#else + (void)(external_cudnn_handle); + (void)(external_cublas_handle); #endif } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 9aa011c1d0ec4..31fc63a86d640 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -6,7 +6,7 @@ #include "core/providers/cuda/cudnn_common.h" #include "core/common/inlined_containers.h" -#include "core/common/gsl.h" +#include #include "shared_inc/cuda_call.h" #include "core/providers/cpu/tensor/utils.h" #ifndef USE_CUDA_MINIMAL @@ -174,7 +174,11 @@ cudnnDataType_t CudnnTensor::GetDataType() { template <> cudnnDataType_t CudnnTensor::GetDataType() { +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200 + return CUDNN_DATA_BFLOAT16; +#else ORT_THROW("cuDNN doesn't support BFloat16."); +#endif } template <> diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index bbe63e66e67db..6f4016b655c96 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/multi_tensor/common.cuh b/onnxruntime/core/providers/cuda/multi_tensor/common.cuh index 2d7928776bb8a..49a0a9c51472f 100644 --- a/onnxruntime/core/providers/cuda/multi_tensor/common.cuh +++ b/onnxruntime/core/providers/cuda/multi_tensor/common.cuh @@ -10,7 +10,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 2bc937526f94c..7b827f8d04593 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/gru.h b/onnxruntime/core/providers/cuda/rnn/gru.h index 6f5c5ab6e956f..e5ea1ed3e670e 100644 --- a/onnxruntime/core/providers/cuda/rnn/gru.h +++ b/onnxruntime/core/providers/cuda/rnn/gru.h @@ -4,7 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index 2df0a38d22f57..1f7df9b6fc2e3 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -11,7 +11,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/float16.h" #include "core/providers/cuda/shared_inc/fast_divmod.h" diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc similarity index 100% rename from onnxruntime/contrib_ops/cuda/grid_sample.cc rename to onnxruntime/core/providers/cuda/tensor/grid_sample.cc diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/core/providers/cuda/tensor/grid_sample.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/grid_sample.h rename to onnxruntime/core/providers/cuda/tensor/grid_sample.h diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu similarity index 100% rename from onnxruntime/contrib_ops/cuda/grid_sample_impl.cu rename to onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h similarity index 100% rename from onnxruntime/contrib_ops/cuda/grid_sample_impl.h rename to onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index 5f73512ab8696..52775b2e8be7a 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -76,6 +76,33 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { auto input_dims = input_shape.GetDims(); auto output_dimensions{input_shape.AsShapeVector()}; +#ifndef USE_ROCM + if (split_sizes.size() == 3 && ((axis + 1) == gsl::narrow_cast(input_shape.NumDimensions()))) { + // we use (axis + 1) == num_dimensions to check if we are splitting on inner most axis. + // only when split on inner axis and output size is 3, we can use Split3Inner. + // this kernel is not using pin_memory, so it is ok for using cuda graph. + output_dimensions[axis] = split_sizes[0]; + Tensor* output0 = ctx->Output(0, TensorShape{output_dimensions}); + output_dimensions[axis] = split_sizes[1]; + Tensor* output1 = ctx->Output(1, TensorShape{output_dimensions}); + output_dimensions[axis] = split_sizes[2]; + Tensor* output2 = ctx->Output(2, TensorShape{output_dimensions}); + + // if input tensor is empty, we don't need to launch kernel, but still need to set output tensor. + if (input_tensor->Shape().Size() <= 0) return Status::OK(); + + return Split3Inner(Stream(ctx), + input_tensor->DataType()->Size(), + split_sizes[0], split_sizes[1], + split_sizes[2], + input_tensor->DataRaw(), + output0->MutableDataRaw(), + output1->MutableDataRaw(), + output2->MutableDataRaw(), + input_dims); + } +#endif + CudaAsyncBuffer output_ptr(this, num_outputs); gsl::span output_ptr_span = output_ptr.CpuSpan(); TensorShapeVector axis_dimension_input_output_mapping(input_dims[axis]); diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index b0ff856a43970..e2f42e4d5855c 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -157,5 +157,114 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block return Status::OK(); } +#ifndef USE_ROCM +template +__global__ void _Split3InnerKernel(const int64_t size0_in_byte, + const int64_t size1_in_byte, + const int64_t size2_in_byte, + const void* input_data, + void* output_data0, + void* output_data1, + void* output_data2, + const int64_t inner_size_in_byte) { + // each block copy one row of input data + auto size0 = size0_in_byte / sizeof(T); + auto size1 = size1_in_byte / sizeof(T); + auto size2 = size2_in_byte / sizeof(T); + auto inner_size = inner_size_in_byte / sizeof(T); + auto output0_vec = reinterpret_cast(output_data0) + blockIdx.x * size0; + auto output1_vec = reinterpret_cast(output_data1) + blockIdx.x * size1; + auto output2_vec = reinterpret_cast(output_data2) + blockIdx.x * size2; + auto input_vec = reinterpret_cast(input_data) + blockIdx.x * inner_size; + // all size and pointer are aligned to sizeof(T) + // so here use all threads in the block to do vectorized copy + + for (auto tid = threadIdx.x; tid < inner_size; tid += blockDim.x) { + auto data = input_vec[tid]; + if (tid < size0) { + output0_vec[tid] = data; + } else if (tid < (size0 + size1)) { + output1_vec[tid - size0] = data; + } else { + output2_vec[tid - size0 - size1] = data; + } + } +} + +Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1, + const int64_t size2, const void* input_data, void* output_data0, void* output_data1, + void* output_data2, const gsl::span& input_shape) { + CUDA_LONG outer_size = 1; + for (size_t i = 0; i < input_shape.size() - 1; ++i) { + outer_size *= static_cast(input_shape[i]); + } + CUDA_LONG inner_size_in_byte = static_cast(input_shape[input_shape.size() - 1] * element_size); + + auto select = [](size_t value) { + if (value % 16 == 0) { + return 16; + } else if (value % 8 == 0) { + return 8; + } else if (value % 4 == 0) { + return 4; + } else if (value % 2 == 0) { + return 2; + } else { + return 1; + } + }; + + auto input_v = reinterpret_cast(input_data); + auto output_v0 = reinterpret_cast(output_data0); + auto output_v1 = reinterpret_cast(output_data1); + auto output_v2 = reinterpret_cast(output_data2); + auto size0_in_byte = size0 * element_size; + auto size1_in_byte = size1 * element_size; + auto size2_in_byte = size2 * element_size; + + auto VEC_SIZE = std::min(select(size0_in_byte), std::min(select(size1_in_byte), select(size2_in_byte))); + auto min_output_vec_size = std::min(select(output_v0), std::min(select(output_v1), select(output_v2))); + VEC_SIZE = std::min(VEC_SIZE, std::min(select(input_v), min_output_vec_size)); + + // determine threads based on the size of the output + auto threadsPerBlock = kNumThreadsPerBlock; + if ((inner_size_in_byte / VEC_SIZE) <= 128) { + // use less threads when the size is small + threadsPerBlock = 128; + } + + switch (VEC_SIZE) { +#define CASE_ELEMENT_TYPE(type) \ + _Split3InnerKernel<<>>( \ + size0_in_byte, \ + size1_in_byte, \ + size2_in_byte, \ + input_data, \ + output_data0, \ + output_data1, \ + output_data2, \ + inner_size_in_byte) + case 16: + CASE_ELEMENT_TYPE(int4); + break; + case 8: + CASE_ELEMENT_TYPE(int64_t); + break; + case 4: + CASE_ELEMENT_TYPE(int32_t); + break; + case 2: + CASE_ELEMENT_TYPE(int16_t); + break; + default: + CASE_ELEMENT_TYPE(int8_t); + break; +#undef CASE_ELEMENT_TYPE + } + + return Status::OK(); +} +#endif + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.h b/onnxruntime/core/providers/cuda/tensor/split_impl.h index 16961cfb7d22d..62e0da7716608 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.h @@ -19,5 +19,9 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data, void** output_data, const size_t input_size); +Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1, + const int64_t size2, const void* input_data, void* output_data0, void* output_data1, + void* output_data2, const gsl::span& input_shape); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.h b/onnxruntime/core/providers/cuda/tensor/transpose.h index da73d1308dbf7..31bb21078c1b1 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose.h @@ -4,7 +4,7 @@ #pragma once #include "core/providers/shared_library/provider_api.h" -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cpu/tensor/transpose.h" diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 17533eb3d9a72..cbf745d3c7b4f 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -230,7 +230,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, } const bool is_2D = X_dims.size() == 2; - const bool is_nchw = is_2D ? true : (scales[1] == 1.0f && scales[1] == 1.0f); + const bool is_nchw = is_2D ? true : (scales[0] == 1.0f && scales[1] == 1.0f); ORT_RETURN_IF_NOT(is_nchw, "Resize 'Cubic' mode only supports NCWH layout " diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp index ea66289c351ea..541254ffaf7f0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp @@ -32,7 +32,7 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp }; } -bool IsSigned(DML_TENSOR_DATA_TYPE dataType) +bool IsSigned(DML_TENSOR_DATA_TYPE dataType) noexcept { switch (dataType) { @@ -140,7 +140,33 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice) return deviceTypeMask; } -void GetDescendingPackedStrides(gsl::span sizes, /*out*/ gsl::span strides) +uint32_t GetBitMaskFromIndices(gsl::span indices) noexcept +{ + uint32_t bitMask = 0; + for (auto i : indices) + { + assert(i < 32); + bitMask |= (1 << i); + } + return bitMask; +} + +uint32_t CountLeastSignificantZeros(uint32_t value) noexcept +{ + // *Use std::countr_zero instead when codebase updated to C++20. + // Use bit twiddling hack rather than for loop. + uint32_t count = 32; + value &= -int32_t(value); + if (value) count--; + if (value & 0x0000FFFF) count -= 16; + if (value & 0x00FF00FF) count -= 8; + if (value & 0x0F0F0F0F) count -= 4; + if (value & 0x33333333) count -= 2; + if (value & 0x55555555) count -= 1; + return count; +} + +void GetDescendingPackedStrides(gsl::span sizes, /*out*/ gsl::span strides) noexcept { assert(sizes.size() == strides.size()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h index c4d260b9736df..5cd3fd0aea72c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h @@ -23,9 +23,11 @@ namespace Dml size_t ComputeByteSizeFromDimensions(gsl::span dimensions, MLOperatorTensorDataType tensorDataType); size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor); uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice); - void GetDescendingPackedStrides(gsl::span sizes, /*out*/ gsl::span strides); + uint32_t GetBitMaskFromIndices(gsl::span indices) noexcept; + uint32_t CountLeastSignificantZeros(uint32_t value) noexcept; + void GetDescendingPackedStrides(gsl::span sizes, /*out*/ gsl::span strides) noexcept; - bool IsSigned(DML_TENSOR_DATA_TYPE dataType); + bool IsSigned(DML_TENSOR_DATA_TYPE dataType) noexcept; template void CastToClampedScalarUnion(DML_TENSOR_DATA_TYPE dataType, T value, DML_SCALAR_UNION* outputValue) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 10b8b7fe42f86..2f110ba339beb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -20,7 +20,7 @@ namespace Dml DmlRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& kernelInfo, std::shared_ptr indexedSubGraph, - const onnxruntime::Path& modelPath, + const std::filesystem::path& modelPath, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, @@ -314,7 +314,7 @@ namespace Dml mutable std::optional m_persistentResourceBinding; std::shared_ptr m_indexedSubGraph; - const onnxruntime::Path& m_modelPath; + const std::filesystem::path& m_modelPath; std::vector> m_subgraphNodes; std::vector m_subgraphInputs; @@ -341,7 +341,7 @@ namespace Dml onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& info, std::shared_ptr indexedSubGraph, - const onnxruntime::Path& modelPath, + const std::filesystem::path& modelPath, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h index d679c5aa5667c..e800175268557 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#include #include "core/framework/op_kernel.h" #include "GraphDescBuilder.h" #include "DmlRuntimeGraphFusionTransformer.h" @@ -10,7 +10,7 @@ namespace Dml onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& info, std::shared_ptr indexedSubGraph, - const onnxruntime::Path& modelPath, + const std::filesystem::path& modelPath, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 2bd9377e4c2fa..387767f821b3e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -232,7 +232,7 @@ namespace Dml::GraphDescBuilder const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, const ExecutionProviderImpl* executionHandle, - const onnxruntime::Path& modelPath, + const std::filesystem::path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, gsl::span subgraphOutputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 4055984b40405..3f778b3a7feba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include "MLOperatorAuthorImpl.h" #include "ExecutionProvider.h" @@ -41,7 +42,7 @@ namespace Dml const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, const ExecutionProviderImpl* executionHandle, - const onnxruntime::Path& modelPath, + const std::filesystem::path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, gsl::span subgraphOutputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index f29fbc7a1a65b..0a2a5bbcbedaf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -842,7 +842,7 @@ namespace Windows::AI::MachineLearning::Adapter const onnx::TensorProto* tensorProto = &attributeProto->t(); // An empty path is used as external weights are not currently supported in this case - Microsoft::WRL::ComPtr tensorWrapper = wil::MakeOrThrow(const_cast(tensorProto), onnxruntime::Path()); + Microsoft::WRL::ComPtr tensorWrapper = wil::MakeOrThrow(const_cast(tensorProto), std::filesystem::path()); *tensor = tensorWrapper.Detach(); return S_OK; } @@ -1545,7 +1545,7 @@ namespace Windows::AI::MachineLearning::Adapter ORT_CATCH_RETURN } - OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl) + OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const std::filesystem::path& modelPath) : m_impl(impl) { // The tensor may be stored as raw data or in typed fields. if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) @@ -2826,7 +2826,7 @@ namespace Windows::AI::MachineLearning::Adapter { // An empty path is used as external weights are not currently supported in this case Microsoft::WRL::ComPtr tensorWrapper = wil::MakeOrThrow( - const_cast(ctx->getInputData(index)), onnxruntime::Path()); + const_cast(ctx->getInputData(index)), std::filesystem::path()); return tensorWrapper; } ); @@ -3018,7 +3018,7 @@ namespace Windows::AI::MachineLearning::Adapter std::tuple, size_t> UnpackTensor( const onnx::TensorProto& initializer, - const onnxruntime::Path& modelPath) + const std::filesystem::path& modelPath) { std::unique_ptr unpackedTensor; size_t tensorByteSize = 0; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 59e253e88457a..7e51ce026d365 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -2,6 +2,8 @@ // Licensed under the MIT License. #pragma once +#include + #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" #include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" @@ -283,7 +285,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable public: OnnxTensorWrapper() = default; - OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath); + OnnxTensorWrapper(onnx::TensorProto* impl, const std::filesystem::path& modelPath); uint32_t STDMETHODCALLTYPE GetDimensionCount() const noexcept override; @@ -681,5 +683,5 @@ bool TryGetStaticInputShapes(const onnxruntime::Node& node, EdgeShapes& inputSha bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputShapes); bool ContainsEmptyDimensions(const EdgeShapes& shapes, gsl::span ignoredShapeIndices = gsl::span()); -std::tuple, size_t> UnpackTensor(const onnx::TensorProto& initializer, const onnxruntime::Path& modelPath); +std::tuple, size_t> UnpackTensor(const onnx::TensorProto& initializer, const std::filesystem::path& modelPath); } // namespace Windows::AI::MachineLearning::Adapter diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp index c6a87da705a99..32d6af73aae8d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp @@ -40,16 +40,32 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0) ); } + MLOperatorTensorDataType ADatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::A).tensorDataType; MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType; + gsl::span outputSizes = m_outputTensorDescs[0].GetSizes(); std::vector ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A); - std::vector ExpectedAScaleTensorShape = {1, 1, 1, 1}; - std::vector ExpectedAZeroPointTensorShape = {1, 1, 1, 1}; + std::vector BTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::B); + std::vector ExpectedAScaleTensorShape(outputSizes.size(), 1); + std::vector ExpectedAZeroPointTensorShape(outputSizes.size(), 1); + ML_CHECK_VALID_ARGUMENT(outputSizes.size() >= 4); + ML_CHECK_VALID_ARGUMENT(ATensorShape.size() >= 2); + ML_CHECK_VALID_ARGUMENT(BTensorShape.size() >= 2); + ML_CHECK_VALID_ARGUMENT(ATensorShape.size() + 2 >= outputSizes.size()); + ML_CHECK_VALID_ARGUMENT(BTensorShape.size() + 2 >= outputSizes.size()); + std::vector AShapeBroadcasted(outputSizes.begin(), outputSizes.end()); + std::copy(ATensorShape.end() - (outputSizes.size() - 2), + ATensorShape.end(), + AShapeBroadcasted.begin() + 2); + std::vector BShapeBroadcasted(outputSizes.begin(), outputSizes.end()); + std::copy(BTensorShape.end() - (outputSizes.size() - 2), + BTensorShape.end(), + BShapeBroadcasted.begin() + 2); // output edges between DynQL and MMItoFloat node TensorDesc intermediateQuantizedATensorDesc = TensorDesc( BDatatype, - gsl::make_span(ATensorShape), + gsl::make_span(AShapeBroadcasted), gsl::make_span(ATensorShape), TensorAxis::DoNotCoerce, TensorAxis::W, @@ -80,6 +96,30 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator 0 // guaranteedBaseOffsetAlignment ); + TensorDesc broadcastedATensorDesc = TensorDesc( + ADatatype, + AShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting). + ATensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'. + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + TensorDesc broadcastedBTensorDesc = TensorDesc( + BDatatype, + BShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting). + BTensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'. + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + DML_TENSOR_DESC namedBroadcastedATensorDesc = broadcastedATensorDesc.GetDmlDesc(); + DML_TENSOR_DESC namedBroadcastedBTensorDesc = broadcastedBTensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc(); @@ -88,7 +128,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator std::vector outputDescs = GetDmlOutputDescs(); DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {}; - dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A]; + dynamicQuantizeLinearOperatorDesc.InputTensor = &namedBroadcastedATensorDesc; dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc; dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc; dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc; @@ -99,7 +139,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor; matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor; matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor; - matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B]; + matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &namedBroadcastedBTensorDesc; matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale]; matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr; matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp index d5bf54de53c30..51b19603f5122 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp @@ -3,6 +3,70 @@ #include "precomp.h" +// With a single equation, the Einstein summation operator can represent a variety of operators including: matmul, +// summation, transposition, diagonal slice, diagonal sum (trace), inner (dot) product, outer product... +// +// Parameters NumPy equivalent Description +// ------------------------------------------------------------------------------------------------------------- +// ('i', A1) A1 returns a view of A1 +// ('i->', A1) sum(A1) sums the values of A1 +// ('i,i->i', A1, B1) A1 * B1 element-wise multiplication of A1 and B1 +// ('i,i->', A1, B1) inner(A1, B1) or dot(A1, B1) inner product of A1 and B1 +// ('i,i', A1, B1) inner(A1, B1) or dot(A1, B1) inner product of A1 and B1 +// ('i,j->ij', A1, B1) outer(A1, B1) outer product of A1 and B1 +// ('ij->ij', A2) A2 returns a view of A2 +// ('ij', A2) A2 returns a view of A2 +// ('ji', A2) A2.T view transpose of A2 +// ('ji->ij', A2) A2.T view transpose of A2 +// ('ii->i', A2) diag(A2) view main diagonal of A2 +// ('ii->', A2) trace(A2) sums main diagonal of A2 +// ('ij->', A2) sum(A2) sums the values of A2 +// ('ij->j', A2) sum(A2, axis=0) sum down the columns of A2 (across rows) +// ('ij->i', A2) sum(A2, axis=1) sum horizontally along the rows of A2 +// ('ij,ij->ij', A2, B2) A2 * B2 element-wise multiplication of A2 and B2 +// ('ij,ji->ij', A2, B2) A2 * B2.transpose() element-wise multiplication of A2 and B2.T +// ('ij,jk', A2, B2) matmul(A2, B2) or dot(A2, B2) matrix multiplication of A2 and B2 +// ('ij,jk->ik', A2, B2) matmul(A2, B2) or dot(A2, B2) matrix multiplication of A2 and B2 +// ('bij,bjk->bik', A2, B2) matmul(A3, B3) matrix multiplication of A3 and B3 (a stack of 2D matrices) +// ('bij,bkj->bik', A2, B2) matmul(A3, transpose(B3)) matrix multiplication of A3 and B3 (a stack of 2D matrices) +// ('ij,kj->ik', A2, B2) inner(A2, B2) inner product of A2 and B2 +// ('ij,kj->ikj', A2, B2) A2[:, None] * B2 each row of A2 multiplied by B2 +// ('ij,kl->ijkl', A2, B2) A2[:, :, None, None] * B2 each value of A2 multiplied by B2 +// (',ij', 3, B2) Scalar times array: array([[ 0, 3, 6], [ 9, 12, 15]]) +// ("ij,j", A2, B1) matvec(A2, B1) Matrix and vector. +// ("ii,ii->i", A2, B2) A2.diag() * B2.diag() diagonals multiplied by each other +// ("ii,ii->", A2, B2) dot(A2.diag(), B2.diag()) dot product of diagonals +// +// Decomposition: +// +// Ultimately though EinSum is equivalent to an elementwise multiplication into an internal product tensor +// (given a helper function to reproject all inputs so they're shape-compatible) followed by sum reduction. +// +// 1. Determine the size of the internal product tensor by concatenating the dimensions of all inputs, +// counting each unique label once. So "bij,bjk->bik" would yield an internal product of shape [b,i,j,k]. +// 2. Project each input tensor as needed to the internal product shape (transposing and/or broadcasting). +// So an input of shape [b,i] with product shape of [b,j,i,k] would insert broadcasted j and k dimensions. +// An input of shape [a,b,c] with product shape of [b,c,a] would require a transpose. +// The input shape [a,b,a] with product shape of [a,b] would collapse the first two input 'a' dimensions. +// 3. Multiply elementwise every input tensor to compute the internal product. +// 4. Sum reduce the product tensor to the final output shape, reducing along any missing dimensions. +// So a product shape of [b,j,i,k] and output shape of [b,i,k] reduces along j. +// +// ReduceSum( +// Mul( +// ExpandTransposeCollapseAsNeeded(A, aAxesToProductAxes), +// ExpandTransposeCollapseAsNeeded(B, bAxesToProductAxes), +// ), +// reductionAxes, +// keepdims=false +// ) +// +// Notes: +// +// - DirectML has no direct EinSum operator, but common cases map to existing operators. +// - EinSum can accept a variable number of input tensors, but the DML EP only supports a limited count +// (falling back to CPU otherwise). + namespace Dml { @@ -13,30 +77,36 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper : DmlOperator(kernelCreationContext), EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) { - ML_CHECK_VALID_ARGUMENT(static_cast(kernelCreationContext.GetInputCount()) + 1 == m_components.size(), - "EinSum input tensor count is inconsistent with the equation component count."); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1, "EinSum expects at least one input tensor."); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "EinSum expects one output tensor."); + ML_CHECK_VALID_ARGUMENT( + static_cast(kernelCreationContext.GetInputCount()) + 1 == m_components.size(), + "EinSum input tensor count is inconsistent with the equation component count." + ); + assert(m_recognizedOperatorType != RecognizedOperatorType::None && "Unrecognized EinSum operators should have fallen back to CPU"); std::vector> inputIndices = {0,1,2}; std::vector> outputIndices = {0}; uint32_t bindableInputCount = kernelCreationContext.GetInputCount(); if (IsMatMulOperatorType()) { - ++bindableInputCount; // Account for the optional C tensor. + ++bindableInputCount; // Account for the optional C tensor. } inputIndices.resize(bindableInputCount); - constexpr uint32_t dimCount = 2; - DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, dimCount); + uint32_t minimumDimensionCount = 1; + DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, minimumDimensionCount); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - static_assert(RecognizedOperatorType::Total == static_cast(12), "Update this switch."); + static_assert(RecognizedOperatorType::Total == static_cast(6), "Update this switch statement."); switch (m_recognizedOperatorType) { case RecognizedOperatorType::Multiply: { + ReprojectTensorDescsToProductTensor(); + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC operatorDesc = {}; operatorDesc.ATensor = &inputDescs[0]; operatorDesc.BTensor = &inputDescs[1]; @@ -46,115 +116,53 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper } break; - case RecognizedOperatorType::OuterProduct: - { - std::array aSizes = {m_inputTensorDescs[0].GetSizes().back(), 1}; - TensorDesc aTensorDesc = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), aSizes); - auto aDmlTensorDesc = aTensorDesc.GetDmlDesc(); - - std::array bSizes = {1, m_inputTensorDescs[1].GetSizes().back()}; - TensorDesc bTensorDesc = TensorDesc(m_inputTensorDescs[1].GetDmlDataType(), bSizes); - auto bDmlTensorDesc = bTensorDesc.GetDmlDesc(); - - DML_GEMM_OPERATOR_DESC operatorDesc = {}; - operatorDesc.ATensor = &aDmlTensorDesc; - operatorDesc.BTensor = &bDmlTensorDesc; - operatorDesc.OutputTensor = &outputDescs[0]; - operatorDesc.Alpha = 1.0; - operatorDesc.Beta = 0.0; - operatorDesc.FusedActivation = nullptr; - - SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext); - } - break; - case RecognizedOperatorType::MatMul: - case RecognizedOperatorType::MatMulTransposeA: - case RecognizedOperatorType::MatMulTransposeB: { - DML_GEMM_OPERATOR_DESC operatorDesc = {}; - operatorDesc.ATensor = &inputDescs[0]; - operatorDesc.BTensor = &inputDescs[1]; - // No operatorDesc.CTensor - operatorDesc.OutputTensor = &outputDescs[0]; - operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; - operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; - operatorDesc.Alpha = 1.0; - operatorDesc.Beta = 0.0; - operatorDesc.FusedActivation = nullptr; - - SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext); - } - break; - case RecognizedOperatorType::MatMulNhcw: - case RecognizedOperatorType::MatMulNhcwTransposeA: - case RecognizedOperatorType::MatMulNhcwTransposeB: - { - // Transpose via input strides. The output tensor is not strided. Support only 4D for now. - assert(m_components.size() == 3); - assert(m_components[0].GetDimensionCount() == m_components[2].GetDimensionCount()); - assert(m_components[1].GetDimensionCount() == m_components[2].GetDimensionCount()); - assert(m_components[2].GetDimensionCount() == 4); - - // Remap transposed strides from NCHW to NHCW - constexpr std::array labelIndices = {0, 2, 1, 3}; - - assert(m_inputTensorDescs.size() >= 2); - for (uint32_t inputIndex = 0; inputIndex < 2; ++inputIndex) - { - TensorDesc& tensorDesc = m_inputTensorDescs[inputIndex]; - auto originalStrides = tensorDesc.GetStrides(); - std::vector inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(inputIndex); - std::vector inputStrides(inputSizes.size()); - - // If there were no strides, compute them based in descending packed order - // based on the input sizes. - if (originalStrides.empty()) - { - Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides); - } - else // Copy the original strides. - { - assert(originalStrides.size() >= inputStrides.size()); - size_t offset = originalStrides.size() - inputStrides.size(); - inputStrides.assign(originalStrides.begin() + offset, originalStrides.end()); - } - - std::vector newStrides(inputStrides.size()); - std::vector newSizes(inputStrides.size()); - for (size_t dim = 0, dimensionCount = inputStrides.size(); dim < dimensionCount; ++dim) - { - uint32_t labelIndex = labelIndices[dim]; - assert(labelIndex < inputStrides.size()); - newSizes[dim] = inputSizes[labelIndex]; - newStrides[dim] = inputStrides[labelIndex]; - } - - // Override the initial input tensor with the new strides. - tensorDesc = TensorDesc(tensorDesc.GetDmlDataType(), newSizes, newStrides, 0); - tensorDesc.GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. - } - - std::vector outputSizes = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0); - std::vector newOutputSizes(outputSizes.size()); - assert(outputSizes.size() == labelIndices.size()); - - for (size_t dim = 0; dim < outputSizes.size(); ++dim) + assert(m_components.size() == 3 && "EinSum matmul expects 2 inputs and 1 output"); + assert(m_productDimensions.size() - 1 <= 4 && "DML Einsum matmul handles up to 4D"); + + // Generate bitmasks for each of the active axes per tensor using their labels. + const auto input0Labels = m_components[0].GetLabels(m_labelIndices); + const auto input1Labels = m_components[1].GetLabels(m_labelIndices); + const auto outputLabels = m_components[2].GetLabels(m_labelIndices); + const uint32_t input0AxesMask = GetBitMaskFromIndices(input0Labels); + const uint32_t input1AxesMask = GetBitMaskFromIndices(input1Labels); + const uint32_t outputAxesMask = GetBitMaskFromIndices(outputLabels); + + // Find each of the interesting axes, including the one being reduced, height, width, batch, and channel. + // - the reduced axis is the term missing from the output. + // - height and width are the unique axes respectively found in only input A or input B. + // - the batch (if present) is the first axis shared by both inputs, and the channel is the subsequent common one. + // If any axis is not found (say it's a 2D GEMM), then the axis value will be beyond the rank, which is + // safely handled correctly during projection as an inserted axis. + + auto findAndClearAxis = [](uint32_t& currentAxesMask, uint32_t contraintAxesMask) -> uint32_t { - uint32_t labelIndex = labelIndices[dim]; - newOutputSizes[dim] = outputSizes[labelIndex]; - } - - m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newOutputSizes, std::nullopt, 0); - m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. + uint32_t foundAxis = CountLeastSignificantZeros(currentAxesMask & ~contraintAxesMask); + currentAxesMask &= ~(1 << foundAxis); + return foundAxis; + }; + + uint32_t remainingAxesMask = ~0u; + uint32_t reductionAxis = findAndClearAxis(/*inout*/ remainingAxesMask, outputAxesMask); + uint32_t heightAxis = findAndClearAxis(/*inout*/ remainingAxesMask, input1AxesMask); + uint32_t widthAxis = findAndClearAxis(/*inout*/ remainingAxesMask, input0AxesMask); + uint32_t batchAxis = findAndClearAxis(/*inout*/ remainingAxesMask, 0); + uint32_t channelAxis = findAndClearAxis(/*inout*/ remainingAxesMask, 0); + + // Reproject all inputs and the output to the needed order pattern for DML compatibility, + // which only accepts the rightmost axis as GEMM-reducible when TransB is true. + ReprojectTensorDescToGivenAxes(/*inout*/ m_inputTensorDescs[0], input0Labels, {{batchAxis, channelAxis, heightAxis, reductionAxis}}); + ReprojectTensorDescToGivenAxes(/*inout*/ m_inputTensorDescs[1], input1Labels, {{batchAxis, channelAxis, widthAxis, reductionAxis}}); + ReprojectTensorDescToGivenAxes(/*inout*/ m_outputTensorDescs[0], outputLabels, {{batchAxis, channelAxis, heightAxis, widthAxis}}); DML_GEMM_OPERATOR_DESC operatorDesc = {}; operatorDesc.ATensor = &inputDescs[0]; operatorDesc.BTensor = &inputDescs[1]; // No operatorDesc.CTensor operatorDesc.OutputTensor = &outputDescs[0]; - operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; - operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; + operatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; + operatorDesc.TransB = DML_MATRIX_TRANSFORM_TRANSPOSE; operatorDesc.Alpha = 1.0; operatorDesc.Beta = 0.0; operatorDesc.FusedActivation = nullptr; @@ -165,41 +173,10 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper case RecognizedOperatorType::ReduceSum: { - // Get how many axes are kept in the final output, either 0 or 1 supported - // meaning full reduction or partial with one dimension left. *It could be - // generalized to support any number of output dimensions, but it would need - // to accomodate for Transposition too if the output labels are reordered. - auto keptAxes = m_components.back().GetLabels(m_labelIndices); - assert(keptAxes.size() <= 1); - - // DML expects output rank to match input rank (as if ONNX ReduceSum keepdims=1). - // So replace the existing tensor description with the input sizes, except that - // reduced dimensions have size 1. - std::vector reducedAxes; - std::vector inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); - std::vector outputSizes = inputSizes; - - // Determine which axes are being reduced by taking the opposite of those kept. - uint32_t keptAxesMask = 0; - for (auto axis : keptAxes) - { - keptAxesMask |= (1 << axis); - } - for (uint32_t axis = 0, axisCount = static_cast(outputSizes.size()); axis < axisCount; ++axis) - { - if (~keptAxesMask & (1< reducedAxes = GetReductionAxes(); operatorDesc.InputTensor = inputDescs.data(); operatorDesc.OutputTensor = outputDescs.data(); operatorDesc.Function = DML_REDUCE_FUNCTION_SUM; @@ -211,48 +188,8 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper break; case RecognizedOperatorType::Transpose: - case RecognizedOperatorType::Identity: { - if (m_recognizedOperatorType == RecognizedOperatorType::Transpose) - { - // Transpose via input strides. The output tensor is not strided. - assert(m_components.front().GetDimensionCount() == m_components.back().GetDimensionCount()); - auto originalStrides = m_inputTensorDescs.front().GetStrides(); - std::vector inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); - std::vector inputStrides(inputSizes.size()); - - // If there were no strides, compute them based in descending packed order - // based on the input sizes. - if (originalStrides.empty()) - { - Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides); - } - else // Copy the original strides. - { - assert(originalStrides.size() >= inputStrides.size()); - size_t offset = originalStrides.size() - inputStrides.size(); - inputStrides.assign(originalStrides.begin() + offset, originalStrides.end()); - } - - // Remap transposed strides using the component labels from input to output. - auto labelIndices = m_components.back().GetLabels(m_labelIndices); - - std::vector newStrides(inputStrides.size()); - std::vector newSizes(inputStrides.size()); - for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i) - { - uint32_t labelIndex = labelIndices[i]; - assert(labelIndex < inputStrides.size()); - newSizes[i] = inputSizes[labelIndex]; - newStrides[i] = inputStrides[labelIndex]; - } - - // Override the initial input tensor with the new strides. - m_inputTensorDescs.front() = TensorDesc(m_inputTensorDescs.front().GetDmlDataType(), newSizes, newStrides, 0); - m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newSizes, std::nullopt, 0); - m_inputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. - m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. - } + ReprojectTensorDescsToProductTensor(); DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = inputDescs.data(); @@ -262,10 +199,208 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper } break; + case RecognizedOperatorType::MultiplyReduceSum: + { + // DML has no generic DML_OPERATOR_DOT_PRODUCT. So construct one via a graph of mul+sumReduce. + + ReprojectTensorDescsToProductTensor(); + TensorDesc productTensorDesc(m_outputTensorDescs.front().GetDmlDataType(), m_productDimensions); + auto dmlProductTensorDesc = productTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC multiplyOperatorDesc = {}; + multiplyOperatorDesc.ATensor = &inputDescs[0]; + multiplyOperatorDesc.BTensor = &inputDescs[1]; + multiplyOperatorDesc.OutputTensor = &dmlProductTensorDesc; + DML_OPERATOR_DESC multiplyOperatorDescWithEnum = { DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &multiplyOperatorDesc }; + + DML_REDUCE_OPERATOR_DESC reduceSumOperatorDesc = {}; + std::vector reducedAxes = GetReductionAxes(); + reduceSumOperatorDesc.Function = DML_REDUCE_FUNCTION_SUM; + reduceSumOperatorDesc.InputTensor = &dmlProductTensorDesc; + reduceSumOperatorDesc.OutputTensor = &outputDescs[0]; + reduceSumOperatorDesc.Axes = reducedAxes.data(); + reduceSumOperatorDesc.AxisCount = gsl::narrow_cast(reducedAxes.size()); + DML_OPERATOR_DESC reduceSumOperatorDescWithEnum = { DML_OPERATOR_REDUCE, &reduceSumOperatorDesc }; + + enum NodeIndex + { + NodeIndexMultiply, + NodeIndexReduceSum, + NodeIndexTotal, + }; + + const DML_OPERATOR_DESC* operatorDescPointers[2] = + { + &multiplyOperatorDescWithEnum, // NodeIndexMultiply + &reduceSumOperatorDescWithEnum, // NodeIndexReduceSum + }; + + DML_INPUT_GRAPH_EDGE_DESC inputEdges[2]; + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdges[1]; + DML_OUTPUT_GRAPH_EDGE_DESC outputEdges[1]; + + DML_INPUT_GRAPH_EDGE_DESC& input0ToMultiplyEdge = inputEdges[0]; + input0ToMultiplyEdge.GraphInputIndex = 0; + input0ToMultiplyEdge.ToNodeIndex = NodeIndexMultiply; + input0ToMultiplyEdge.ToNodeInputIndex = 0; + + DML_INPUT_GRAPH_EDGE_DESC& input1ToMultiplyEdge = inputEdges[1]; + input1ToMultiplyEdge.GraphInputIndex = 1; + input1ToMultiplyEdge.ToNodeIndex = NodeIndexMultiply; + input1ToMultiplyEdge.ToNodeInputIndex = 1; + + DML_INTERMEDIATE_GRAPH_EDGE_DESC& multiplyToReduceSumEdge = intermediateEdges[0]; + multiplyToReduceSumEdge.FromNodeIndex = NodeIndexMultiply; + multiplyToReduceSumEdge.FromNodeOutputIndex = 0; + multiplyToReduceSumEdge.ToNodeIndex = NodeIndexReduceSum; + multiplyToReduceSumEdge.ToNodeInputIndex = 0; + + DML_OUTPUT_GRAPH_EDGE_DESC& reduceSumToOutputEdge = outputEdges[0]; + reduceSumToOutputEdge.FromNodeIndex = NodeIndexReduceSum; + reduceSumToOutputEdge.FromNodeOutputIndex = 0; + reduceSumToOutputEdge.GraphOutputIndex = 0; + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = uint32_t(std::size(inputEdges)); + operatorGraphDesc.inputEdges = std::data(inputEdges); + operatorGraphDesc.intermediateEdgeCount = uint32_t(std::size(intermediateEdges)); + operatorGraphDesc.intermediateEdges = std::data(intermediateEdges); + operatorGraphDesc.outputEdgeCount = uint32_t(std::size(outputEdges)); + operatorGraphDesc.outputEdges = std::data(outputEdges); + operatorGraphDesc.nodeCount = uint32_t(std::size(operatorDescPointers)); + operatorGraphDesc.nodes = std::data(operatorDescPointers); + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } + break; + default: return; } } + + // Reproject all inputs and the output to the intermediate product tensor. + // e.g. + // + // Equation: i,j->ji + // + // [1] [4,5,6,7] [4, 8,12] + // [2] -> [5,10,15] + // [3] [6,12,18] + // [7,14,21] + // + // Expand inputs 0 and 1 to 2D via strides to be directly broadcast-compatible. + // + // [1,1,1,1] [4,5,6,7] [4, 8,12] + // [2,2,2,2] [4,5,6,7] -> [5,10,15] + // [3,3,3,3] [4,5,6,7] [6,12,18] + // [7,14,21] + // + // Transpose the output to be shape-compatible: + // + // [1,1,1,1] [4,5,6,7] [ 4, 5, 6, 7] + // [2,2,2,2] [4,5,6,7] -> [ 8,10,12,14] + // [3,3,3,3] [4,5,6,7] [12,15,18,21] + // + void ReprojectTensorDescsToProductTensor() + { + assert(!m_components.empty() && "Equation components should have already been parsed."); + assert(m_inputTensorDescs.size() + m_outputTensorDescs.size() == m_components.size()); + + for (size_t i = 0, count = m_inputTensorDescs.size(); i < count; ++i) + { + auto inputLabels = m_components[i].GetLabels(m_labelIndices); + ReprojectTensorDescToProductTensor(/*inout*/ m_inputTensorDescs[i], inputLabels, /*isReduced*/ false); + } + auto outputLabels = m_components.back().GetLabels(m_labelIndices); + ReprojectTensorDescToProductTensor(/*inout*/ m_outputTensorDescs.front(), outputLabels, /*isReduced*/ true); + } + + // Project the given tensor for shape compatibility to the internal product tensor, which may include broadcasting, + // transposition, and collapsing repeated terms (e.g. iji,i->j with 2 i's in the first term with strides summed). + // + // e.g. + // + // Axis labels: 3,0,2 // the 2 in the inputShape[0] corresponds to productDimensions[3]. + // Original tensor shape: [2,3,4] + // Original tensor strides: [12,4,1] // packed strides right-to-left + // Product tensor shape: [3,5,4,2] // transposed relative to input, with 1 more axis not in input tensor + // Reprojected shape: [3,5,4,2] // identical to product shape + // (or when isReduced) [3,1,4,2] // inserted dimension is 1 + // Reprojected strides: [4,0,1,12] // the newly inserted tensor has 0 stride for broadcasting + // + void ReprojectTensorDescToProductTensor( + /*inout*/ TensorDesc& tensorDesc, + gsl::span axisLabels, + bool isReduced // Return 1's for any missing dimensions not in axisLabels. + ) + { + assert(m_productDimensions.size() == m_uniqueLabelCount && "Product dimensions were not computed yet"); + const size_t newRank = m_productDimensions.size(); + + // Compute the default strides of the tensor (non-transposed). + tensorDesc.EnsureStridesExist(); + const auto originalSizes = tensorDesc.GetSizes(); + const auto originalStrides = tensorDesc.GetStrides(); + assert(originalSizes.size() >= axisLabels.size()); + assert(originalStrides.size() >= axisLabels.size()); + + // Set default sizes for shape compatibility with the product tensor, and + // set strides to 0's initially to broadcast any missing dimensions. + std::vector newSizes; + std::vector newStrides(newRank, 0u); // Default to 0 to broadcast missing entries. + if (isReduced) + { + newSizes.resize(newRank, 1u); // Fill with 1's initially for any missing (reduced) dimensions. + } + else + { + newSizes = m_productDimensions; // Use the product tensor shape directly. Missing axes will be broadcasted. + } + + // Scatter the original sizes and strides into the corresponding product tensor axis. + for (size_t i = 0, count = axisLabels.size(); i < count; ++i) + { + uint32_t productAxis = axisLabels[i]; + if (productAxis < newRank) + { + newSizes[productAxis] = originalSizes[i]; + newStrides[productAxis] += originalStrides[i]; // Add to combine diagonal cases like i,j,i->i,j + } + } + tensorDesc.SetDimensionsAndStrides(newSizes, newStrides); + tensorDesc.EnsureDimensionCount(1, TensorAxis::RightAligned); + } + + // Reproject a tensor to the given axis arrangement. + // The new tensor will have rank == newAxes.size(). + // e.g. + // + // product tensor shape = [2,3,4,5,6] // m_productDimensions + // newAxes = [4,2,0,1] + // new tensor shape = [6,4,2,3] + // + void ReprojectTensorDescToGivenAxes( + /*inout*/ TensorDesc& tensorDesc, + gsl::span axisLabels, + gsl::span newAxes + ) + { + // First, reproject the original dimensions up to the product tensor. + ReprojectTensorDescToProductTensor(/*inout*/ tensorDesc, axisLabels, /*isReduced*/ false); + tensorDesc.PermuteDimensions(newAxes, TensorAxis::LeftAligned); + } + + std::vector GetReductionAxes() const + { + // Determine which axes are reduced by looking for any output dimensions of size 1. + // Note this could include dimensions that are not actually being reduced and simply + // already had size 1 from the input, but such cases harmless nops either way. + + auto outputSizes = m_outputTensorDescs.front().GetSizes(); + std::vector reducedAxes; + FindValueIndices(outputSizes, 1u, /*out*/ reducedAxes); + return reducedAxes; + } }; void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) @@ -276,7 +411,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* EinSumHelper helper(attributes); auto recognizedOperatorType = helper.GetRecognizedOperatorType(); - static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(12), "Update this function."); + static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(6), "Verify if this function needs updating."); *isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index f26a2ac6fa79a..f738e9c6626fa 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -201,7 +201,7 @@ TensorDesc::TensorDesc( assert(m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)); } -gsl::span TensorDesc::GetStrides() const +gsl::span TensorDesc::GetStrides() const noexcept { if (m_bufferTensorDesc.Strides == nullptr) { @@ -212,8 +212,6 @@ gsl::span TensorDesc::GetStrides() const void TensorDesc::SetStrides(gsl::span strides) { - m_bufferTensorDesc.Strides = strides.empty() ? nullptr : strides.data(); - if (!strides.empty()) { ML_CHECK_VALID_ARGUMENT(strides.size() <= std::size(m_strides)); @@ -221,6 +219,8 @@ void TensorDesc::SetStrides(gsl::span strides) std::copy(strides.begin(), strides.end(), m_strides); } + m_bufferTensorDesc.Strides = strides.empty() ? nullptr : m_strides; + m_bufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( m_bufferTensorDesc.DataType, m_bufferTensorDesc.DimensionCount, @@ -228,7 +228,7 @@ void TensorDesc::SetStrides(gsl::span strides) strides.empty() ? nullptr : m_strides); } -DML_TENSOR_DESC TensorDesc::GetDmlDesc() +DML_TENSOR_DESC TensorDesc::GetDmlDesc() noexcept { if (m_tensorType == DML_TENSOR_TYPE_INVALID) { @@ -289,6 +289,15 @@ void TensorDesc::ForceUnsignedDataType() } } +// Add additional padding 1's to ensure the count is at least that large. +void TensorDesc::EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment) +{ + if (m_bufferTensorDesc.DimensionCount < newDimensionCount) + { + SetDimensionCount(newDimensionCount, alignment); + } +} + void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment) { ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount); @@ -321,38 +330,48 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm m_bufferTensorDesc.DimensionCount = newDimensionCount; } -// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout +void TensorDesc::SetDimensionsAndStrides(gsl::span sizes, gsl::span strides) +{ + static_assert(sizeof(m_sizes) == sizeof(m_strides)); + ML_CHECK_VALID_ARGUMENT(sizes.size() <= std::size(m_sizes)); + ML_CHECK_VALID_ARGUMENT(strides.empty() || strides.size() == sizes.size()); + + std::copy(sizes.begin(), sizes.end(), m_sizes); + m_bufferTensorDesc.DimensionCount = static_cast(sizes.size()); + SetStrides(strides); +} + void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment) { + const uint32_t oldRank = m_bufferTensorDesc.DimensionCount; EnsureStridesExist(); SetDimensionCount(static_cast(dimensionMapping.size()), alignment); - // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping - std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount}; - std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount}; + // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping. + // Note using MaximumDimensionCount instead of oldRank is intentional here, because the old rank could + // be smaller or larger than the new rank, but it will never be larger than MaximumDimensionCount. + std::vector oldSizes{m_sizes, m_sizes + MaximumDimensionCount}; + std::vector oldStrides{m_strides, m_strides + MaximumDimensionCount}; for (size_t i = 0; i < dimensionMapping.size(); i++) { - m_sizes[i] = tempSizes[dimensionMapping[i]]; - m_strides[i] = tempStrides[dimensionMapping[i]]; + uint32_t sourceAxis = dimensionMapping[i]; + m_sizes[i] = sourceAxis < oldRank ? oldSizes[sourceAxis] : 1; + m_strides[i] = sourceAxis < oldRank ? oldStrides[sourceAxis] : 0; } m_bufferTensorDesc.Sizes = m_sizes; m_bufferTensorDesc.Strides = m_strides; } -void TensorDesc::EnsureStridesExist() +void TensorDesc::EnsureStridesExist() noexcept { if (m_bufferTensorDesc.Strides != nullptr) { - // Strides are populated + // Strides are already populated return; } - uint32_t stride = 1; - for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;) - { - m_strides[i] = stride; - stride *= m_sizes[i]; - } + GetDescendingPackedStrides({m_sizes, m_bufferTensorDesc.DimensionCount}, {m_strides, m_bufferTensorDesc.DimensionCount}); + m_bufferTensorDesc.Strides = m_strides; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index 909e2084d0163..bd9f8a46600b9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -32,18 +32,28 @@ namespace Dml uint32_t guaranteedBaseOffsetAlignment ); - DML_TENSOR_DESC GetDmlDesc(); + DML_TENSOR_DESC GetDmlDesc() noexcept; - inline DML_TENSOR_DATA_TYPE GetDmlDataType() const { return m_bufferTensorDesc.DataType; } - inline MLOperatorTensorDataType GetMlOperatorDataType() const { return m_mlOperatorTensorDataType; } + inline DML_TENSOR_DATA_TYPE GetDmlDataType() const noexcept { return m_bufferTensorDesc.DataType; } + inline MLOperatorTensorDataType GetMlOperatorDataType() const noexcept { return m_mlOperatorTensorDataType; } void ForceUnsignedDataType(); - inline bool IsValid() const { return m_tensorType != DML_TENSOR_TYPE_INVALID; } + inline bool IsValid() const noexcept { return m_tensorType != DML_TENSOR_TYPE_INVALID; } inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; } void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment); - gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } - gsl::span GetStrides() const; + void EnsureDimensionCount(uint32_t newDimensionCount, TensorAxis alignment); + + gsl::span GetSizes() const noexcept { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; } + gsl::span GetStrides() const noexcept; void SetStrides(gsl::span strides); + void EnsureStridesExist() noexcept; + + void SetDimensionsAndStrides(gsl::span sizes, gsl::span strides); + + // Rearranges existing m_sizes and m_strides by gathering axes from dimensionMapping. + // It IS legal to change the number of dimensions by adding filler, dropping entire dimensions for a new view, + // and even duplicating logical dimensions. Axes beyond the original rank will be filled by size 1 and stride 0. + // e.g. Existing sizes [2,3,4] with [2,0] yields [4,2]. void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment); inline uint64_t GetBufferSizeInBytes() const @@ -91,8 +101,6 @@ namespace Dml uint32_t m_sizes[MaximumDimensionCount] = {}; uint32_t m_strides[MaximumDimensionCount] = {}; DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {}; - - void EnsureStridesExist(); }; class TensorDescBuilder diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h index 02166f992449e..a3f2777a0c805 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h @@ -6,7 +6,7 @@ #include #include #include - +#include namespace Dml { @@ -16,21 +16,14 @@ namespace Dml return g_converterToUtf16.from_bytes(str.data()); } - static inline std::wstring GetModelName(const onnxruntime::Path& modelPath) + static inline std::wstring GetModelName(const std::filesystem::path& modelPath) { - if (modelPath.GetComponents().empty()) + if (modelPath.empty() || !modelPath.has_filename() || !modelPath.has_extension()) { return L""; } - const onnxruntime::PathString& pathString = modelPath.GetComponents().back(); - size_t dotPosition = pathString.find_last_of('.'); - if (dotPosition == std::string::npos) - { - return L""; - } - - return pathString.substr(0, dotPosition); + return modelPath.stem().native(); } static inline std::wstring GetSanitizedFileName(std::wstring_view name) @@ -138,4 +131,4 @@ namespace StringUtil return {}; } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h index 5ecc49f237481..e9df3fd20aff9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h @@ -26,7 +26,7 @@ #include #include -#include "core/common/gsl.h" +#include #ifdef _GAMING_XBOX_SCARLETT #include diff --git a/onnxruntime/core/providers/dml/GraphTransformers/precomp.h b/onnxruntime/core/providers/dml/GraphTransformers/precomp.h index a2cce6baed7b6..7b146e3c4d9b1 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/precomp.h +++ b/onnxruntime/core/providers/dml/GraphTransformers/precomp.h @@ -14,4 +14,4 @@ #include #include -#include "core/common/gsl.h" +#include diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 32079a90ea73a..686cdbe774a48 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -5,7 +5,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" -#include "core/common/gsl.h" +#include #include #ifdef ORT_NO_EXCEPTIONS diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 44c63089564d3..3a7cf28ef903e 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -68,6 +68,17 @@ namespace OperatorHelper originalValues = std::move(expanded); } + uint32_t GetBitMaskFromIndices(gsl::span indices) noexcept + { + uint32_t bitMask = 0; + for (auto i : indices) + { + assert(i < 32); + bitMask |= (1 << i); + } + return bitMask; + } + float CastFloat16ToFloat32(uint16_t input) { // Promote float16m10e5s1 to float32m23e8s1. @@ -1384,6 +1395,7 @@ namespace OperatorHelper m_recognizedOperatorType = DetermineRecognizedOperatorType(); } + // Updates: m_components, m_labelIndices, m_uniqueLabelCount. void EinSumHelper::ParseEquationComponents() { // Parse an equation like 'ij,jk->ik' into components {ij, jk, ik} mapping letters to @@ -1477,134 +1489,125 @@ namespace OperatorHelper currentComponent.labelIndexEnd = static_cast(m_labelIndices.size()); m_components.push_back(currentComponent); } + + m_uniqueLabelCount = labelMap.size(); } - EinSumHelper::RecognizedOperatorType EinSumHelper::DetermineRecognizedOperatorType() + EinSumHelper::RecognizedOperatorType EinSumHelper::DetermineRecognizedOperatorType() const { if (m_components.empty()) { - return RecognizedOperatorType::None; // Parsing may have found unsupported components - treating as unknown. + return RecognizedOperatorType::None; // Parsing may have found unsupported components - treating as unknown. } - // std::ranges::equal is not supported yet. - auto equals = [](gsl::span a, gsl::span b) + auto areIdenticalAxes = [](gsl::span a, gsl::span b) -> bool { - return std::equal(a.begin(), a.end(), b.begin(), b.end()); + return GetBitMaskFromIndices(a) == GetBitMaskFromIndices(b); }; - auto as_span = [](std::initializer_list il) { + auto as_span = [](std::initializer_list il) + { return gsl::make_span(il.begin(), il.size()); }; - std::array componentRanks; - if (m_components.size() > componentRanks.size()) + // Identify any common patterns that map to existing DirectML operators + // (identity, transpose, sum reduction, matmul, multiplication...). + + constexpr size_t maximumSupportedComponentCount = 3; // 2 inputs + 1 output + // Bail if more than 2 inputs. EinSum is generic and can handle any variable number of inputs, + // but 3-input models have yet to be seen. + if (m_components.size() > maximumSupportedComponentCount) { - // No recognized operator takes more than 2 inputs and 1 output. - // EinSum itself is generic and can handle any variable number of inputs, - // but DML's operators expect fixed counts. return RecognizedOperatorType::None; } - else if (m_components.size() == 2) + else if (m_components.size() == 2) // 1 input, 1 output { - auto inputLabels = m_components[0].GetLabels(m_labelIndices); - auto outputLabels = m_components[1].GetLabels(m_labelIndices); - if (inputLabels.size() == outputLabels.size()) + auto outputLabels = m_components.back().GetLabels(m_labelIndices); + + // Use reduction if the output has fewer dimensions than total dimension labels. + if (outputLabels.size() <= m_uniqueLabelCount) { - // Check identity. - if (equals(inputLabels, outputLabels)) - { - // Handles: "->", "i->i", "ij->ij", "ijk->ijk", "ijkl->ijkl" ... - return RecognizedOperatorType::Identity; - } - else // Transpose since a permutation exists. - { - // Handles: "ij->ji", "ijk->kji", "ijkl->lkji", "ijkl->ijkl" ... - return RecognizedOperatorType::Transpose; - } + // Handles: "ij->i", "i->", "ij->", "ijkl->jl", "ijkl->", "iji->" ... + return RecognizedOperatorType::ReduceSum; } - else if (outputLabels.empty()) // Scalar output, with all inputs reduced. + else { - // Handles: "i->", "ij->", "ijk->", "ijkl->" ... - return RecognizedOperatorType::ReduceSum; + // Handles identity: "->", "i->i", "ij->ij", "ijk->ijk", "ijkl->ijkl" ... + // Handles transpose: "ij->ji", "ijk->kji", "ijkl->lkji", "ijkl->ijkl" ... + // Handles diagional: "ii->i", "iij->ji" + return RecognizedOperatorType::Transpose; } } - else if (m_components.size() == 3) + else if (m_components.size() == 3) // 2 inputs, 1 output { - // If all components have the same size and label order, then apply elementwise multiplication. - auto inputALabels = m_components[0].GetLabels(m_labelIndices); - auto inputBLabels = m_components[1].GetLabels(m_labelIndices); + auto input0Labels = m_components[0].GetLabels(m_labelIndices); + auto input1Labels = m_components[1].GetLabels(m_labelIndices); auto outputLabels = m_components[2].GetLabels(m_labelIndices); - if (equals(inputALabels, outputLabels) && equals(inputBLabels, outputLabels)) + + // Use elementwise multiplication when no reduction occurs. + if (outputLabels.size() == m_uniqueLabelCount) { - // Handles: "i,i->i", "ij,ij->ij", "ijk,ijk->ijk", "ijkl,ijkl->ijkl" ... + // Handles: "i,i->i", "ij,ij->ij", "ijk,ijk->kji", "ijkl,klij->jilk" ... return RecognizedOperatorType::Multiply; } - } - - // Otherwise check for special cases of dedicated operators... - - struct RecognizedOperatorInfo - { - RecognizedOperatorType recognizedOperatorType; - std::initializer_list componentRanks; - std::initializer_list labelIndices; - }; - - const RecognizedOperatorInfo recognizedOperators[] = { - {RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik - {RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik - {RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik - {RecognizedOperatorType::OuterProduct, {1,1,2},{0, 1, 0,1}}, // i,j->ij - {RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik - {RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik - {RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik - {RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik - {RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik - {RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik - {RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod) - {RecognizedOperatorType::MatMulNhcw, {4,4,4},{0,1,2,3, 0,3,2,4, 0,1,2,4}}, // aibj,ajbk->aibk - {RecognizedOperatorType::MatMulNhcwTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,3,2,4}}, // ajbi,ajbk->aibk - {RecognizedOperatorType::MatMulNhcwTransposeB, {4,4,4},{0,1,2,3, 0,4,2,3, 0,1,2,4}}, // aibj,akbj->aibk - {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i - {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j - }; - - // For each recognized operator, compare the labels-per-component and label indices. - for (auto& recognizedOperator : recognizedOperators) - { - if (equals(m_labelIndices, as_span(recognizedOperator.labelIndices)) - && m_components.size() == recognizedOperator.componentRanks.size()) + // Use matrix multiplication when exactly 1 dimension is being reduced, + // the output is up to 4D (DML limit), and the inputs have distinct axes. + else if (outputLabels.size() + 1 == m_uniqueLabelCount + && outputLabels.size() <= 4 + && !areIdenticalAxes(input0Labels, input1Labels)) { - for (size_t i = 0; i < m_components.size(); ++i) - { - componentRanks[i] = m_components[i].GetDimensionCount(); - } - - if (equals(gsl::make_span(componentRanks.data(), m_components.size()), as_span(recognizedOperator.componentRanks))) - { - return recognizedOperator.recognizedOperatorType; - } + return RecognizedOperatorType::MatMul; + } + // Otherwise use an elementwise multiplication and sum reduction combo. + // This is the most generic, but it also uses more intermediate memory. + else + { + // Handles: "ij,ji->", "ijkl,abij->ab", ... + return RecognizedOperatorType::MultiplyReduceSum; } } return RecognizedOperatorType::None; } - std::vector EinSumHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + // Updates: m_productDimensions. + void EinSumHelper::ExtractLabelSizesFromTensors( + const IKernelInformationAdapter& kernelInformation, + const IShapeInformationAdapter& shapeInformation + ) { - assert(!m_components.empty()); // Should have already parsed components. + // Get the dimension length for each term label. e.g. + // + // equation = "ijk,jm->im" + // input[0].shape = [3,2,4] + // input[1].shape = [2,5] + // output[0].shape = [3,5] + // + // Yields: + // + // i = 3 // m_productDimensions[0] + // j = 2 // m_productDimensions[1] + // k = 4 // m_productDimensions[2] + // m = 5 // m_productDimensions[3] + // + // The character labels are already resolved to numeric indices in occurrence order, and so the product + // dimensions are simply an array of sizes: + // + // label sizes = [3,2,4,5] + // + assert(!m_components.empty()); // Should have already parsed components. - uint32_t inputCount = shapeInfo.GetInputCount(); - uint32_t outputCount = shapeInfo.GetOutputCount(); + uint32_t inputCount = kernelInformation.GetInputCount(); + uint32_t outputCount = kernelInformation.GetOutputCount(); ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count."); ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor."); - std::vector labelSizes(m_labelIndices.size(), UINT_MAX); + m_productDimensions.resize(m_uniqueLabelCount, UINT_MAX); // Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier. for (uint32_t i = 0; i < inputCount; ++i) { - auto inputShape = shapeInfo.GetInputTensorShape(i); + auto inputShape = shapeInformation.GetInputTensorShape(i); auto& component = m_components[i]; auto labelIndices = component.GetLabels(m_labelIndices); uint32_t dimensionCount = component.GetDimensionCount(); @@ -1618,18 +1621,23 @@ namespace OperatorHelper // e.g. Given "ij,ji", both i's and both j's must match dimension sizes. uint32_t dimensionSize = inputShape[j]; uint32_t labelIndex = labelIndices[j]; - assert(labelIndex < labelSizes.size()); + assert(labelIndex < m_productDimensions.size()); - if (labelSizes[labelIndex] == UINT_MAX) + if (m_productDimensions[labelIndex] == UINT_MAX) { - labelSizes[labelIndex] = dimensionSize; + m_productDimensions[labelIndex] = dimensionSize; } else { - ML_CHECK_VALID_ARGUMENT(labelSizes[labelIndex] == dimensionSize, "All labels must have the same dimension sizes."); + ML_CHECK_VALID_ARGUMENT(m_productDimensions[labelIndex] == dimensionSize, "All labels must have the same dimension sizes."); } } } + } + + std::vector EinSumHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + assert(!m_components.empty()); // Should have already parsed components. // Generate output dimensions from corresponding input tensor labels. // e.g. Given ij,jk->ij with [2,3] and [3,5], the output is [2,5]. @@ -1637,7 +1645,7 @@ namespace OperatorHelper auto outputLabelIndices = m_components.back().GetLabels(m_labelIndices); for (auto labelIndex : outputLabelIndices) { - outputDimensions.push_back(labelSizes[labelIndex]); + outputDimensions.push_back(m_productDimensions[labelIndex]); } return { EdgeShapes(outputDimensions) }; @@ -1645,13 +1653,8 @@ namespace OperatorHelper bool EinSumHelper::IsMatMulOperatorType() const noexcept { - return m_recognizedOperatorType == RecognizedOperatorType::OuterProduct || - m_recognizedOperatorType == RecognizedOperatorType::MatMul || - m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA || - m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB || - m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw || - m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA || - m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB; + static_assert(RecognizedOperatorType::Total == static_cast(6), "Verify this for any potentially new matrix multiplication operators."); + return m_recognizedOperatorType == RecognizedOperatorType::MatMul; } std::vector MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 06bed80a7c27d..b775de0b39cf4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -762,9 +762,9 @@ class ReduceHelper : public ReduceHelperBase class EinSumHelper { -public: void Initialize(); +public: // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template @@ -772,6 +772,7 @@ class EinSumHelper { m_equation = info.GetAttribute(AttrName::Equation); Initialize(); + ExtractLabelSizesFromTensors(KernelInformationAdapter(info), ShapeInformationAdapter(shape)); } EinSumHelper(const MLOperatorAttributes& info) @@ -785,17 +786,11 @@ class EinSumHelper enum class RecognizedOperatorType { None, - Identity, - Multiply, - OuterProduct, - MatMul, - MatMulTransposeA, - MatMulTransposeB, - MatMulNhcw, - MatMulNhcwTransposeA, - MatMulNhcwTransposeB, - ReduceSum, - Transpose, + Transpose, // 1 input, rearrangement or diagonal slice, no reduction + ReduceSum, // 1 input, no multiplication, just sum reduction + Multiply, // 2 inputs, elementwise multiplication, no sum reduction + MatMul, // 2 inputs, elementwise multiplication, sum reduction on 1 axis. + MultiplyReduceSum, // 2 inputs, elementwise multiplication, sum reduction on multiple axes Total, }; @@ -805,7 +800,11 @@ class EinSumHelper protected: void ParseEquationComponents(); - RecognizedOperatorType DetermineRecognizedOperatorType(); + void ExtractLabelSizesFromTensors( + const IKernelInformationAdapter& kernelInformation, + const IShapeInformationAdapter& shapeInformation + ); + RecognizedOperatorType DetermineRecognizedOperatorType() const; protected: struct Component @@ -824,9 +823,10 @@ class EinSumHelper }; std::string m_equation; - std::vector m_labelIndices; // Concatenation of all labels as rebased indices ("ij,ai" -> 0,1,2,0). - std::vector m_components; // All components in order, including inputs and output. - std::vector m_outputDimensions; + size_t m_uniqueLabelCount = 0; // e.g. ij,jk->ij has 3 unique labels. + std::vector m_labelIndices; // Concatenation of all labels as rebased indices ("ij,ai" -> 0,1,2,0). + std::vector m_components; // All components in order, including inputs and output. + std::vector m_productDimensions; // Dimensions of each unique label (size() == m_uniqueLabelCount). RecognizedOperatorType m_recognizedOperatorType = RecognizedOperatorType::None; }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h index a64d1e01c6cc5..6c47e60e63b8d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h @@ -17,4 +17,4 @@ #include #include -#include "core/common/gsl.h" +#include diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h index 718469f740d4d..831b10c3e147f 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h @@ -66,9 +66,9 @@ class DnnlConv { * - For Onnx a non-dilated kernel would be all 1s * - For OneDNN a non-dilated kernel would be all 0s * - * The memory dimentions returned is in the form expected for OneDNN each dilation dimention - * will be 1 less than the dilated dimention expected by Onnx specification. Be aware of this - * fact as 'dilations' are used in any calcuations since this could result in an off-by-one + * The memory dimensions returned is in the form expected for OneDNN each dilation dimension + * will be 1 less than the dilated dimension expected by Onnx specification. Be aware of this + * fact as 'dilations' are used in any calculations since this could result in an off-by-one * error. */ dnnl::memory::dims GetDilations(DnnlNode& node, ConvShape shape); @@ -115,4 +115,4 @@ class DnnlConv { }; } // namespace ort_dnnl -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h index f0928974b1317..3a27788745ef0 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h @@ -49,9 +49,9 @@ class DnnlConvGrad { * - For Onnx a non-dilated kernel would be all 1s * - For OneDNN a non-dilated kernel would be all 0s * - * The memory dimentions returned is in the form expected for OneDNN each dilation dimention - * will be 1 less than the dilated dimention expected by Onnx specification. Be aware of this - * fact as 'dilations' are used in any calcuations since this could result in an off-by-one + * The memory dimensions returned is in the form expected for OneDNN each dilation dimension + * will be 1 less than the dilated dimension expected by Onnx specification. Be aware of this + * fact as 'dilations' are used in any calculations since this could result in an off-by-one * error. */ dnnl::memory::dims GetDilations(DnnlNode& node, ConvShape shape); @@ -62,4 +62,4 @@ class DnnlConvGrad { }; } // namespace ort_dnnl -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h index d1cea23fca245..dac4e743ea198 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h @@ -20,7 +20,7 @@ class DnnlQAttention { MASK_INDEX = 5, INPUT_ZP = 6, WEIGHTS_ZP = 7, - PAST = 8 // not suppoted + PAST = 8 // not supported }; enum OutputTensors : int { diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index b0f510f054a03..61c035bc29ed5 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -98,6 +98,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kVSINPUExecutionProvider, +#ifdef USE_VSINPU + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 59dfbb0e94922..c51bf5ce9d4a6 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index f43dd814aa959..7a945471c7701 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -4,7 +4,7 @@ #pragma once #include "core/providers/js/js_kernel.h" -#include "core/common/gsl.h" +#include #include "core/providers/cpu/tensor/transpose.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 72193ef6268c1..94480c308b99f 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -60,17 +60,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); } } else if (src_device.Type() == OrtDevice::GPU) { -#ifndef MIGRAPHX_STREAM_SYNC - if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copying from GPU to pinned memory, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } else { - // copying from GPU to CPU memory, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - } -#else HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); -#endif } else { // copying between cpu memory memcpy(dst_data, src_data, bytes); diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc similarity index 83% rename from onnxruntime/core/providers/migraphx/hip_allocator.cc rename to onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 53f10e318e65f..0693eea056416 100644 --- a/onnxruntime/core/providers/migraphx/hip_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -3,7 +3,7 @@ #include "core/providers/shared_library/provider_api.h" #include "migraphx_call.h" -#include "hip_allocator.h" +#include "migraphx_allocator.h" #include "core/common/status.h" #include "core/framework/float16.h" #include "core/common/status.h" @@ -11,7 +11,7 @@ namespace onnxruntime { -void HIPAllocator::CheckDevice() const { +void MIGraphXAllocator::CheckDevice() const { #ifndef NDEBUG // check device to match at debug build // if it's expected to change, call hipSetDevice instead of the check @@ -23,7 +23,7 @@ void HIPAllocator::CheckDevice() const { #endif } -void* HIPAllocator::Alloc(size_t size) { +void* MIGraphXAllocator::Alloc(size_t size) { CheckDevice(); void* p = nullptr; if (size > 0) { @@ -32,12 +32,12 @@ void* HIPAllocator::Alloc(size_t size) { return p; } -void HIPAllocator::Free(void* p) { +void MIGraphXAllocator::Free(void* p) { CheckDevice(); (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown } -void* HIPExternalAllocator::Alloc(size_t size) { +void* MIGraphXExternalAllocator::Alloc(size_t size) { void* p = nullptr; if (size > 0) { p = alloc_(size); @@ -49,7 +49,7 @@ void* HIPExternalAllocator::Alloc(size_t size) { return p; } -void HIPExternalAllocator::Free(void* p) { +void MIGraphXExternalAllocator::Free(void* p) { free_(p); std::lock_guard lock(lock_); auto it = reserved_.find(p); @@ -59,7 +59,7 @@ void HIPExternalAllocator::Free(void* p) { } } -void* HIPExternalAllocator::Reserve(size_t size) { +void* MIGraphXExternalAllocator::Reserve(size_t size) { void* p = Alloc(size); if (!p) return nullptr; std::lock_guard lock(lock_); diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h similarity index 78% rename from onnxruntime/core/providers/migraphx/hip_allocator.h rename to onnxruntime/core/providers/migraphx/migraphx_allocator.h index 3244f9f04ea70..64da844e8c714 100644 --- a/onnxruntime/core/providers/migraphx/hip_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -9,12 +9,12 @@ namespace onnxruntime { -class HIPAllocator : public IAllocator { +class MIGraphXAllocator : public IAllocator { public: - HIPAllocator(int device_id, const char* name) + MIGraphXAllocator(int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device_id)), device_id, OrtMemTypeDefault)) {} virtual void* Alloc(size_t size) override; @@ -24,14 +24,14 @@ class HIPAllocator : public IAllocator { void CheckDevice() const; }; -class HIPExternalAllocator : public HIPAllocator { +class MIGraphXExternalAllocator : public MIGraphXAllocator { typedef void* (*ExternalAlloc)(size_t size); typedef void (*ExternalFree)(void* p); typedef void (*ExternalEmptyCache)(); public: - HIPExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) - : HIPAllocator(device_id, name) { + MIGraphXExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + : MIGraphXAllocator(device_id, name) { alloc_ = reinterpret_cast(alloc); free_ = reinterpret_cast(free); empty_cache_ = reinterpret_cast(empty_cache); @@ -55,7 +55,7 @@ class HIPPinnedAllocator : public IAllocator { HIPPinnedAllocator(int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, device_id), + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(device_id)), device_id, OrtMemTypeCPUOutput)) {} virtual void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 5248ac2f39214..9807cd646e51c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef _WIN32 +#include +#else #include -#include -#include -#include +#endif + +#include #include "core/common/common.h" #include "core/common/status.h" #include "core/providers/shared_library/provider_api.h" @@ -34,16 +37,20 @@ std::conditional_t RocmCall( ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { if (retCode != successCode) { try { - char hostname[HOST_NAME_MAX]; - if (gethostname(hostname, HOST_NAME_MAX) != 0) - strcpy(hostname, "?"); +#ifdef _WIN32 + // According to the POSIX spec, 255 is the safe minimum value. + static constexpr int HOST_NAME_MAX = 255; +#endif + std::string hostname(HOST_NAME_MAX, 0); + if (gethostname(hostname.data(), HOST_NAME_MAX) != 0) + hostname = "?"; int currentHipDevice; (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error static char str[1024]; snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", libName, (int)retCode, RocmErrString(retCode), currentHipDevice, - hostname, + hostname.c_str(), file, line, exprString, msg); if constexpr (THRW) { // throw an exception with the error info @@ -68,9 +75,5 @@ std::conditional_t RocmCall( template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template Status RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); -template void RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); -template Status RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 15d385a636b76..f6a95cebf34b5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -4,8 +4,6 @@ #pragma once #include "migraphx_inc.h" -#pragma once - namespace onnxruntime { // ----------------------------------------------------------------------- diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 6ee85c3a4c047..097b16ecde536 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -13,12 +13,11 @@ #include "core/common/logging/severity.h" #include "migraphx_execution_provider.h" #include "migraphx_execution_provider_utils.h" -#include "hip_allocator.h" +#include "migraphx_allocator.h" #include "gpu_data_transfer.h" #include "migraphx_inc.h" -// TODO: find a better way to share this -#include "core/providers/rocm/rocm_stream_handle.h" +#include "migraphx_stream_handle.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -102,10 +101,10 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); // Set GPU device to be used - HIP_CALL_THROW(hipSetDevice(device_id_)); + HIP_CALL_THROW(hipSetDevice(info_.device_id)); t_ = migraphx::target(info.target_device.c_str()); // whether fp16 is enable @@ -181,16 +180,10 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); } - ROCBLAS_CALL_THROW(rocblas_create_handle(&external_rocblas_handle_)); - ROCBLAS_CALL_THROW(rocblas_set_stream(external_rocblas_handle_, stream_)); - - MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); - MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " - << "device_id: " << device_id_ + << "device_id: " << info_.device_id << ", migraphx_fp16_enable: " << fp16_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ @@ -205,17 +198,14 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { - ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(external_rocblas_handle_))); - ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(external_miopen_handle_))); } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateROCMAllocator(device_id, onnxruntime::CUDA); }, device_id_); + [](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - ORT_UNUSED_PARAMETER(device_id); - return CreateROCMPinnedAllocator(onnxruntime::CUDA_PINNED); + return CreateMIGraphXPinnedAllocator(device_id, onnxruntime::CUDA_PINNED); }, 0); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; @@ -254,40 +244,40 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, migraphx_shape_datatype_t& mgx_type) { mgx_type = migraphx_shape_float_type; switch (type) { - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: mgx_type = migraphx_shape_int16_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: mgx_type = migraphx_shape_int32_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: mgx_type = migraphx_shape_uint16_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: mgx_type = migraphx_shape_uint32_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: mgx_type = migraphx_shape_uint64_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: mgx_type = migraphx_shape_bool_type; break; default: @@ -303,7 +293,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { std::vector result; int num = nums.size(); for (int i = 0; i < num; ++i) { - result.push_back(nums[i]); + result.push_back(static_cast(nums[i])); } return result; @@ -501,16 +491,9 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co if (arg_s != nullptr) { const auto& tensor_dims = arg_s->dim(); std::vector dims; - std::transform(tensor_dims.begin(), - tensor_dims.end(), - std::back_inserter(dims), - [&](auto&& d) -> std::size_t { - if (d.has_dim_value()) { - return d.dim_value(); - } else { - return 0; - } - }); + for (auto&& dim : tensor_dims) { + dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0); + } if (dims == std::vector{0}) { return true; } @@ -546,8 +529,8 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, - const logging::Logger& logger) { - // Then check whether a subgraph should fallback to CPU + [[maybe_unused]] const logging::Logger& logger) { + // Then check whether a subgraph should fall back to CPU // 1. Check whether a subgraph contains a RNN operator std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"}; std::unordered_set op_names = {"AveragePool", "Conv", "Gemm", "LRN", "MatMul", "MaxPool"}; @@ -591,17 +574,10 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (arg_s == nullptr) return false; const auto& tensor_dims = arg_s->dim(); std::vector dims; - std::transform(tensor_dims.begin(), - tensor_dims.end(), - std::back_inserter(dims), - [&](auto&& d) -> std::size_t { - if (d.has_dim_value()) { - return d.dim_value(); - } else { - return 1; - } - }); - return (std::accumulate(dims.begin(), dims.end(), 1, std::multiplies{}) > 300); + for (auto&& dim : tensor_dims) { + dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 1); + } + return (std::accumulate(dims.begin(), dims.end(), 1ULL, std::multiplies{}) > 300); })) { return false; } @@ -623,7 +599,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v static bool IsNodeSupported(const std::set& op_set, const onnxruntime::GraphViewer& graph_viewer, const NodeIndex node_idx, - const logging::Logger& logger) { + [[maybe_unused]] const logging::Logger& logger) { const auto& node = graph_viewer.GetNode(node_idx); const auto& optype = node->OpType(); const auto& domain = node->Domain(); @@ -1442,14 +1418,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // lock to avoid race condition std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); -#ifdef MIGRAPHX_STREAM_SYNC void* rocm_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); auto prog_outputs = prog.run_async(m, static_cast(rocm_stream)); -#else - auto prog_outputs = prog.eval(m); - HIP_CALL_THROW(hipDeviceSynchronize()); -#endif + // In case of input parameters are reused as output parameter call hipMemcpy auto output_num = prog_outputs.size(); if (prog_output_indices.size() < output_num) { @@ -1478,8 +1450,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; - RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, - false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_); + RegisterMIGraphXStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/); } OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { @@ -1487,7 +1458,6 @@ OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); return default_device_; } -#ifdef MIGRAPHX_STREAM_SYNC Status MIGraphXExecutionProvider::Sync() const { HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr))); @@ -1512,5 +1482,4 @@ Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxrunti return Status::OK(); } -#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 1977f71b8b1cf..f34ca320d0a5a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,9 +3,6 @@ #pragma once -#include -#include - #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" @@ -14,8 +11,6 @@ #include #include -// TODO: find a better way to share this -// #include "core/providers/cuda/rocm_stream_handle.h" namespace onnxruntime { @@ -62,13 +57,11 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider(); -#ifdef MIGRAPHX_STREAM_SYNC Status Sync() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; -#endif std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, @@ -85,7 +78,13 @@ class MIGraphXExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; + int GetDeviceId() const override { return info_.device_id; } + ProviderOptions GetProviderOptions() const override { + return MIGraphXExecutionProviderInfo::ToProviderOptions(info_); + } + private: + MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; @@ -98,7 +97,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool load_compiled_model_ = false; std::string load_compiled_path_; bool dump_model_ops_ = false; - int device_id_; migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; @@ -109,8 +107,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::unordered_map map_no_input_shape_; AllocatorPtr allocator_; - miopenHandle_t external_miopen_handle_ = nullptr; - rocblas_handle external_rocblas_handle_ = nullptr; std::unique_ptr metadef_id_generator_; }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 8411e3eef096b..68d5d9af98ea4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -14,7 +14,7 @@ namespace onnxruntime { // Information needed to construct trt execution providers. struct MIGraphXExecutionProviderInfo { std::string target_device; - int device_id{0}; + OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 071070e92a209..9274b5696185c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -28,7 +28,7 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) { +bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 96b24051ace76..2b035b20f619f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -4,5 +4,5 @@ #pragma once #include -#include +#include #include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index dd24dbdc76d2f..6d199930116e8 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -6,7 +6,7 @@ #include "core/providers/migraphx/migraphx_provider_factory.h" #include "migraphx_execution_provider.h" #include "migraphx_provider_factory_creator.h" -#include "hip_allocator.h" +#include "migraphx_allocator.h" #include "gpu_data_transfer.h" #include "core/framework/provider_options.h" @@ -33,10 +33,23 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { return std::make_unique(info_); } +struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { + std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); + } + + std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); + } + +} g_info; + struct MIGraphX_Provider : Provider { + void* GetInfo() override { return &g_info; } + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; - info.device_id = device_id; + info.device_id = static_cast(device_id); info.target_device = "gpu"; return std::make_shared(info); } @@ -44,7 +57,7 @@ struct MIGraphX_Provider : Provider { std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { auto& options = *reinterpret_cast(provider_options); MIGraphXExecutionProviderInfo info; - info.device_id = options.device_id; + info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; info.int8_enable = options.migraphx_int8_enable; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index ac9834e64942a..b257a4318dc0e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -10,4 +10,13 @@ struct IExecutionProviderFactory; struct MIGraphXExecutionProviderInfo; enum class ArenaExtendStrategy : int32_t; struct MIGraphXExecutionProviderExternalAllocatorInfo; + +struct ProviderInfo_MIGraphX { + virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + + protected: + ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc new file mode 100644 index 0000000000000..9c5bb4ecf5c97 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "migraphx_stream_handle.h" + +namespace onnxruntime { + +struct MIGraphXNotification : public synchronize::Notification { + MIGraphXNotification(Stream& s) : Notification(s) { + HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); + } + + ~MIGraphXNotification() { + if (event_) + HIP_CALL_THROW(hipEventDestroy(event_)); + } + + void Activate() override { + // record event with hipEventBlockingSync so we can support sync on host without busy wait. + HIP_CALL_THROW(hipEventRecord(event_, static_cast(stream_.GetHandle()))); + } + + void wait_on_device(Stream& device_stream) { + ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); + // launch a wait command to the migraphx stream + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); + }; + + void wait_on_host() { + // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); + HIP_CALL_THROW(hipEventSynchronize(event_)); + } + + hipEvent_t event_; +}; + +MIGraphXStream::MIGraphXStream(hipStream_t stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream) + : Stream(stream, device), + cpu_allocator_(cpu_allocator), + release_cpu_buffer_on_migraphx_stream_(release_cpu_buffer_on_migraphx_stream) { +} + +MIGraphXStream::~MIGraphXStream() { + ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + if (own_stream_) { + auto* handle = GetHandle(); + if (handle) + HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); + } +} + +std::unique_ptr MIGraphXStream::CreateNotification(size_t /*num_consumers*/) { + return std::make_unique(*this); +} + +void MIGraphXStream::Flush() { + if (own_stream_) + HIP_CALL_THROW(hipStreamSynchronize(static_cast(GetHandle()))); +} + +void MIGraphXStream::EnqueDeferredCPUBuffer(void* cpu_buffer) { + // stream is per thread, so don't need lock + deferred_cpu_buffers_.push_back(cpu_buffer); +} + +struct CpuBuffersInfo { + // This struct stores the information needed + // to release CPU buffers allocated for GPU kernels. + // It's used to enqueue their release after + // associated GPU kernels in a MIGraphX stream. + + // This is a CPU allocator in MIGraphX EP. + // It must be the one used to allocate the + // following pointers. + AllocatorPtr allocator; + // buffers[i] is the i-th pointer added by + // AddDeferredReleaseCPUPtr for a specific + // MIGraphX stream. For example, this fields + // should contain all values in + // deferred_release_buffer_pool_[my_stream] + // when release my_stream's buffers. + std::unique_ptr buffers; + // CPU buffer buffers[i]. + // Number of buffer points in "buffers". + size_t n_buffers; +}; + +static void ReleaseCpuBufferCallback(void* raw_info) { + std::unique_ptr info = std::make_unique(); + info.reset(reinterpret_cast(raw_info)); + for (size_t i = 0; i < info->n_buffers; ++i) { + info->allocator->Free(info->buffers[i]); + } +} + +Status MIGraphXStream::CleanUpOnRunEnd() { + if (deferred_cpu_buffers_.empty()) + return Status::OK(); + // Release the ownership of cpu_buffers_info so that the underlying + // object will keep alive until the end of ReleaseCpuBufferCallback. + if (release_cpu_buffer_on_migraphx_stream_ && cpu_allocator_->Info().alloc_type == OrtArenaAllocator) { + std::unique_ptr cpu_buffers_info = std::make_unique(); + cpu_buffers_info->allocator = cpu_allocator_; + cpu_buffers_info->buffers = std::make_unique(deferred_cpu_buffers_.size()); + for (size_t i = 0; i < deferred_cpu_buffers_.size(); ++i) { + cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); + } + cpu_buffers_info->n_buffers = deferred_cpu_buffers_.size(); + HIP_RETURN_IF_ERROR(hipLaunchHostFunc(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release())); + } else { + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetHandle()))); + for (auto* buffer : deferred_cpu_buffers_) { + cpu_allocator_->Free(buffer); + } + } + + deferred_cpu_buffers_.clear(); + return Status::OK(); +} + +void* MIGraphXStream::GetResource(int version, int id) const { + ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!"); + void* resource{}; + switch (id) { + case RocmResource::hip_stream_t: + return reinterpret_cast(GetHandle()); + default: + break; + } + return resource; +} + +// CPU Stream command handles +void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_device(stream); +} + +void WaitMIGraphXNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_host(); +} + +void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, + const OrtDevice::DeviceType device_type, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream, + hipStream_t external_stream, + bool use_existing_stream) { + // wait migraphx notification on migraphx ep + stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitMIGraphXNotificationOnDevice); + // wait migraphx notification on cpu ep + stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitMIGraphXNotificationOnHost); + if (!use_existing_stream) + stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream](const OrtDevice& device) { + HIP_CALL_THROW(hipSetDevice(device.Id())); + hipStream_t stream = nullptr; + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); + }); + else + stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, + release_cpu_buffer_on_migraphx_stream, + external_stream](const OrtDevice& device) { + return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); + }); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h new file mode 100644 index 0000000000000..03a7c1607e3ad --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/stream_handles.h" +#include "migraphx_inc.h" +#include "migraphx_call.h" + +#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) + +namespace onnxruntime { +void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); + +struct MIGraphXStream : Stream { + MIGraphXStream(hipStream_t stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream); + + ~MIGraphXStream(); + + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; + + void Flush() override; + + Status CleanUpOnRunEnd() override; + + void EnqueDeferredCPUBuffer(void* cpu_buffer); + + bool own_stream_{true}; + + virtual void* GetResource(int version, int id) const; + + virtual WaitNotificationFn GetWaitNotificationFn() const { return WaitMIGraphXNotificationOnDevice; } + + private: + std::vector deferred_cpu_buffers_; + AllocatorPtr cpu_allocator_; + bool release_cpu_buffer_on_migraphx_stream_{true}; +}; + +void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, + const OrtDevice::DeviceType device_type, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream, + hipStream_t external_stream, + bool use_existing_stream); +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 745504ca04941..5108f90fc763a 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -185,7 +185,8 @@ bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) { } common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, - const Path& model_path, float& scale, int32_t& zero_point) { + const std::filesystem::path& model_path, float& scale, + int32_t& zero_point) { scale = 0.0f; zero_point = 0; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index a606b8aceb63d..6f76d04a9691d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -5,10 +5,11 @@ #include #include +#include #include "core/common/inlined_containers.h" #include "core/graph/basic_types.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksTypes.h" -#include "core/common/gsl.h" +#include // This is the minimal Android API Level required by ORT NNAPI EP to run // ORT running on any host system with Android API level less than this will fall back to CPU EP @@ -132,7 +133,7 @@ bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type); bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit); common::Status GetQuantizationScaleAndZeroPoint( - const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const Path& model_path, + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::filesystem::path& model_path, float& scale, int32_t& zero_point); common::Status GetQuantizationScaleAndZeroPoint( diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index d0ae32378379d..12416ea0c121b 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -56,7 +56,13 @@ DEFINE_ADD_OPERAND_FROM_SCALAR(float, FLOAT32); #undef DEFINE_ADD_OPERAND_FROM_SCALAR void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { - skipped_initializers_.insert(tensor_name); + // decrement usage count if this is a known initializer. + // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names + // without first checking if the value is an initializer. + auto entry = initializer_usage_.find(tensor_name); + if (entry != initializer_usage_.end()) { + entry->second -= 1; + } } Status ModelBuilder::Prepare() { @@ -87,7 +93,16 @@ static size_t GetPaddedByteSize(size_t size) { } void ModelBuilder::PreprocessInitializers() { + const auto& initializers = GetInitializerTensors(); + for (const auto& node_unit : node_unit_holder_) { + // find all initializers consumed. AddInitializersToSkip will potentially decrement the usage count. + for (const auto& input : node_unit->Inputs()) { + if (input.node_arg.Exists() && Contains(initializers, input.node_arg.Name())) { + initializer_usage_[input.node_arg.Name()]++; + } + } + if (const auto* op_builder = GetOpBuilder(*node_unit)) { op_builder->AddInitializersToSkip(*this, *node_unit); } @@ -208,11 +223,16 @@ Status ModelBuilder::RegisterInitializers() { std::vector> initializers(initializer_size); size_t sizeAll = 0; + const auto should_skip_initializer = [this](const std::string& name) -> bool { + const auto it = initializer_usage_.find(name); + return it == initializer_usage_.end() || it->second == 0; + }; + int i = 0; for (const auto& pair : initializer_tensors) { const auto& tensor = *pair.second; const auto& name = tensor.name(); - if (Contains(skipped_initializers_, name)) + if (should_skip_initializer(name)) continue; Shape shape; @@ -249,7 +269,7 @@ Status ModelBuilder::RegisterInitializers() { size_t offset = 0; for (const auto& pair : initializer_tensors) { const auto& tensor = *pair.second; - if (Contains(skipped_initializers_, tensor.name())) + if (should_skip_initializer(tensor.name())) continue; auto [index, size, padded_size] = initializers[i++]; @@ -439,10 +459,11 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( Status ModelBuilder::AddOperations() { const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); for (const auto node_idx : node_indices) { - LOGS_DEFAULT(VERBOSE) << "Adding node [" << node_idx << "]"; const auto* node(graph_viewer_.GetNode(node_idx)); const NodeUnit& node_unit = GetNodeUnit(node); + LOGS_DEFAULT(VERBOSE) << "Adding node [" << node_unit.Name() << "] at index [" << node_unit.Index() << "]"; + // Since we may have NodeUnit with multiple nodes, insert NodeUnit with the first occurrence of // its node(s) in topological order may cause the incorrect topological order while inserting // NodeUNits, for example, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index 8eddf389d3b52..b2118150dd304 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -134,7 +134,7 @@ class ModelBuilder { std::unordered_set operands_; std::unordered_set fused_activations_; - std::unordered_set skipped_initializers_; + std::unordered_map initializer_usage_; // All activation nodes (Relu, Relu1, Relu6) as a map std::unordered_map activation_node_units_; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index dab7bccf43396..1c82d5e7452fd 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -8,7 +8,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/common/logging/logging.h" @@ -1142,7 +1142,7 @@ bool IsQuantizationScaleSupported(const GraphViewer& graph_viewer, bool IsQuantizationZeroPointSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::string& op_type, - const Path& model_path, + const std::filesystem::path& model_path, bool is_quant_matmul, bool is_conv_matmul_u8s8_weight) { // zero point is optional here @@ -1282,7 +1282,7 @@ bool IsQuantizedIOSupported(const GraphViewer& graph_viewer, const NodeUnit& nod bool HasRequiredScaleAndZeroPoint(const GraphViewer& graph_viewer, const std::string& op_desc, const NodeUnitIODef& io_def, - const Path& path, + const std::filesystem::path& path, float required_scale, int32_t required_zp) { float scale = 0.0f; int32_t zp = 0; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 0844857a06d61..94e511e04dff3 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/common/common.h" #include "core/framework/node_unit.h" @@ -200,7 +201,7 @@ bool IsQuantizationScaleSupported(const GraphViewer& graph_viewer, bool IsQuantizationZeroPointSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::string& op_type, - const Path& model_path, + const std::filesystem::path& model_path, bool is_quant_matmul, bool is_conv_matmul_u8s8_weight); @@ -214,7 +215,7 @@ bool IsQuantizedIOSupported(const GraphViewer& graph_viewer, const NodeUnit& nod bool HasRequiredScaleAndZeroPoint(const GraphViewer& graph_viewer, const std::string& op_desc, const NodeUnitIODef& io_def, - const Path& path, + const std::filesystem::path& path, float required_scale, int32_t required_zp); // performs broadcasting operation on two shapes to make them compatible diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h index d3d6da8364b6e..2adf346332c66 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h @@ -2098,7 +2098,7 @@ struct NnApi { * @param executionCallback The execution callback to set. * @param callbackContext The context to be passed to the callbacks when they * are invoked. The context object may be used by multiple threads - * simulatenously, so it must be thread-safe. + * simultaneously, so it must be thread-safe. */ void (*SL_ANeuralNetworksDiagnostic_registerCallbacks)( ANeuralNetworksDiagnosticCompilationFinishedCallback compilationCallback, diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index d0ef447a46d21..1c027e39fa5f5 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -105,7 +105,11 @@ BackendManager::BackendManager(const GlobalContext& global_context, subgraph_context_, ep_ctx_handle_); } catch (const OnnxRuntimeException& ex) { - if (device_type.find("NPU") != std::string::npos) { +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + ORT_THROW(ex.what()); +#else + if (device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback) { LOGS_DEFAULT(WARNING) << ex.what(); LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; @@ -122,6 +126,7 @@ BackendManager::BackendManager(const GlobalContext& global_context, } else { ORT_THROW(ex.what()); } +#endif } } } @@ -419,7 +424,13 @@ void BackendManager::Compute(OrtKernelContext* context) { subgraph_context_, ep_ctx_handle_); } catch (const OnnxRuntimeException& ex) { - if (GetGlobalContext().device_type.find("NPU") != std::string::npos) { + // Build option disables fallback to CPU on compilation failures with NPU. +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."; + ORT_THROW(ex.what()); +#else + if (GetGlobalContext().device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback) { LOGS_DEFAULT(WARNING) << ex.what(); LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; @@ -434,7 +445,10 @@ void BackendManager::Compute(OrtKernelContext* context) { } catch (std::string const& msg) { ORT_THROW(msg); } + } else { + ORT_THROW(ex.what()); } +#endif } backend_map_.insert({key, dynamic_backend}); } else { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 9da6e5945ab83..f8046bcb3a06f 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -545,6 +545,11 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { std::cout << "Inference successful" << std::endl; } + // Create a duplicate infer_request_ shared ptr on the stack in the current local scope, + // as the infer_request gets freed in the next stage the reference count for the infer_request decrements & + // thus we dont have any dangling ptr leading to seg faults in the debug mode subsequent execution call + OVInferRequestPtr infer_request_ = infer_request; + // Once the inference is completed, the infer_request becomes free and is placed back into pool of infer_requests_ inferRequestsQueue_->putIdleRequest(std::move(infer_request)); #ifndef NDEBUG @@ -552,7 +557,7 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { if (openvino_ep::backend_utils::IsDebugEnabled()) { inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode std::string& hw_target = global_context_.device_type; - printPerformanceCounts(infer_request, std::cout, hw_target); + printPerformanceCounts(std::move(infer_request_), std::cout, hw_target); } #endif #endif diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 6e11cbf4a699f..598e985676f8d 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -21,6 +21,7 @@ struct GlobalContext { bool ep_context_embed_mode = true; bool export_ep_ctx_blob = false; bool enable_qdq_optimizer = false; + bool disable_cpu_fallback = false; size_t num_of_threads; std::string device_type; std::string precision_str; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 6f7e1fb607864..655e1b180388b 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -33,6 +33,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv global_context_->OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR}; global_context_->export_ep_ctx_blob = info.export_ep_ctx_blob_; global_context_->enable_qdq_optimizer = info.enable_qdq_optimizer_; + global_context_->disable_cpu_fallback = info.disable_cpu_fallback_; // to check if target device is available // using ie_core capability GetAvailableDevices to fetch list of devices plugged in @@ -89,14 +90,8 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, if (!(GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG").empty())) { std::cout << "In the OpenVINO EP" << std::endl; } -#ifdef _WIN32 - std::wstring onnx_path = graph_viewer.ModelPath().ToPathString(); - global_context_->onnx_model_path_name = - std::string(onnx_path.begin(), onnx_path.end()); -#else - global_context_->onnx_model_path_name = - graph_viewer.ModelPath().ToPathString(); -#endif + global_context_->onnx_model_path_name = graph_viewer.ModelPath().string(); + global_context_->onnx_opset_version = graph_viewer.DomainToVersionMap().at(kOnnxDomain); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index d950255c7727b..050fb91c51771 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -74,6 +74,7 @@ struct OpenVINOExecutionProviderInfo { bool disable_dynamic_shapes_{false}; bool export_ep_ctx_blob_{false}; bool enable_qdq_optimizer_{false}; + bool disable_cpu_fallback_{false}; OpenVINOExecutionProviderInfo() = delete; @@ -81,7 +82,7 @@ struct OpenVINOExecutionProviderInfo { size_t num_of_threads, std::string cache_dir, std::string model_priority, int num_streams, void* context, bool enable_opencl_throttling, bool disable_dynamic_shapes, bool export_ep_ctx_blob, - bool enable_qdq_optimizer) + bool enable_qdq_optimizer, bool disable_cpu_fallback) : precision_(precision), enable_npu_fast_compile_(enable_npu_fast_compile), num_of_threads_(num_of_threads), @@ -92,7 +93,8 @@ struct OpenVINOExecutionProviderInfo { enable_opencl_throttling_(enable_opencl_throttling), disable_dynamic_shapes_(disable_dynamic_shapes), export_ep_ctx_blob_(export_ep_ctx_blob), - enable_qdq_optimizer_(enable_qdq_optimizer) { + enable_qdq_optimizer_(enable_qdq_optimizer), + disable_cpu_fallback_(disable_cpu_fallback) { std::set ov_supported_device_types = {"CPU", "GPU", "GPU.0", "GPU.1", "NPU"}; if (dev_type == "") { diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index a45c1fd236af1..45bba431741c5 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -13,7 +13,8 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { const char* cache_dir, const char* model_priority, int num_streams, void* context, bool enable_opencl_throttling, bool disable_dynamic_shapes, - bool export_ep_ctx_blob, bool enable_qdq_optimizer) + bool export_ep_ctx_blob, bool enable_qdq_optimizer, + bool disable_cpu_fallback) : precision_(precision), enable_npu_fast_compile_(enable_npu_fast_compile), num_of_threads_(num_of_threads), @@ -23,7 +24,8 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { enable_opencl_throttling_(enable_opencl_throttling), disable_dynamic_shapes_(disable_dynamic_shapes), export_ep_ctx_blob_(export_ep_ctx_blob), - enable_qdq_optimizer_(enable_qdq_optimizer) { + enable_qdq_optimizer_(enable_qdq_optimizer), + disable_cpu_fallback_(disable_cpu_fallback) { device_type_ = (device_type == nullptr) ? "" : device_type; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; } @@ -45,12 +47,14 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { bool disable_dynamic_shapes_; bool export_ep_ctx_blob_; bool enable_qdq_optimizer_; + bool disable_cpu_fallback_; }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { OpenVINOExecutionProviderInfo info(device_type_, precision_, enable_npu_fast_compile_, num_of_threads_, cache_dir_, model_priority_, num_streams_, context_, enable_opencl_throttling_, - disable_dynamic_shapes_, export_ep_ctx_blob_, enable_qdq_optimizer_); + disable_dynamic_shapes_, export_ep_ctx_blob_, enable_qdq_optimizer_, + disable_cpu_fallback_); return std::make_unique(info); } @@ -99,6 +103,8 @@ struct OpenVINO_Provider : Provider { bool enable_qdq_optimizer = false; + bool disable_cpu_fallback = false; + if (provider_options_map.find("device_type") != provider_options_map.end()) { device_type = provider_options_map.at("device_type").c_str(); @@ -256,6 +262,15 @@ struct OpenVINO_Provider : Provider { export_ep_ctx_blob = false; bool_flag = ""; } + + if (provider_options_map.find("disable_cpu_fallback") != provider_options_map.end()) { + bool_flag = provider_options_map.at("disable_cpu_fallback"); + if (bool_flag == "true" || bool_flag == "True") + disable_cpu_fallback = true; + else if (bool_flag == "false" || bool_flag == "False") + disable_cpu_fallback = false; + bool_flag = ""; + } return std::make_shared(const_cast(device_type.c_str()), const_cast(precision.c_str()), enable_npu_fast_compile, @@ -267,7 +282,8 @@ struct OpenVINO_Provider : Provider { enable_opencl_throttling, disable_dynamic_shapes, export_ep_ctx_blob, - enable_qdq_optimizer); + enable_qdq_optimizer, + disable_cpu_fallback); } void Initialize() override { diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h b/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h index 4df653b022a66..bff70a90b6a70 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory_creator.h @@ -11,9 +11,11 @@ struct OrtOpenVINOProviderOptions; namespace onnxruntime { +struct SessionOptions; // defined in provider_bridge_ort.cc struct OpenVINOProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions* provider_options_map); + static std::shared_ptr Create(ProviderOptions* provider_options_map, + const SessionOptions* session_options); static std::shared_ptr Create(const OrtOpenVINOProviderOptions* provider_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 42a58097e1635..47d3f2f793d7c 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -46,6 +46,10 @@ #include "core/providers/nnapi/nnapi_provider_factory_creator.h" #endif +#if defined(USE_VSINPU) +#include "core/providers/vsinpu/vsinpu_provider_factory_creator.h" +#endif + #if defined(USE_JSEP) #include "core/providers/js/js_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc index 1be460c989732..c685fa065e2ba 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc @@ -13,17 +13,12 @@ namespace onnxruntime { namespace qnn { // ArgMax/ArgMin support limitations: -// - HTP only: cannot generate a graph output // - HTP only: max input rank is 4. // - All backends: ONNX select_last_index attribute must be 0. class ArgMaxMinOpBuilder : public BaseOpBuilder { public: ArgMaxMinOpBuilder() : BaseOpBuilder("ArgMaxMinOpBuilder") {} - Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - protected: Qnn_DataType_t GetSupportedOutputDataType(size_t index, Qnn_DataType_t qnn_data_type) const override ORT_MUST_USE_RESULT; @@ -35,31 +30,18 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { bool do_op_validation) const override ORT_MUST_USE_RESULT; }; -Status ArgMaxMinOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger) const { - // ONNX ArgMax/ArgMin ops output int64 indices, but the equivalent QNN ops output uint32 indices. - // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type - // for ArgMax/ArgMin ops within the graph. However, if the ArgMin/ArgMax op **generates** a graph output, - // then we cannot support it on the HTP backend. - bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); - if (is_npu_backend) { - const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); - ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), - "QNN EP does not support ArgMin/ArgMax ops that generate a graph output."); - } - - return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); -} - Qnn_DataType_t ArgMaxMinOpBuilder::GetSupportedOutputDataType(size_t index, Qnn_DataType_t qnn_data_type) const { - // ONNX ArgMxx ops have int64 output, but QNN requires uint32. + // ONNX ArgMxx ops have int64 output, but QNN requires uint32 or int32. // If this node produces a graph output, BaseOpBuilder::ProcessOutputs() adds a Cast node after the ArgMxx op. - // Otherwise, it just set the output type to unit32. This only works for the QNN CPU backend, since the HTP backend - // does not generally support int64. + // Otherwise, it just set the output type to unit32 or int32. ORT_UNUSED_PARAMETER(index); - ORT_UNUSED_PARAMETER(qnn_data_type); - return QNN_DATATYPE_UINT_32; + if (qnn_data_type == QNN_DATATYPE_INT_64) { + return QNN_DATATYPE_INT_32; + } else if (qnn_data_type == QNN_DATATYPE_UINT_64) { + return QNN_DATATYPE_UINT_32; + } + + return qnn_data_type; } Status ArgMaxMinOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index e1156288d2f8f..2fbe59bf0d578 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -235,9 +235,8 @@ Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrap TensorShape new_tensor_shape(new_tensor_shape_dims); Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); - onnxruntime::PathString model_path = qnn_model_wrapper.GetGraphViewer().ModelPath().ToPathString(); - const ORTCHAR_T* model_path_str = model_path.empty() ? nullptr : model_path.c_str(); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(Env::Default(), model_path_str, initializer, in_tensor)); + ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( + Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor)); ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 1713f201c9d69..5283e9a559cd4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -127,7 +127,8 @@ Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, int32_t elem_data_type = 0; ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(input_1.node_arg, elem_data_type)); - const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) || + const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || + (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) || (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16); ORT_RETURN_IF_NOT(is_signed_type, "Conv weights must be of a signed quantized type if quantized per-channel"); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 3f73ef76e9def..b7455314578de 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -51,7 +51,7 @@ Status PadOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, auto& pads_input_name = inputs[1].node_arg.Name(); ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(pads_input_name), "Qnn doesn't support dynamic pad input"); - if (node_unit.Inputs().size() > 2) { + if (inputs.size() > 2 && inputs[2].node_arg.Exists()) { auto& constant_value_input_name = inputs[2].node_arg.Name(); ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(constant_value_input_name), "Qnn doesn't support dynamic constant_value input"); @@ -227,13 +227,13 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap param_tensor_names.push_back(mode_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(mode_param)); - QnnParamWrapper multiples_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_PAD_AMOUNT, - std::move(pad_amount_dim), std::move(pad_amount)); - param_tensor_names.push_back(multiples_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(multiples_param)); + QnnParamWrapper pad_amount_param(node_unit.Index(), node_unit.Name(), QNN_OP_PAD_PARAM_PAD_AMOUNT, + std::move(pad_amount_dim), std::move(pad_amount)); + param_tensor_names.push_back(pad_amount_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(pad_amount_param)); // Process optional input constant_value - if (node_unit.Inputs().size() > 2) { + if (inputs.size() > 2 && inputs[2].node_arg.Exists()) { ORT_RETURN_IF_ERROR(ProcessConstantValue(qnn_model_wrapper, param_tensor_names, node_unit, inputs[2])); } // constant_value diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc index 88c94581a8887..b033c8723ea86 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc @@ -92,10 +92,8 @@ static Status GetInitializerInputData(const NodeUnitIODef& input, const QnnModel Tensor tensor(dtype, shape, std::make_shared()); // Deserialize initializer into Tensor. - onnxruntime::PathString model_path = qnn_model_wrapper.GetGraphViewer().ModelPath().ToPathString(); - const ORTCHAR_T* model_path_str = model_path.empty() ? nullptr : model_path.c_str(); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), model_path_str, - *initializer_proto, tensor)); + ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( + onnxruntime::Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), *initializer_proto, tensor)); Status status; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 047972294f78c..d22c0811682d0 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -66,17 +66,6 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast(rank - 1), "QNN TopK's axis is always the last dimension"); - // ONNX TopK outputs int64 indices, but the equivalent QNN op outputs uint32 indices. - // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type - // for TopK ops within the graph. However, if the TopK op **generates** a graph output, - // then we cannot support it on the HTP backend. - bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); - if (is_npu_backend) { - const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); - ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), - "QNN EP does not support TopK ops that generate a graph output."); - } - return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index c8bd31bde77de..f44efb1eba6db 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -13,7 +13,7 @@ // #include "GPU/QnnGpuCommon.h" #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" -#include "core/common/gsl.h" +#include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 3a8a8af17b90f..f85cdc401a152 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include "qnn_model_wrapper.h" #include "core/common/safeint.h" @@ -313,7 +315,8 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector& zero_points) const { + /*out*/ std::vector& zero_points, + /*out*/ int32_t& onnx_data_type) const { const auto& graph_initializers = GetInitializerTensors(); auto iter = graph_initializers.find(initializer_name); ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for zero-point(s): ", @@ -323,13 +326,14 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name, ORT_RETURN_IF_NOT(zp_tensor_proto->has_data_type(), "Expected zero-point initializer ", initializer_name.c_str(), " to have a proto data type."); - const int32_t onnx_data_type = zp_tensor_proto->data_type(); + onnx_data_type = zp_tensor_proto->data_type(); std::vector initializer_bytes; ORT_RETURN_IF_ERROR(UnpackInitializerData(*zp_tensor_proto, initializer_bytes)); switch (onnx_data_type) { // QNN use -offset for some reason + case ONNX_NAMESPACE::TensorProto_DataType_INT4: // INT4 zero-points are unpacked as 8-bit values for QNN case ONNX_NAMESPACE::TensorProto_DataType_INT8: { auto int8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); std::transform(int8_span.begin(), int8_span.end(), std::back_inserter(zero_points), @@ -338,6 +342,7 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name, }); break; } + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: // UINT4 zero-points are unpacked as 8-bit values for QNN case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { auto uint8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); std::transform(uint8_span.begin(), uint8_span.end(), std::back_inserter(zero_points), @@ -584,10 +589,36 @@ void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector& unpacked_tensor) const { if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) { - return onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), unpacked_tensor); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), + unpacked_tensor)); + } else { + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor)); + } + + int32_t onnx_data_type = initializer.data_type(); + + // If this is an int4, we need to unpack it because QNN treats int4 as a full int8. + if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); + const size_t num_elems = shape.Size(); + std::vector packed_int4_bytes = std::move(unpacked_tensor); + unpacked_tensor = std::vector(num_elems); + + auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); + auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); + ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); + const size_t num_elems = shape.Size(); + std::vector packed_int4_bytes = std::move(unpacked_tensor); + unpacked_tensor = std::vector(num_elems); + + auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); + auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); + ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); } - return onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor); + return Status::OK(); } } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 0705a1d1b8f5b..9ab122b7f8e28 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -216,7 +216,9 @@ class QnnModelWrapper { Status UnpackScales(const std::string& initializer_name, std::vector& scales) const; // Unpack zero-points from initializer and convert to int32_t (1 zero-point for per-tensor, > 1 for per-channel). - Status UnpackZeroPoints(const std::string& initializer_name, std::vector& zero_points) const; + Status UnpackZeroPoints(const std::string& initializer_name, + /*out*/ std::vector& zero_points, + /*out*/ int32_t& onnx_data_type) const; // Checks if a tensor in the ONNX graph is per-channel quantized. Status IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc index 401d403c15b01..2d22c3c1b8226 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc @@ -9,6 +9,9 @@ #include "QnnTypes.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#define ALIGN_PTR_UP(ptr, align, type) \ + reinterpret_cast((reinterpret_cast(ptr) + (align)-1) & ~((align)-1)) + namespace onnxruntime { namespace qnn { @@ -38,9 +41,10 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const { return QnnQuantParamsWrapper(*this); } +// Initializes by copying from a Qnn_QuantizeParams_t. Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { - if (scale_offset_data_) { - scale_offset_data_.reset(nullptr); + if (per_channel_data_) { + per_channel_data_.reset(nullptr); params_ = QNN_QUANTIZE_PARAMS_INIT; } @@ -51,6 +55,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { switch (params.quantizationEncoding) { case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: + case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: params_ = params; break; case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { @@ -63,15 +68,49 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { const uint32_t num_elems = params.axisScaleOffsetEncoding.numScaleOffsets; if (num_elems > 0) { - scale_offset_data_ = std::make_unique(num_elems); - gsl::span src_span(params.axisScaleOffsetEncoding.scaleOffset, num_elems); - std::copy(src_span.begin(), src_span.end(), scale_offset_data_.get()); - params_.axisScaleOffsetEncoding.scaleOffset = scale_offset_data_.get(); + const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); + constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); + per_channel_data_ = std::make_unique(num_bytes + align); + Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); + + std::memcpy(aligned_dst, params.axisScaleOffsetEncoding.scaleOffset, num_bytes); + params_.axisScaleOffsetEncoding.scaleOffset = aligned_dst; } else { params_.axisScaleOffsetEncoding.scaleOffset = nullptr; } break; } + case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: { + const uint32_t num_elems = params.bwAxisScaleOffsetEncoding.numElements; + + params_.encodingDefinition = params.encodingDefinition; + params_.quantizationEncoding = params.quantizationEncoding; + params_.bwAxisScaleOffsetEncoding.axis = params.bwAxisScaleOffsetEncoding.axis; + params_.bwAxisScaleOffsetEncoding.bitwidth = params.bwAxisScaleOffsetEncoding.bitwidth; + params_.bwAxisScaleOffsetEncoding.numElements = num_elems; + + // Deep copy the scales[] and offsets[] arrays + if (num_elems > 0) { + const size_t num_scale_bytes = num_elems * sizeof(float); + const size_t num_zp_bytes = num_elems * sizeof(int32_t); + const size_t num_bytes = num_scale_bytes + num_zp_bytes; + constexpr std::uintptr_t align = alignof(float); + static_assert(alignof(float) == alignof(int32_t)); + + per_channel_data_ = std::make_unique(num_bytes + align); + char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); + char* zps_begin = scales_begin + num_scale_bytes; + + std::memcpy(scales_begin, params.bwAxisScaleOffsetEncoding.scales, num_scale_bytes); + std::memcpy(zps_begin, params.bwAxisScaleOffsetEncoding.offsets, num_zp_bytes); + params_.bwAxisScaleOffsetEncoding.scales = reinterpret_cast(scales_begin); + params_.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast(zps_begin); + } else { + params_.bwAxisScaleOffsetEncoding.scales = nullptr; + params_.bwAxisScaleOffsetEncoding.offsets = nullptr; + } + break; + } default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding); } @@ -79,11 +118,13 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { return Status::OK(); } +// Initialize this object from a (potentially) quantized ONNX tensor. +// QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers. Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) { const std::optional& ort_quant_params = io_def.quant_param; - if (scale_offset_data_) { - scale_offset_data_.reset(nullptr); + if (per_channel_data_) { + per_channel_data_.reset(nullptr); params_ = QNN_QUANTIZE_PARAMS_INIT; } @@ -98,17 +139,25 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(ort_quant_params->scale.Name(), scales)); + bool is_int4_type = false; + if (ort_quant_params->zero_point != nullptr) { - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points)); + int32_t onnx_tp_type = 0; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points, + onnx_tp_type)); + + is_int4_type = (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || + (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); } const bool is_per_tensor = scales.size() == 1; - if (is_per_tensor) { + // QNN uses different structs to represent quantization parameters depending on + // - per-tensor vs per-channel + // - int4 vs not int4 + if (is_per_tensor && !is_int4_type) { params_.encodingDefinition = QNN_DEFINITION_DEFINED; params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; - - // Parse scale & zero_point params_.scaleOffsetEncoding.scale = scales[0]; if (ort_quant_params->zero_point != nullptr) { @@ -117,8 +166,62 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con } else { params_.scaleOffsetEncoding.offset = 0; } - } else { - // Per-channel quantization. + } else if (is_per_tensor && is_int4_type) { + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET; + params_.bwScaleOffsetEncoding.bitwidth = 4; + params_.bwScaleOffsetEncoding.scale = scales[0]; + + if (ort_quant_params->zero_point != nullptr) { + ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value"); + params_.bwScaleOffsetEncoding.offset = zero_points[0]; + } else { + params_.bwScaleOffsetEncoding.offset = 0; + } + } else if (!is_per_tensor && is_int4_type) { + const auto* io_shape = io_def.node_arg.Shape(); + ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); + const int32_t io_rank = io_shape->dim_size(); + + constexpr int64_t DEFAULT_QDQ_AXIS = 1; + int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS); + if (axis < 0) { + axis += io_rank; + } + ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank, + "Quantization axis must be within the range [0, rank - 1]"); + + const size_t num_elems = scales.size(); + const bool no_zero_points = zero_points.empty(); + ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value"); + ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, + "Expected the same number of zero-points and scales for per-channel quantization"); + + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; + params_.bwAxisScaleOffsetEncoding.axis = static_cast(*(ort_quant_params->axis)); + params_.bwAxisScaleOffsetEncoding.bitwidth = 4; + params_.bwAxisScaleOffsetEncoding.numElements = static_cast(num_elems); + + const size_t num_scale_bytes = num_elems * sizeof(float); + const size_t num_zp_bytes = num_elems * sizeof(int32_t); + const size_t num_bytes = num_scale_bytes + num_zp_bytes; + constexpr std::uintptr_t align = alignof(float); + per_channel_data_ = std::make_unique(num_bytes + align); + + char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); + char* zps_begin = scales_begin + num_scale_bytes; + gsl::span scales_span(reinterpret_cast(scales_begin), num_elems); + gsl::span zps_span(reinterpret_cast(zps_begin), num_elems); + + for (size_t i = 0; i < num_elems; i++) { + scales_span[i] = scales[i]; + zps_span[i] = no_zero_points ? 0 : zero_points[i]; + } + + params_.bwAxisScaleOffsetEncoding.scales = scales_span.data(); + params_.bwAxisScaleOffsetEncoding.offsets = zps_span.data(); + } else if (!is_per_tensor && !is_int4_type) { const auto* io_shape = io_def.node_arg.Shape(); ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); const int32_t io_rank = io_shape->dim_size(); @@ -140,8 +243,11 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, "Expected the same number of zero-points and scales for per-channel quantization"); - scale_offset_data_ = std::make_unique(num_elems); - gsl::span data_span(scale_offset_data_.get(), num_elems); + const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); + constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); + per_channel_data_ = std::make_unique(num_bytes + align); + Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); + gsl::span data_span(aligned_dst, num_elems); for (size_t i = 0; i < num_elems; i++) { data_span[i].scale = scales[i]; @@ -151,6 +257,8 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con params_.axisScaleOffsetEncoding.axis = static_cast(axis); params_.axisScaleOffsetEncoding.numScaleOffsets = static_cast(num_elems); params_.axisScaleOffsetEncoding.scaleOffset = data_span.data(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected tensor kind for QuantParamsWrapper::Init()"); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h index 3cf04da97a8ff..d1f93e5a692bc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h @@ -5,7 +5,7 @@ #include #include "QnnTypes.h" #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/framework/node_unit.h" namespace onnxruntime { @@ -48,17 +48,17 @@ class QnnQuantParamsWrapper { (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET)); } - bool IsPerChannel(bool include_bw = false) const { + bool IsPerChannel() const { return params_.encodingDefinition == QNN_DEFINITION_DEFINED && (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET || - (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); } // Handle transposing of a per-channel quantized tensor. The quantization parameter's axis // must be transposed using the inverse permutation of the Transpose. template Status HandleTranspose(gsl::span perm) { - if (!IsPerChannel(true)) { + if (!IsPerChannel()) { return Status::OK(); } @@ -82,7 +82,7 @@ class QnnQuantParamsWrapper { template Status HandleUnsqueeze(gsl::span orig_shape, gsl::span new_shape) { - if (!IsPerChannel(true)) { + if (!IsPerChannel()) { return Status::OK(); } @@ -134,7 +134,13 @@ class QnnQuantParamsWrapper { private: Qnn_QuantizeParams_t params_; - std::unique_ptr scale_offset_data_; // Stores per-channel scales and offsets + + // Stores arrays of per-channel scales and offsets. Fields in params_ point to this data. + // + // Use an opaque array of bytes because QNN uses different data layouts depending on the quantization encoding: + // - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: array of scale/zp pairs [{scale0, zp0}, {scale1, zp1}, ...] + // - QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: parallel arrays for scales and zps [scale0, ...] [zp0, zp1, ...] + std::unique_ptr per_channel_data_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 19362daee6ca2..c2e500b8980ad 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -43,6 +43,8 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) { } size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { const static std::unordered_map elem_type_to_size = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, sizeof(Int4x2)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, sizeof(UInt4x2)}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, @@ -162,6 +164,12 @@ std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) { case QNN_DATATYPE_BOOL_8: out << "QNN_DATATYPE_BOOL_8"; break; + case QNN_DATATYPE_SFIXED_POINT_4: + out << "QNN_DATATYPE_SFIXED_POINT_4"; + break; + case QNN_DATATYPE_UFIXED_POINT_4: + out << "QNN_DATATYPE_UFIXED_POINT_4"; + break; default: ORT_THROW("Unknown Qnn Data type"); } @@ -216,6 +224,10 @@ std::ostream& operator<<(std::ostream& out, const Qnn_QuantizeParams_t& quantize if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { out << " scale=" << quantize_params.scaleOffsetEncoding.scale; out << " offset=" << quantize_params.scaleOffsetEncoding.offset; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { + out << " bitwidth=" << quantize_params.bwScaleOffsetEncoding.bitwidth; + out << " scale=" << quantize_params.bwScaleOffsetEncoding.scale; + out << " offset=" << quantize_params.bwScaleOffsetEncoding.offset; } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { out << " axis=" << quantize_params.axisScaleOffsetEncoding.axis; size_t num_elems = quantize_params.axisScaleOffsetEncoding.numScaleOffsets; @@ -292,7 +304,9 @@ std::ostream& operator<<(std::ostream& out, const Qnn_ClientBuffer_t& client_buf T* data = reinterpret_cast(client_bufer.data); out << " dataSize=" << client_bufer.dataSize; uint32_t count = client_bufer.dataSize / sizeof(T); - count = count > 100 ? 100 : count; // limit to 100 data + const bool truncate = count > 100; + + count = truncate ? 100 : count; // limit to 100 data out << " clientBuf=("; for (uint32_t i = 0; i < count; i++) { if constexpr (sizeof(T) == 1) { @@ -301,7 +315,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_ClientBuffer_t& client_buf out << data[i] << " "; } } - out << ")"; + out << (truncate ? "..." : "") << ")"; return out; } @@ -319,6 +333,8 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor) { } out << ")"; out << " memType=" << GetQnnTensorMemType(tensor); +// TODO: the code below has compilation errors with the latest ABSL +#if 0 if (GetQnnTensorMemType(tensor) == QNN_TENSORMEMTYPE_RAW) { if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_FLOAT_32) { operator<< (out, GetQnnTensorClientBuf(tensor)); @@ -341,6 +357,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor) { operator<< (out, GetQnnTensorClientBuf(tensor)); } } +#endif out << " quantizeParams:" << GetQnnTensorQParams(tensor); return out; } @@ -432,10 +449,12 @@ bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn }; const std::unordered_map onnx_to_qnn_data_type_quantized = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, QNN_DATATYPE_SFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_INT8, QNN_DATATYPE_SFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_INT16, QNN_DATATYPE_SFIXED_POINT_16}, {ONNX_NAMESPACE::TensorProto_DataType_INT32, QNN_DATATYPE_SFIXED_POINT_32}, {ONNX_NAMESPACE::TensorProto_DataType_INT64, QNN_DATATYPE_INT_64}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, QNN_DATATYPE_UFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_UINT8, QNN_DATATYPE_UFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_UINT16, QNN_DATATYPE_UFIXED_POINT_16}, {ONNX_NAMESPACE::TensorProto_DataType_UINT32, QNN_DATATYPE_UFIXED_POINT_32}, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index d4616e0cefbf2..0ddaa97694217 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -233,7 +233,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio #ifdef _WIN32 auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); // Register callback for ETW capture state (rundown) - etwRegistrationManager.RegisterInternalCallback( + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( [&etwRegistrationManager, this]( LPCGUID SourceId, ULONG IsEnabled, @@ -270,6 +270,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio (void)qnn_backend_manager_->ResetQnnLogLevel(); } }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); #endif // In case ETW gets disabled later @@ -397,6 +398,11 @@ QNNExecutionProvider::~QNNExecutionProvider() { if (!cache) continue; ORT_IGNORE_RETURN_VALUE(cache->erase(this)); } + + // Unregister the ETW callback +#ifdef _WIN32 + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); +#endif } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -805,7 +811,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, context_cache_path_cfg_, - graph_viewer_0.ModelPath().ToPathString(), + graph_viewer_0.ModelPath().native(), context_cache_path); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index c5d3098f87b3a..e7419dabb14d1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -15,6 +15,9 @@ #include #include #include +#ifdef _WIN32 +#include "core/platform/windows/logging/etw_sink.h" +#endif namespace onnxruntime { @@ -86,6 +89,9 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = false; +#ifdef _WIN32 + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; +#endif class PerThreadContext final { public: diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h index 49bfddad36b43..57c1fc5068073 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.h +++ b/onnxruntime/core/providers/rocm/math/softmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/providers/rocm/rocm_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/rocm/miopen_common.cc b/onnxruntime/core/providers/rocm/miopen_common.cc index 6b01f02ae49b5..6b08d392069a1 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.cc +++ b/onnxruntime/core/providers/rocm/miopen_common.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "miopen_common.h" -#include "core/common/gsl.h" +#include #include "core/providers/cpu/tensor/utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index 6214ec7bc0ea3..a2b587a56466f 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -49,10 +49,10 @@ miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, - const miopenConvFwdAlgorithm_t* algo, int n_algo) { + const miopenConvFwdAlgorithm_t* algo, int n_algo, int device_id = 0) { // TODO: get maximum available size from memory arena size_t free, total; - HIP_CALL_THROW(hipMemGetInfo(&free, &total)); + onnxruntime::rocm::hipMemGetInfoAlt(device_id, &free, &total); // Assuming 10% of fragmentation free = static_cast(static_cast(free) * 0.9); size_t max_ws_size = 0; @@ -283,7 +283,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) int algo_count = 1; const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos) + size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId()) : AlgoSearchWorkspaceSize; IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 7974053c32497..ca12720fb3eb4 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -143,6 +143,8 @@ template Status RocmCall(hiprandStatus_t retCode, const template void RocmCall(hiprandStatus_t retCode, const char* exprString, const char* libName, hiprandStatus_t successCode, const char* msg, const char* file, const int line); template Status RocmCall(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line); template void RocmCall(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line); +template Status RocmCall(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line); +template void RocmCall(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line); #ifdef ORT_USE_NCCL template Status RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); diff --git a/onnxruntime/core/providers/rocm/rocm_common.h b/onnxruntime/core/providers/rocm/rocm_common.h index 07b3e252c6000..a8ddb85233031 100644 --- a/onnxruntime/core/providers/rocm/rocm_common.h +++ b/onnxruntime/core/providers/rocm/rocm_common.h @@ -10,7 +10,7 @@ #include "core/providers/rocm/shared_inc/rocm_call.h" #include "core/providers/rocm/shared_inc/fast_divmod.h" #include "core/util/math.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace rocm { @@ -67,5 +67,17 @@ inline int warpSizeDynamic() { return deviceProp.warpSize; } +inline void hipMemGetInfoAlt(uint32_t deviceId, size_t* pFree, size_t* pTotal) { + const auto status = hipMemGetInfo(pFree, pTotal); + if (status != hipSuccess) { + size_t usedMemory = 0; + ROCMSMI_CALL_THROW(rsmi_init(0)); + ROCMSMI_CALL_THROW(rsmi_dev_memory_total_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, pTotal)); + ROCMSMI_CALL_THROW(rsmi_dev_memory_usage_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, &usedMemory)); + *pFree = *pTotal - usedMemory; + ROCMSMI_CALL_THROW(rsmi_shut_down()); + } +} + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 76964e1aed93c..c1cedd47501ae 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -43,7 +43,8 @@ class Memcpy final : public OpKernel { // do we support async copy? // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, + Y->Location().device); ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else { @@ -88,10 +89,12 @@ class Memcpy final : public OpKernel { Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), + alloc); auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, + *ctx->GetComputeStream())); Y->Add(std::move(*target_tensor)); } return Status::OK(); @@ -127,7 +130,8 @@ ONNX_OPERATOR_KERNEL_EX( AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, + ROCMExecutionProviderExternalAllocatorInfo + external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( @@ -149,7 +153,8 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -160,8 +165,11 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi } } -ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t /*gpu_mem_limit*/, - ArenaExtendStrategy /*arena_extend_strategy*/, ROCMExecutionProviderExternalAllocatorInfo /*external_allocator_info*/, +ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, + size_t /*gpu_mem_limit*/, + ArenaExtendStrategy /*arena_extend_strategy*/, + ROCMExecutionProviderExternalAllocatorInfo + /*external_allocator_info*/, OrtArenaCfg* /*default_memory_arena_cfg*/) { HIP_CALL_THROW(hipSetDevice(device_id)); @@ -229,7 +237,8 @@ void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { } ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, + : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { HIP_CALL_THROW(hipSetDevice(info_.device_id)); @@ -261,7 +270,7 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in size_t free = 0; size_t total = 0; - HIP_CALL_THROW(hipMemGetInfo(&free, &total)); + onnxruntime::rocm::hipMemGetInfoAlt(info_.device_id, &free, &total); OverrideTunableOpInfoByEnv(info_); @@ -313,7 +322,8 @@ ROCMExecutionProvider::PerThreadContext& ROCMExecutionProvider::GetPerThreadCont // get or create a context if (context_state_.retired_context_pool.empty()) { context = std::make_shared(info_.device_id, stream_, info_.gpu_mem_limit, - info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg); + info_.arena_extend_strategy, info_.external_allocator_info, + info_.default_memory_arena_cfg); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -357,7 +367,8 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured(0)) { + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && + !GetPerThreadContext().IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; GetPerThreadContext().CaptureBegin(0); } @@ -471,7 +482,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, + LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); @@ -504,20 +516,32 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, + LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, + LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, + LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, + LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, + LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Add); @@ -573,7 +597,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Reciprocal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, + Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sqrt); @@ -587,12 +612,18 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Erf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, bool, Not); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, LRN); @@ -600,11 +631,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, + ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, + ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, + AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool); diff --git a/onnxruntime/core/providers/rocm/rocm_pch.h b/onnxruntime/core/providers/rocm/rocm_pch.h index ccd17157ffb4d..723b990c8d290 100644 --- a/onnxruntime/core/providers/rocm/rocm_pch.h +++ b/onnxruntime/core/providers/rocm/rocm_pch.h @@ -14,6 +14,7 @@ #include #include #include +#include #ifdef ORT_USE_NCCL #include diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 88ef666678b3e..a739fe0a5d193 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -5,7 +5,7 @@ #include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/rocm/rocm_provider_factory_creator.h" -#include "core/common/gsl.h" +#include #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_execution_provider_info.h" diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index b6b40666b8bd0..253ded1911cb5 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -17,6 +17,7 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) +#define ROCMSMI_CALL(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) @@ -27,6 +28,7 @@ std::conditional_t RocmCall( #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) +#define ROCMSMI_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 7cdfb0ffc19f2..2f54a04e15304 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -12,7 +12,7 @@ #include #include #include -#include "core/common/gsl.h" +#include #include #include #include @@ -108,6 +108,7 @@ struct NodeProto; struct SparseTensorProto; struct StringStringEntryProto; struct StringStringEntryProtos; // RepeatedPtrField +struct OperatorSetIdProto; struct TensorProto; struct TensorProtos; // RepeatedPtrField struct TensorShapeProto_Dimension; @@ -120,6 +121,7 @@ struct TypeProto_Sequence; struct TypeProto; struct ValueInfoProto; struct ValueInfoProtos; // RepeatedPtrField +struct FunctionProto; struct InferenceContext; class GraphInferencer; using InferenceFunction = std::function; @@ -146,6 +148,7 @@ struct ConfigOptions; struct DataTransferManager; struct IndexedSubGraph; struct IndexedSubGraph_MetaDef; +enum class IndexedSubGraph_SourceOfSchema : uint8_t; struct KernelCreateInfo; struct KernelDef; struct KernelDefBuilder; @@ -279,6 +282,9 @@ std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info) std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name); std::unique_ptr CreateCUDAPinnedAllocator(const char* name); +std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name); +std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name); + std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name); std::unique_ptr CreateROCMPinnedAllocator(const char* name); diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 27d8a0f06f565..7fb9fd3c8cfd5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -353,16 +353,12 @@ std::unique_ptr CreateGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX -std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) { - return g_host->CreateROCMAllocator(device_id, name); +std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) { + return g_host->CreateMIGraphXAllocator(device_id, name); } -std::unique_ptr CreateROCMPinnedAllocator(const char* name) { - return g_host->CreateROCMPinnedAllocator(name); -} - -std::unique_ptr CreateGPUDataTransfer() { - return g_host->CreateGPUDataTransfer(); +std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) { + return g_host->CreateMIGraphXPinnedAllocator(device_id, name); } #endif @@ -503,7 +499,7 @@ template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } -Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, +Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const std::filesystem::path& model_path, /*out*/ std::vector& unpacked_tensor) { return g_host->UnpackInitializerData(tensor, model_path, unpacked_tensor); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index cc3b13f696a96..382b3ac932520 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -5,6 +5,7 @@ #include #include #include +#include // Public wrappers around internal ort interfaces (currently) #include "core/providers/shared_library/provider_host_api.h" @@ -178,6 +179,11 @@ struct ProviderHost { virtual void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0; #endif +#ifdef USE_MIGRAPHX + virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; +#endif + #ifdef USE_ROCM virtual std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateROCMPinnedAllocator(const char* name) = 0; @@ -209,7 +215,7 @@ struct ProviderHost { virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) = 0; - virtual Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + virtual Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const std::filesystem::path& model_path, /*out*/ std::vector& unpacked_tensor) = 0; virtual uint16_t math__floatToHalf(float f) = 0; @@ -298,6 +304,11 @@ struct ProviderHost { virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0; + // OperatorSetIdProto + virtual std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) = 0; + virtual void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) = 0; + virtual int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) = 0; + #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional virtual const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) = 0; @@ -414,6 +425,11 @@ struct ProviderHost { virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0; virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0; + virtual const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) = 0; + virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) = 0; + virtual int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) = 0; + virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) = 0; + // NodeProto virtual std::unique_ptr NodeProto__construct() = 0; virtual void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) = 0; @@ -421,6 +437,7 @@ struct ProviderHost { virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0; virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0; virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0; + virtual ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) = 0; // TensorProto virtual std::unique_ptr TensorProto__construct() = 0; @@ -489,6 +506,64 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + // FunctionProto + virtual std::unique_ptr FunctionProto__construct() = 0; + virtual void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) = 0; + virtual bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) = 0; + virtual bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) = 0; + virtual std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& name) = 0; + + virtual bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& doc_string) = 0; + + virtual bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& domain) = 0; + + virtual const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; // ConfigOptions @@ -540,6 +615,9 @@ struct ProviderHost { virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) = 0; virtual const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) = 0; + virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0; + virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0; + // KernelDef virtual void KernelDef__operator_delete(KernelDef* p) = 0; virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0; @@ -784,7 +862,7 @@ struct ProviderHost { virtual const std::string& NodeUnit__Name(const NodeUnit* p) noexcept = 0; virtual int NodeUnit__SinceVersion(const NodeUnit* p) noexcept = 0; virtual NodeIndex NodeUnit__Index(const NodeUnit* p) noexcept = 0; - virtual const Path& NodeUnit__ModelPath(const NodeUnit* p) noexcept = 0; + virtual const std::filesystem::path& NodeUnit__ModelPath(const NodeUnit* p) noexcept = 0; virtual ProviderType NodeUnit__GetExecutionProviderType(const NodeUnit* p) noexcept = 0; virtual const Node& NodeUnit__GetNode(const NodeUnit* p) noexcept = 0; @@ -806,7 +884,7 @@ struct ProviderHost { virtual void Model__operator_delete(Model* p) = 0; virtual Graph& Model__MainGraph(Model* p) = 0; virtual std::unique_ptr Model__ToProto(Model* p) = 0; - virtual std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) = 0; + virtual std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) = 0; virtual const ModelMetaData& Model__MetaData(const Model* p) const noexcept = 0; virtual Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) = 0; @@ -834,7 +912,7 @@ struct ProviderHost { virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0; virtual Graph* Graph__MutableParentGraph(Graph* p) = 0; virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0; - virtual const Path& Graph__ModelPath(const Graph* p) const = 0; + virtual const std::filesystem::path& Graph__ModelPath(const Graph* p) const = 0; virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; virtual bool Graph__IsSubgraph(const Graph* p) = 0; virtual const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const = 0; @@ -868,7 +946,7 @@ struct ProviderHost { virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger) = 0; virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0; - virtual const Path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; + virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; virtual const Node* GraphViewer__GetNode(const GraphViewer* p, NodeIndex node_index) = 0; virtual const NodeArg* GraphViewer__GetNodeArg(const GraphViewer* p, const std::string& name) = 0; @@ -902,13 +980,6 @@ struct ProviderHost { int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; - // Path - virtual PathString Path__ToPathString(const Path* p) noexcept = 0; - virtual const std::vector& Path__GetComponents(const Path* p) noexcept = 0; - virtual bool Path__IsEmpty(const Path* p) noexcept = 0; - virtual std::unique_ptr Path__construct() = 0; - virtual void Path__operator_delete(ONNX_NAMESPACE::Path* p) = 0; - // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index fd2540b42a3db..de6c1da1d6430 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -80,6 +80,15 @@ struct StringStringEntryProtos final { PROVIDER_DISALLOW_ALL(StringStringEntryProtos) }; + +struct OperatorSetIdProto final { + std::string* mutable_domain() { return g_host->OperatorSetIdProto__mutable_domain(this); } + void set_version(int64_t version) { return g_host->OperatorSetIdProto__set_version(this, version); } + int64_t version() { return g_host->OperatorSetIdProto__version(this); } + + PROVIDER_DISALLOW_ALL(OperatorSetIdProto) +}; + struct AttributeProto final { static std::unique_ptr Create() { return g_host->AttributeProto__construct(); } void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); } @@ -178,6 +187,11 @@ struct ModelProto final { void set_ir_version(int64_t value) { return g_host->ModelProto__set_ir_version(this, value); } + const OperatorSetIdProto& opset_import(int index) const { return g_host->ModelProto__opset_import(this, index); } + OperatorSetIdProto* mutable_opset_import(int index) { return g_host->ModelProto__mutable_opset_import(this, index); } + int opset_import_size() const { return g_host->ModelProto__opset_import_size(this); } + OperatorSetIdProto* add_opset_import() { return g_host->ModelProto__add_opset_import(this); } + ModelProto() = delete; ModelProto(const ModelProto&) = delete; void operator=(const ModelProto&) = delete; @@ -190,6 +204,7 @@ struct NodeProto final { int attribute_size() { return g_host->NodeProto__attribute_size(this); } const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); } AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); } + AttributeProto* add_attribute() { return g_host->NodeProto__add_attribute(this); } NodeProto() = delete; NodeProto(const NodeProto&) = delete; @@ -372,6 +387,69 @@ struct ValueInfoProtos final { PROVIDER_DISALLOW_ALL(ValueInfoProtos) }; + +struct FunctionProto final { + static std::unique_ptr Create() { return g_host->FunctionProto__construct(); } + static void operator delete(void* p) { g_host->FunctionProto__operator_delete(reinterpret_cast(p)); } + + bool SerializeToString(std::string& string) const { return g_host->FunctionProto__SerializeToString(this, string); } + bool SerializeToOstream(std::ostream& output) const { return g_host->FunctionProto__SerializeToOstream(this, output); } + bool ParseFromString(const std::string& data) { return g_host->FunctionProto__ParseFromString(this, data); } + std::string SerializeAsString() const { return g_host->FunctionProto__SerializeAsString(this); } + + bool has_name() const { return g_host->FunctionProto__has_name(this); } + const std::string& name() const { return g_host->FunctionProto__name(this); } + void set_name(const std::string& name) { g_host->FunctionProto__set_name(this, name); } + + bool has_doc_string() const { return g_host->FunctionProto__has_doc_string(this); } + const std::string& doc_string() const { return g_host->FunctionProto__doc_string(this); } + void set_doc_string(const std::string& doc_string) { g_host->FunctionProto__set_doc_string(this, doc_string); } + + bool has_domain() const { return g_host->FunctionProto__has_domain(this); } + const std::string& domain() const { return g_host->FunctionProto__domain(this); } + void set_domain(const std::string& domain) { g_host->FunctionProto__set_domain(this, domain); } + + const std::string& input(int index) const { return g_host->FunctionProto__input(this, index); } + std::string* mutable_input(int index) { return g_host->FunctionProto__mutable_input(this, index); } + int input_size() const { return g_host->FunctionProto__input_size(this); } + void add_input(const std::string& value) { g_host->FunctionProto__add_input(this, value); } + + const std::string& output(int index) const { return g_host->FunctionProto__output(this, index); } + std::string* mutable_output(int index) { return g_host->FunctionProto__mutable_output(this, index); } + int output_size() const { return g_host->FunctionProto__output_size(this); } + void add_output(const std::string& value) { g_host->FunctionProto__add_output(this, value); } + + const std::string& attribute(int index) const { return g_host->FunctionProto__attribute(this, index); } + std::string* mutable_attribute(int index) { return g_host->FunctionProto__mutable_attribute(this, index); } + int attribute_size() const { return g_host->FunctionProto__attribute_size(this); } + void add_attribute(const std::string& value) { g_host->FunctionProto__add_attribute(this, value); } + + const AttributeProto& attribute_proto(int index) const { return g_host->FunctionProto__attribute_proto(this, index); } + AttributeProto* mutable_attribute_proto(int index) { return g_host->FunctionProto__mutable_attribute_proto(this, index); } + int attribute_proto_size() const { return g_host->FunctionProto__attribute_proto_size(this); } + AttributeProto* add_attribute_proto() { return g_host->FunctionProto__add_attribute_proto(this); } + + const NodeProto& node(int index) const { return g_host->FunctionProto__node(this, index); } + NodeProto* mutable_node(int index) { return g_host->FunctionProto__mutable_node(this, index); } + int node_size() const { return g_host->FunctionProto__node_size(this); } + NodeProto* add_node() { return g_host->FunctionProto__add_node(this); } + + const ValueInfoProto& value_info(int index) const { return g_host->FunctionProto__value_info(this, index); } + ValueInfoProtos* mutable_value_info() { return g_host->FunctionProto__mutable_value_info(this); } + ValueInfoProto* mutable_value_info(int index) { return g_host->FunctionProto__mutable_value_info(this, index); } + int value_info_size() const { return g_host->FunctionProto__value_info_size(this); } + ValueInfoProto* add_value_info() { return g_host->FunctionProto__add_value_info(this); } + + const StringStringEntryProto& metadata_props(int index) const { return g_host->FunctionProto__metadata_props(this, index); } + StringStringEntryProtos* mutable_metadata_props() { return g_host->FunctionProto__mutable_metadata_props(this); } + StringStringEntryProto* mutable_metadata_props(int index) { return g_host->FunctionProto__mutable_metadata_props(this, index); } + int metadata_props_size() const { return g_host->FunctionProto__metadata_props_size(this); } + StringStringEntryProto* add_metadata_props() { return g_host->FunctionProto__add_metadata_props(this); } + + FunctionProto() = delete; + FunctionProto(const FunctionProto&) = delete; + void operator=(const FunctionProto&) = delete; +}; } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -449,6 +527,12 @@ struct IndexedSubGraph_MetaDef final { void operator=(const IndexedSubGraph_MetaDef&) = delete; }; +enum class IndexedSubGraph_SourceOfSchema : uint8_t { + CREATE, + REUSE_OR_CREATE, + EXISTING, +}; + struct IndexedSubGraph final { static std::unique_ptr Create() { return g_host->IndexedSubGraph__construct(); } static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast(p)); } @@ -458,6 +542,9 @@ struct IndexedSubGraph final { void SetMetaDef(std::unique_ptr&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast*>(&meta_def_))); } const IndexedSubGraph_MetaDef* GetMetaDef() const { return reinterpret_cast(g_host->IndexedSubGraph__GetMetaDef(this)); } + void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); } + IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); } + IndexedSubGraph() = delete; IndexedSubGraph(const IndexedSubGraph&) = delete; void operator=(const IndexedSubGraph&) = delete; @@ -811,7 +898,7 @@ struct NodeUnit final { const std::string& Name() const noexcept { return g_host->NodeUnit__Name(this); } int SinceVersion() const noexcept { return g_host->NodeUnit__SinceVersion(this); } NodeIndex Index() const noexcept { return g_host->NodeUnit__Index(this); } - const Path& ModelPath() const noexcept { return g_host->NodeUnit__ModelPath(this); } + const std::filesystem::path& ModelPath() const noexcept { return g_host->NodeUnit__ModelPath(this); } ProviderType GetExecutionProviderType() const noexcept { return g_host->NodeUnit__GetExecutionProviderType(this); } const Node& GetNode() const noexcept { return g_host->NodeUnit__GetNode(this); } @@ -840,7 +927,7 @@ struct Model final { Graph& MainGraph() { return g_host->Model__MainGraph(this); } std::unique_ptr ToProto() { return g_host->Model__ToProto(this); } - std::unique_ptr ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) { return g_host->Model__ToGraphProtoWithExternalInitializers(this, external_file_name, file_path, initializer_size_threshold); } + std::unique_ptr ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) { return g_host->Model__ToGraphProtoWithExternalInitializers(this, external_file_name, file_path, initializer_size_threshold); } const ModelMetaData& MetaData() const noexcept { return g_host->Model__MetaData(this); } Model() = delete; @@ -873,7 +960,7 @@ struct Graph final { const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); } Graph* MutableParentGraph() { return g_host->Graph__MutableParentGraph(this); } const std::string& Name() const noexcept { return g_host->Graph__Name(this); } - const Path& ModelPath() const { return g_host->Graph__ModelPath(this); } + const std::filesystem::path& ModelPath() const { return g_host->Graph__ModelPath(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->Graph__GetProducerNode(this, node_arg_name); } @@ -923,7 +1010,7 @@ class GraphViewer final { std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); } - const Path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } + const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } const Node* GetNode(NodeIndex node_index) const { return g_host->GraphViewer__GetNode(this, node_index); } const NodeArg* GetNodeArg(const std::string& name) const { return g_host->GraphViewer__GetNodeArg(this, name); } @@ -968,19 +1055,6 @@ class GraphViewer final { void operator=(const GraphViewer&) = delete; }; -struct Path final { - static std::unique_ptr Create() { return g_host->Path__construct(); } - static void operator delete(void* p) { g_host->Path__operator_delete(reinterpret_cast(p)); } - - PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); } - const std::vector& GetComponents() const noexcept { return g_host->Path__GetComponents(this); } - bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); } - - Path() = delete; - Path(const Path&) = delete; - void operator=(const Path&) = delete; -}; - struct OpKernelContext final { template const T& RequiredInput(int index) const; diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 2171ce056e029..42788f2960197 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -29,7 +29,7 @@ bool GraphHasCtxNode(const GraphViewer& graph_viewer) { return false; } -const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) { +const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) { // find the top level graph const Graph* cur_graph = &graph_viewer.GetGraph(); while (cur_graph->IsSubgraph()) { diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index f8fefc12c3453..3be08d043da48 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -24,7 +24,7 @@ static const std::string EPCONTEXT_WARNING = for the best model loading time"; bool GraphHasCtxNode(const GraphViewer& graph_viewer); -const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer); +const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, const std::string engine_cache_path, diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 9c2db494f0e41..67cbc8f5d6f13 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -18,7 +18,7 @@ #include "core/providers/cuda/gpu_data_transfer.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" -#include "core/common/gsl.h" +#include #include #include #include @@ -70,7 +70,15 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_mapgetName(); auto dynamic_range_iter = dynamic_range_map.find(tensor_name); if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name; return false; } } @@ -83,11 +91,20 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_mapgetOutput(j)->getName(); auto dynamic_range_iter = dynamic_range_map.find(tensor_name); if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name; return false; } } else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) { nvinfer1::IConstantLayer* const_layer = static_cast(trt_layer); + const std::string const_layer_name = const_layer->getName(); auto trt_weights = const_layer->getWeights(); double max_weight = std::numeric_limits::min(); for (int64_t k = 0, end = trt_weights.count; k < end; ++k) { @@ -108,13 +125,26 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map(trt_weights.values)[k]; break; +#if NV_TENSORRT_MAJOR >= 10 + case nvinfer1::DataType::kINT64: + weight = static_cast(static_cast(trt_weights.values)[k]); + break; +#endif // NV_TENSORRT_MAJOR >= 10 default: - LOGS_DEFAULT(ERROR) << "Found unsupported datatype!"; + LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; return false; } max_weight = std::max(max_weight, std::abs(weight)); } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif if (!trt_layer->getOutput(j)->setDynamicRange(static_cast(-max_weight), static_cast(max_weight))) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name; return false; } } @@ -160,11 +190,20 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSource source{}; t = toUpper(t); if (t == "CUBLAS") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 source = nvinfer1::TacticSource::kCUBLAS; +#endif } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif } else if (t == "CUDNN") { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 source = nvinfer1::TacticSource::kCUDNN; +#endif } else if (t == "EDGE_MASK_CONVOLUTIONS") { source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; } else if (t == "JIT_CONVOLUTIONS") { @@ -269,6 +308,7 @@ void CudaCall(cudaError retCode, const char* exprString, const return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#ifndef USE_CUDA_MINIMAL template <> Status CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) { return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line); @@ -288,7 +328,27 @@ template <> void CudaCall(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) { return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#endif +#if NV_TENSORRT_MAJOR >= 10 +void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#else +// Only override this method when TensorRT <= 8.6 void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -305,6 +365,7 @@ void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*curr // if cudaMalloc fails, returns nullptr. return outputPtr; } +#endif void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { output_shapes.clear(); @@ -1081,20 +1142,26 @@ Status BindKernelOutput(Ort::KernelContext& ctx, TensorrtExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream))); +#else + (void)(stream); +#endif } } TensorrtExecutionProvider::PerThreadContext::~PerThreadContext() { +#ifndef USE_CUDA_MINIMAL if (external_cublas_handle_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); } if (external_cudnn_handle_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); } +#endif trt_context_map_.clear(); } @@ -1230,10 +1297,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_))); +#endif } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1404,6 +1473,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (!ep_context_embed_mode_env.empty()) { ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); } + // incase the EP context is dumped the engine cache has to be enabled + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } enable_engine_cache_for_ep_context_model(); @@ -1699,8 +1772,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } if (external_stream_) { +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); +#endif } if (!external_stream_ && stream_) { @@ -2178,7 +2253,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif SubGraphCollection_t next_nodes_list; const std::vector& subgraph_node_index = graph_viewer->GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); @@ -2318,12 +2400,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // Construct subgraph capability from node list std::vector> result; // Get ModelPath - const auto& path_string = graph.ModelPath().ToPathString(); + const auto& path_string = graph.ModelPath().string(); #ifdef _WIN32 - wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_)); + strncpy_s(model_path_, path_string.c_str(), sizeof(model_path_) - 1); #else - strcpy(model_path_, path_string.c_str()); + strncpy(model_path_, path_string.c_str(), sizeof(model_path_) - 1); #endif + model_path_[sizeof(model_path_) - 1] = '\0'; // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. @@ -3019,7 +3102,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } else { // Set INT8 per tensor dynamic range if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif if (!SetDynamicRange(*trt_network, dynamic_range_map)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); @@ -3138,18 +3228,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading if (context_memory_sharing_enable_) { - size_t mem_size = trt_engine->getDeviceMemorySize(); - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } - #if defined(_MSC_VER) #pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated +#pragma warning(disable : 4996) #endif - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); + size_t mem_size = trt_engine->getDeviceMemorySize(); #if defined(_MSC_VER) #pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); @@ -3415,7 +3508,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Set INT8 Per Tensor Dynamic range if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } @@ -3596,14 +3696,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (context_update) { if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated -#endif +#if NV_TENSORRT_MAJOR < 10 *(trt_state->context) = std::unique_ptr( trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); -#if defined(_MSC_VER) -#pragma warning(pop) +#else + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { *(trt_state->context) = std::unique_ptr( @@ -3685,7 +3783,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Set execution context memory if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } @@ -3816,17 +3921,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading if (context_memory_sharing_enable_) { - size_t mem_size = trt_engine->getDeviceMemorySize(); - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } #if defined(_MSC_VER) #pragma warning(push) -#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated +#pragma warning(disable : 4996) #endif - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); + size_t mem_size = trt_engine->getDeviceMemorySize(); #if defined(_MSC_VER) #pragma warning(pop) +#endif + if (mem_size > max_ctx_mem_size_) { + max_ctx_mem_size_ = mem_size; + } +#if NV_TENSORRT_MAJOR < 10 + trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); #endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); @@ -3992,7 +4101,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Set execution context memory if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f4dae57487f51..b58e86237860c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -3,9 +3,13 @@ #pragma once #include +#ifndef USE_CUDA_MINIMAL #include -#include - +#else +typedef void* cudnnHandle_t; +typedef void* cublasHandle_t; +typedef void* cudnnStatus_t; +#endif #include "core/providers/tensorrt/nv_includes.h" #include "core/platform/ort_mutex.h" @@ -116,8 +120,11 @@ using unique_pointer = std::unique_ptr; // class OutputAllocator : public nvinfer1::IOutputAllocator { public: +#if NV_TENSORRT_MAJOR >= 10 + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; +#else void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override; - +#endif void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; void* getBuffer() { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index df12d90338782..95abcd1bad2b8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -537,10 +537,8 @@ HashValue TRTGenerateId(const GraphViewer& graph_viewer) { }; // Use the model's file name instead of the entire path to avoid cache regeneration if path changes - const auto& model_path_components = main_graph.ModelPath().GetComponents(); - - if (!model_path_components.empty()) { - std::string model_name = PathToUTF8String(model_path_components.back()); + if (main_graph.ModelPath().has_filename()) { + std::string model_name = PathToUTF8String(main_graph.ModelPath().filename()); LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name; // Ensure enough characters are hashed in case model names are too short diff --git a/onnxruntime/core/providers/tvm/tvm_api.cc b/onnxruntime/core/providers/tvm/tvm_api.cc index 37982d0bdb551..e9a7d002e77c8 100644 --- a/onnxruntime/core/providers/tvm/tvm_api.cc +++ b/onnxruntime/core/providers/tvm/tvm_api.cc @@ -16,8 +16,7 @@ #include #include "core/common/common.h" -#include "core/common/path.h" -#include "core/common/gsl.h" +#include #include "tvm_api.h" diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index f4d56dd618a8c..bb2883512037b 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -3,7 +3,7 @@ #pragma once #include #include "vaip/my_ort.h" -#include "core/common/gsl.h" +#include namespace vaip { diff --git a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc new file mode 100644 index 0000000000000..ab31aa313cf6d --- /dev/null +++ b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc @@ -0,0 +1,682 @@ +// Standard headers/libs. +#include +#include +#include +#include + +// 3rd-party headers/libs. +#include + +#include "ep_context_utils.h" + +namespace onnxruntime { + +constexpr const char* kVitisAI = "vitisai"; + +std::unique_ptr ConvertIndexedSubGraphToFunctionProto( + const IndexedSubGraph& sub_graph, const Graph& parent_graph) { + auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create(); + auto* p_meta_def = const_cast(sub_graph.GetMetaDef()); + if (p_meta_def) { + p_func_proto->set_name(p_meta_def->name()); + p_func_proto->set_domain(p_meta_def->domain()); + for (const auto& input : p_meta_def->inputs()) { + p_func_proto->add_input(input); + } + auto* p_metadata_props_0 = p_func_proto->add_metadata_props(); + *(p_metadata_props_0->mutable_key()) = "meta_def_inputs_size"; + *(p_metadata_props_0->mutable_value()) = std::to_string(p_meta_def->inputs().size()); + for (const auto& output : p_meta_def->outputs()) { + p_func_proto->add_output(output); + } + // XXX: SerDes with different fields. + for (const auto& initializer : p_meta_def->constant_initializers()) { + p_func_proto->add_input(initializer); + } + // XXX: SerDes with different numbers of fields. + for (const auto& attr_pair : p_meta_def->attributes()) { + p_func_proto->add_attribute(attr_pair.first); + auto* p_attr_proto = p_func_proto->add_attribute_proto(); + *p_attr_proto = attr_pair.second; + } + p_func_proto->set_doc_string(p_meta_def->doc_string()); + // "since_version" + auto* p_metadata_props_1 = p_func_proto->add_metadata_props(); + *(p_metadata_props_1->mutable_key()) = "meta_def_since_version"; + *(p_metadata_props_1->mutable_value()) = std::to_string(p_meta_def->since_version()); + // "status" + auto* p_metadata_props_2 = p_func_proto->add_metadata_props(); + *(p_metadata_props_2->mutable_key()) = "meta_def_status"; + *(p_metadata_props_2->mutable_value()) = + std::to_string(static_cast(p_meta_def->status())); + // TODO: `MetaDef::type_and_shape_inference_function`. + } + auto p_parent_graph_proto = parent_graph.ToGraphProto(); + for (auto node_index : const_cast(sub_graph).Nodes()) { + auto* p_node_proto = p_parent_graph_proto->mutable_node(static_cast(node_index)); + auto* p_attr_proto = p_node_proto->add_attribute(); + p_attr_proto->set_name("parent_graph_node_index"); + p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_proto->set_i(node_index); + *(p_func_proto->add_node()) = *p_node_proto; + } +#if 0 + // Alternative. + for (const auto node_index : sub_graph.Nodes()) { + const auto* p_node = parent_graph.GetNode(node_index); + auto p_node_proto = ONNX_NAMESPACE::NodeProto::Create(); + // XXX + p_node->ToProto(*p_node_proto, true); + auto* p_attr_proto = p_node_proto->add_attribute(); + p_attr_proto->set_name("parent_graph_node_index"); + p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_proto->set_i(node_index); + *(p_func_proto.add_node()) = *p_node_proto; + } +#endif + auto* p_metadata_props_3 = p_func_proto->add_metadata_props(); + *(p_metadata_props_3->mutable_key()) = "schema_source"; + *(p_metadata_props_3->mutable_value()) = + std::to_string(static_cast(sub_graph.GetSchemaSource())); + return p_func_proto; +} + +std::unique_ptr ConvertFunctionProtoToIndexedSubGraph( + const std::unique_ptr& p_func_proto) { + auto p_isg = IndexedSubGraph::Create(); + // "meta_def_inputs_size" (optional) and "schema_source". + int func_metadata_props_size = p_func_proto->metadata_props_size(); + // Precisely, func_metadata_props_size == 4, which implies + // `IndexedSubGraph::meta_def_` is not null and `IndexedSubGraph::nodes` > 1. + if (func_metadata_props_size > 1) { + auto& prop0 = const_cast(p_func_proto->metadata_props(0)); + int isg_meta_def_inputs_size = std::stoi(*(prop0.mutable_value())); + auto p_meta_def = IndexedSubGraph_MetaDef::Create(); + p_meta_def->name() = p_func_proto->name(); + p_meta_def->domain() = p_func_proto->domain(); + auto& prop1 = const_cast(p_func_proto->metadata_props(1)); + p_meta_def->since_version() = std::stoi(*(prop1.mutable_value())); + auto& prop2 = const_cast(p_func_proto->metadata_props(2)); + p_meta_def->status() = static_cast(std::stoi(*(prop2.mutable_value()))); + auto& meta_def_inputs = p_meta_def->inputs(); + for (int i = 0; i < isg_meta_def_inputs_size; i++) { + meta_def_inputs.push_back(p_func_proto->input(i)); + } + auto& meta_def_outputs = p_meta_def->outputs(); + for (int i = 0, l = p_func_proto->output_size(); i < l; i++) { + meta_def_outputs.push_back(p_func_proto->output(i)); + } + auto& meta_def_initializers = p_meta_def->constant_initializers(); + for (int i = isg_meta_def_inputs_size, l = p_func_proto->input_size(); i < l; i++) { + meta_def_initializers.push_back(p_func_proto->input(i)); + } + auto& meta_def_attrs = p_meta_def->attributes(); + for (int i = 0, l = p_func_proto->attribute_size(); i < l; i++) { + meta_def_attrs.emplace(p_func_proto->attribute(i), p_func_proto->attribute_proto(i)); + } + p_meta_def->doc_string() = p_func_proto->doc_string(); + // TODO: `IndexedSubGraph::type_and_shape_inference_function`. + p_isg->SetMetaDef(std::move(p_meta_def)); + } + auto& isg_nodes = p_isg->Nodes(); + for (int i = 0, l = p_func_proto->node_size(); i < l; i++) { + const auto& node_proto = p_func_proto->node(i); + isg_nodes.push_back( + node_proto.attribute(const_cast(node_proto).attribute_size() - 1).i()); + } + auto schema_source = static_cast( + std::stoi(*(const_cast(p_func_proto->metadata_props(func_metadata_props_size - 1)).mutable_value()))); + p_isg->SetSchemaSource(schema_source); + return p_isg; +} + +std::string SerializeCapabilities( + const std::vector>& capability_ptrs, + const Graph& graph) { + std::stringstream ss; + for (const auto& p : capability_ptrs) { + auto& p_subgraph = p->SubGraph(); + auto p_func_proto = ConvertIndexedSubGraphToFunctionProto(*p_subgraph, graph); + std::string func_proto_buf; + p_func_proto->SerializeToString(func_proto_buf); + size_t buf_len = func_proto_buf.length(); + ss.write(reinterpret_cast(&buf_len), sizeof(buf_len)); + ss.write(func_proto_buf.data(), buf_len); + } + if (!ss.good()) { + ORT_THROW("Serialization stream bad"); + } + return ss.str(); +} + +void DeserializeCapabilities(const std::string& ser_capabilities, + std::vector>& capability_ptrs) { + std::istringstream ss(ser_capabilities); + while (!ss.eof()) { + size_t buf_len; + ss.read(reinterpret_cast(&buf_len), sizeof(buf_len)); + std::string buf(buf_len, '\0'); + ss.read(&buf[0], buf_len); + auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create(); + p_func_proto->ParseFromString(buf); + auto p_subgraph = ConvertFunctionProtoToIndexedSubGraph(p_func_proto); + capability_ptrs.push_back(ComputeCapability::Create(std::move(p_subgraph))); + } +} + +std::string SerializeOrigialGraph(const GraphViewer& graph_viewer) { + // XXX: Will Steps 1/2/3 suffice for restoring a model/graph later? + // Any information loss or mismatch? + // Step 1 + const Graph& orig_graph = graph_viewer.GetGraph(); + // Step 2 + const Model& orig_model = orig_graph.GetModel(); + // Step 3 + auto p_orig_model_proto = const_cast(orig_model).ToProto(); + if (p_orig_model_proto->opset_import_size() == 0) { + for (const auto& it : graph_viewer.DomainToVersionMap()) { + auto* p_opset_import = p_orig_model_proto->add_opset_import(); + *(p_opset_import->mutable_domain()) = it.first; + p_opset_import->set_version(it.second); + } + } + + nlohmann::json j_obj; + if (p_orig_model_proto->opset_import_size() > 0) { + for (int i = 0, n = p_orig_model_proto->opset_import_size(); i < n; ++i) { + auto& op_set_id_proto = const_cast(p_orig_model_proto->opset_import(i)); + j_obj[*op_set_id_proto.mutable_domain()] = std::to_string(op_set_id_proto.version()); + } + } + j_obj["orig_graph_name"] = graph_viewer.Name(); + // TODO: platform dependency (Linux vs Windows). + j_obj["orig_model_path"] = graph_viewer.ModelPath().string(); + + // XXX: `ModelProto::SerializeToString` will lose some info, + // e.g., ModelProto.opset_import. + std::string ser_buf; + p_orig_model_proto->SerializeToString(ser_buf); + j_obj["orig_model_proto_ser_str"] = ser_buf; + + return j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace); +} + +// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc". +ONNX_NAMESPACE::ModelProto* CreateEPContexModel( + const GraphViewer& graph_viewer, + const std::string& serialized_ctx_cache, + const std::string& ctx_cache_file_loc, + const int64_t embed_mode, + const std::string& backend_cache_dir, + const std::string& backend_cache_key, + bool saving_orig_graph, + const logging::Logger* p_logger) { + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP]Creating EP context node"; + // Create a new graph/model, reusing the graph name, + // the op-domain-to-opset-version map, + // and the op schema registry of the current graph. + // XXX: This approach (immediately below) has a memory fault issue (std::bad_alloc). + // auto& ep_ctx_graph = graph_viewer.CreateModel(*p_logger)->MainGraph(); + // This apporach (immediately below) has no memory falut issue. + auto p_temp_model = graph_viewer.CreateModel(*p_logger); + auto& ep_ctx_graph = p_temp_model->MainGraph(); + + const auto& graph_inputs = graph_viewer.GetInputs(); + std::vector input_node_arg_ptrs; + input_node_arg_ptrs.reserve(graph_inputs.size()); + // XXX: vs `GraphViewer::GetInputsIncludingInitializers()`. + for (const auto* p_node_arg : graph_inputs) { + auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg( + p_node_arg->Name(), p_node_arg->TypeAsProto()); + input_node_arg_ptrs.push_back(&temp_node_arg); + } + const auto& graph_outputs = graph_viewer.GetOutputs(); + std::vector output_node_arg_ptrs; + output_node_arg_ptrs.reserve(graph_outputs.size()); + for (const auto* p_node_arg : graph_outputs) { + auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto()); + output_node_arg_ptrs.push_back(&temp_node_arg); + } + + // Attr "embed_mode". + auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_0->set_name(kEmbedModeAttr); + // p_attr_0->set_type(onnx::AttributeProto_AttributeType_INT); + p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_0->set_i(embed_mode); + // Attr "ep_cache_context". + auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_1->set_name(kEPCacheContextAttr); + // p_attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // Relative to the ONNX model file. + p_attr_1->set_s( + embed_mode == 0 ? fs::path(ctx_cache_file_loc).filename().string() : serialized_ctx_cache); + // Attr "source". + auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_2->set_name(kSourceAttr); + // p_attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_2->set_s(kVitisAIExecutionProvider); + // Attr "onnx_model_filename". + auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_3->set_name(kONNXModelFileNameAttr); + // p_attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_3->set_s(graph_viewer.ModelPath().filename().string()); + // Attr "notes". + auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_4->set_name(kNotesAttr); + // p_attr_4->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_4->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // FIXME: 2G-limit of ProtoBuf. + if (saving_orig_graph) { + p_attr_4->set_s(SerializeOrigialGraph(graph_viewer)); + } else { + nlohmann::json j_obj; + j_obj["backend_cache_dir"] = backend_cache_dir; + j_obj["backend_cache_key"] = backend_cache_key; + p_attr_4->set_s(j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace)); + } + + auto p_node_attrs = NodeAttributes::Create(); + constexpr int num_attrs = 5; + p_node_attrs->reserve(num_attrs); + p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0); + p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1); + p_node_attrs->emplace(kSourceAttr, *p_attr_2); + p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3); + p_node_attrs->emplace(kNotesAttr, *p_attr_4); + + // Since we don't implement `IExecutionProvider::GetEpContextNodes()` and + // thus don't leverage `CreateEpContextModel()` in the file "graph_partitioner.cc", + // we specify a brand-new node name here. + ep_ctx_graph.AddNode(kEPContextOpName, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain); + + auto res_status = ep_ctx_graph.Resolve(); + ORT_ENFORCE(res_status.IsOK(), res_status.ErrorMessage()); + LOGS_DEFAULT(VERBOSE) << "Created EP context model graph resolved"; + + auto p_ep_ctx_graph_viewer = ep_ctx_graph.CreateGraphViewer(); + auto p_temp_model_2 = p_ep_ctx_graph_viewer->CreateModel(*p_logger); + auto p_ep_ctx_model_proto = p_temp_model_2->ToProto(); + p_ep_ctx_graph_viewer->ToProto(*p_ep_ctx_model_proto->mutable_graph(), true, true); + p_ep_ctx_model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + return p_ep_ctx_model_proto.release(); +} + +// Ref.: `static common::Status Save(Model& model, int fd)` in the file "model.h". +void DumpEPContextModel( + const std::unique_ptr& p_model_proto, const std::string& ep_ctx_model_file_loc) { + std::fstream dump_stream(ep_ctx_model_file_loc, std::ios::out | std::ios::trunc | std::ios::binary); + p_model_proto->SerializeToOstream(dump_stream); + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP] Dumped " << ep_ctx_model_file_loc; +} + +const Node* GetEPContextNodePtr(const Graph& graph) { + // TODO: Support for multi-node EP context model. + for (const auto* p_node : graph.Nodes()) { + if (p_node->OpType() == kEPContextOp) { + return p_node; + } + } + return nullptr; +} + +bool ValidateEPContextNode(const Graph& graph) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(graph); + assert(p_node != nullptr); + auto& attrs = p_node->GetAttributes(); + assert(attrs.count(kEmbedModeAttr) > 0); + assert(attrs.count(kEPCacheContextAttr) > 0); + assert(attrs.count(kSourceAttr) > 0); + const auto& source_val = attrs.at(kSourceAttr).s(); + if (source_val == kVitisAIExecutionProvider) { + return true; + } + size_t vitisai_len = std::strlen(kVitisAI); + assert(source_val.length() == vitisai_len); + for (size_t i = 0; i < vitisai_len; ++i) { + assert(static_cast(std::tolower(source_val[i])) == kVitisAI[i]); + } + return true; +} + +// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc". +void CreateEPContexNodes( + Graph* p_ep_ctx_graph, + const std::vector& fused_nodes_and_graphs, + const std::string& serialized_ctx_cache, + const std::string& ctx_cache_file_loc, + const int64_t embed_mode, + const std::string& backend_cache_dir, + const std::string& backend_cache_key, + bool saving_orig_graph, + const logging::Logger* p_logger) { + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP]Creating EP context nodes"; + int fused_index = 0; + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + Node& fused_node = fused_node_graph.fused_node; + const auto& fused_name = fused_node.Name(); + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + // FIXME + const auto& graph_inputs = graph_viewer.GetInputs(); + std::vector input_node_arg_ptrs; + input_node_arg_ptrs.reserve(graph_inputs.size()); + // XXX: vs `GraphViewer::GetInputsIncludingInitializers()`. + for (const auto* p_node_arg : graph_inputs) { + auto& temp_node_arg = p_ep_ctx_graph->GetOrCreateNodeArg( + p_node_arg->Name(), p_node_arg->TypeAsProto()); + input_node_arg_ptrs.push_back(&temp_node_arg); + } + const auto& graph_outputs = graph_viewer.GetOutputs(); + std::vector output_node_arg_ptrs; + output_node_arg_ptrs.reserve(graph_outputs.size()); + for (const auto* p_node_arg : graph_outputs) { + auto& temp_node_arg = p_ep_ctx_graph->GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto()); + output_node_arg_ptrs.push_back(&temp_node_arg); + } + + auto p_node_attrs = NodeAttributes::Create(); + if (fused_index == 0) { + p_node_attrs->reserve(7); + // Attr "ep_cache_context". + auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_1->set_name(kEPCacheContextAttr); + p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // Relative to the ONNX model file. + p_attr_1->set_s( + embed_mode == 0 ? fs::path(ctx_cache_file_loc).filename().string() : serialized_ctx_cache); + p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1); + // Attr "notes". + auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_4->set_name(kNotesAttr); + p_attr_4->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // FIXME: 2G-limit of ProtoBuf. + if (saving_orig_graph) { + p_attr_4->set_s(SerializeOrigialGraph(graph_viewer)); + } else { + nlohmann::json j_obj; + j_obj["backend_cache_dir"] = backend_cache_dir; + j_obj["backend_cache_key"] = backend_cache_key; + p_attr_4->set_s(j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace)); + } + p_node_attrs->emplace(kNotesAttr, *p_attr_4); + // Attr "main_context". + auto p_attr_5 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_5->set_name(kMainContextAttr); + p_attr_5->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_5->set_i(1); + p_node_attrs->emplace(kMainContextAttr, *p_attr_5); + } else { + p_node_attrs->reserve(5); + // Attr "main_context". + auto p_attr_5 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_5->set_name(kMainContextAttr); + p_attr_5->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_5->set_i(0); + p_node_attrs->emplace(kMainContextAttr, *p_attr_5); + } + // Attr "embed_mode". + auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_0->set_name(kEmbedModeAttr); + p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_0->set_i(embed_mode); + p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0); + // Attr "source". + auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_2->set_name(kSourceAttr); + p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_2->set_s(kVitisAIExecutionProvider); + p_node_attrs->emplace(kSourceAttr, *p_attr_2); + // Attr "onnx_model_filename". + auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_3->set_name(kONNXModelFileNameAttr); + p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_3->set_s(graph_viewer.ModelPath().filename().string()); + p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3); + // Attr "partition_name". + auto p_attr_6 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_6->set_name(kPartitionNameAttr); + p_attr_6->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_6->set_s(fused_name); + p_node_attrs->emplace(kPartitionNameAttr, *p_attr_6); + + p_ep_ctx_graph->AddNode(fused_name, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain); + + ++fused_index; + } + auto res_status = p_ep_ctx_graph->Resolve(); + ORT_ENFORCE(res_status.IsOK(), res_status.ErrorMessage()); + LOGS_DEFAULT(VERBOSE) << "Created EP context model graph resolved"; +} + +std::string RetrieveEPContextCache( + const Graph& graph, const PathString& ep_ctx_model_loc, bool binary_mode) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(graph); + const auto& attrs = p_node->GetAttributes(); + int64_t embed_mode = attrs.at(kEmbedModeAttr).i(); + const std::string& ep_ctx_cache = attrs.at(kEPCacheContextAttr).s(); + if (embed_mode) { + return ep_ctx_cache; + } + fs::path ep_ctx_fs_path(ep_ctx_model_loc); + // Attr "ep_cache_context" stores a relative path. + ep_ctx_fs_path.replace_filename(fs::path(ep_ctx_cache)); + // TODO: Validaion of the file location to make sure security is met. + if (!fs::exists(ep_ctx_fs_path) || !fs::is_regular_file(ep_ctx_fs_path)) { + ORT_THROW("File for EP context cache is missing"); + } + auto open_mode = binary_mode ? (std::ios::in | std::ios::binary) : std::ios::in; + std::ifstream ifs(ep_ctx_fs_path.string().c_str(), open_mode); + if (!ifs.is_open()) { + ORT_THROW("Exception opening EP context cache file"); + } + ifs.seekg(0, ifs.end); + std::streampos cache_len = ifs.tellg(); + if (cache_len == -1) { + ifs.close(); + ORT_THROW("Error when operating EP context cache file"); + } else if (cache_len == 0) { + ifs.close(); + LOGS_DEFAULT(WARNING) << "Empty EP context cache file: " << ep_ctx_fs_path.string(); + return ""; + } + ifs.seekg(0, ifs.beg); + char* buf = new char[static_cast(cache_len)]; + ifs.read(buf, cache_len); + if (!ifs.good()) { + ifs.close(); + ORT_THROW("Exception reading EP context cache file"); + } + ifs.close(); + std::string cache_payload(buf); + delete[] buf; + return cache_payload; +} + +void RetrieveBackendCacheInfo(const Graph& graph, std::string& cache_dir, std::string& cache_key) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(graph); + if (p_node == nullptr) { + LOGS_DEFAULT(WARNING) << "Failed to retrieve cache info due to no EP context nodes"; + return; + } + const auto& attrs = p_node->GetAttributes(); + const auto& notes_str = attrs.at(kNotesAttr).s(); + nlohmann::json j_obj = nlohmann::json::parse(notes_str); + cache_dir = j_obj["backend_cache_dir"].get(); + cache_key = j_obj["backend_cache_key"].get(); + if (cache_dir.empty()) { + LOGS_DEFAULT(WARNING) << "Retrieved backend cache dir empty"; + } + if (cache_key.empty()) { + LOGS_DEFAULT(WARNING) << "Retrieved backend cache key empty"; + } +} + +std::unique_ptr RetrieveOriginalGraph(const Graph& ep_ctx_graph) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(ep_ctx_graph); + const auto& attrs = p_node->GetAttributes(); + const auto& notes_str = attrs.at(kNotesAttr).s(); + nlohmann::json j_obj = nlohmann::json::parse(notes_str); + + const auto& orig_model_path = j_obj["orig_model_path"].get(); + bool model_loaded = false; + auto p_model_proto = ONNX_NAMESPACE::ModelProto::Create(); + if (!orig_model_path.empty() && fs::exists(orig_model_path) && fs::is_regular_file(orig_model_path)) { + auto load_status = Model::Load(ToPathString(orig_model_path), *p_model_proto); + model_loaded = load_status.IsOK(); + } + if (!model_loaded) { + p_model_proto->ParseFromString(j_obj["orig_model_proto_ser_str"].get()); + if (p_model_proto->opset_import_size() == 0) { + for (auto& elem : j_obj.items()) { + if (elem.key() == "orig_model_path" || elem.key() == "orig_graph_name" || elem.key() == "orig_model_proto_ser_str") { + continue; + } + auto* p_op_set_id_proto = p_model_proto->add_opset_import(); + *(p_op_set_id_proto->mutable_domain()) = elem.key(); + p_op_set_id_proto->set_version(std::stoll(elem.value().get())); + } + } + } + auto& logger = logging::LoggingManager::DefaultLogger(); + auto p_model = Model::Create(std::move(*p_model_proto), ToPathString(orig_model_path), nullptr, logger); + auto& graph = p_model->MainGraph(); + graph.ToGraphProto()->set_name(j_obj["orig_graph_name"].get()); + + return graph.CreateGraphViewer(); +} + +bool GraphHasEPContextNode(const Graph& graph) { + size_t vitisai_len = std::strlen(kVitisAI); + for (const auto* p_node : graph.Nodes()) { + if (p_node->OpType() != kEPContextOp) { + continue; + } + const auto& attrs = p_node->GetAttributes(); + if (attrs.count(kSourceAttr) == 0) { + continue; + } + const auto& source_val = attrs.at(kSourceAttr).s(); + if (source_val == kVitisAIExecutionProvider) { + return true; + } + if (source_val.length() != vitisai_len) { + continue; + } + size_t j = 0; + do { + if (static_cast(std::tolower(source_val[j])) != kVitisAI[j]) { + break; + } + ++j; + } while (j < vitisai_len); + if (j == vitisai_len) { + return true; + } + } + return false; +} + +bool FusedGraphHasEPContextNode( + const std::vector& fused_nodes_and_graphs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + bool has_node = GraphHasEPContextNode(fused_node_graph.filtered_graph.get().GetGraph()); + if (has_node) { + return true; + } + } + return false; +} + +const fs::path& GetTopLevelModelPath(const GraphViewer& graph_viewer) { + const auto& graph = graph_viewer.GetGraph(); + const Graph* p_graph = &graph; + while (p_graph->IsSubgraph()) { + p_graph = p_graph->ParentGraph(); + } + return p_graph->ModelPath(); +} + +bool GetEPContextModelFileLocation( + const std::string& ep_ctx_model_path_cfg, + const PathString& model_path_str, + bool is_ep_ctx_model, + PathString& ep_ctx_model_file_loc) { + if (!ep_ctx_model_file_loc.empty()) { + return true; + } + if (!ep_ctx_model_path_cfg.empty()) { + ep_ctx_model_file_loc = ToPathString(ep_ctx_model_path_cfg); + } else if (!model_path_str.empty()) { + if (is_ep_ctx_model) { + ep_ctx_model_file_loc = model_path_str; + } else { + // Two alternatives for this case. + // Alternative 1: + // 1) Implement/override the method `IExecutionProvider::GetEpContextNodes()`. + // 2) And follow how the default path is implemented in `CreateEpContextModel()` + // in the file "graph_partitioner.cc". + // 3) Model dump is not required. + // Alternative 2: + // 1) Do NOT implement/override `IExecutionProvider::GetEpContextNodes()`. + // 2) No need to follow `CreateEpContextModel()` in the file "graph_partitioner.cc", + // freely implement what the default path is like. + // 3) Model dump is required. +#if 0 + ep_ctx_model_file_loc = model_path_str + ToPathString("_ctx.onnx"); +#endif +#if 1 + fs::path model_fs_path(model_path_str); + fs::path ep_ctx_model_fs_path(model_fs_path.parent_path() / model_fs_path.stem()); + ep_ctx_model_fs_path += fs::path("_ctx.onnx"); + ep_ctx_model_file_loc = ToPathString(ep_ctx_model_fs_path.string()); +#endif + } + } + return !ep_ctx_model_file_loc.empty(); +} + +// The file for EP context cache is in the same folder as the EP context model file. +PathString GetEPContextCacheFileLocation( + const PathString& ep_ctx_model_file_loc, const PathString& model_path_str) { + if (!ep_ctx_model_file_loc.empty()) { + fs::path ep_ctx_model_fs_path(ep_ctx_model_file_loc); + fs::path ep_ctx_cache_fs_path(ep_ctx_model_fs_path.parent_path() / ep_ctx_model_fs_path.stem()); + ep_ctx_cache_fs_path += fs::path("__ep_ctx_cache.bin"); + return ToPathString(ep_ctx_cache_fs_path.string()); + } + fs::path model_fs_path(model_path_str); + fs::path ep_ctx_cache_fs_path(model_fs_path.parent_path() / model_fs_path.stem()); + ep_ctx_cache_fs_path += fs::path("__ep_ctx_cache.bin"); + return ToPathString(ep_ctx_cache_fs_path.string()); +} + +std::string Slurp(const fs::path& file_location, bool binary_mode) { + // std::filesystem::value_type == onnxruntime::PathChar == ORTCHAR_T + // std::filesystem::string_type == onnxruntime::PathString + // const char* location_str = PathToUTF8String(file_location.native()).c_str(); + std::ifstream ifs; + ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + std::stringstream ss; + try { + auto open_mode = binary_mode ? (std::ios::in | std::ios::binary) : std::ios::in; + ifs.open(file_location.string().c_str(), open_mode); + ss << ifs.rdbuf(); + if (!ss.good()) { + LOGS_DEFAULT(WARNING) << "Failed to write to stream"; + } + ifs.close(); + } catch (std::system_error& se) { + LOGS_DEFAULT(WARNING) << "Failed to read " << file_location << ": " << se.code().message(); + } + return ss.str(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 29a1231fdce18..8c1dce0d3dc1a 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -53,6 +53,8 @@ struct OrtVitisAIEpAPI { std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); uint32_t (*vaip_get_version)(); + void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); + void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); void Ensure() { if (handle_) return; @@ -77,6 +79,8 @@ struct OrtVitisAIEpAPI { } std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version", (void**)&vaip_get_version); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache)); } private: @@ -122,13 +126,7 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config vaip_core::DllSafe>> compile_onnx_model( const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { -#ifndef _WIN32 - auto model_path = graph_viewer.ModelPath().ToPathString(); -#else - using convert_t = std::codecvt_utf8; - std::wstring_convert strconverter; - auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); -#endif + auto model_path = graph_viewer.ModelPath().string(); if (s_library_vitisaiep.compile_onnx_model_with_options) { return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); } else { @@ -137,6 +135,17 @@ vaip_core::DllSafe>> c } } +void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data) { + const std::string& model_path = PathToUTF8String(model_path_str); + const onnxruntime::Graph& graph = graph_viewer.GetGraph(); + const auto json_str = config_to_json_str(options); + s_library_vitisaiep.get_backend_compilation_cache(model_path, graph, json_str.c_str(), compiler_codes, cache_dir, cache_key, cache_data); +} + +void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path) { + s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path); +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = @@ -218,9 +227,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { auto& logger = logging::LoggingManager::DefaultLogger(); auto& model = const_cast(const_model); auto model_proto = model.ToProto(); - auto file_path = model.MainGraph().ModelPath().ToPathString(); + auto file_path = model.MainGraph().ModelPath(); auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()}; - auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger); + auto ret = Model::Create(std::move(*model_proto), ToPathString(file_path), &local_registries, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); @@ -270,6 +279,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { graph.SetGraphResolveNeeded(); } auto status = graph.Resolve(); + if (!status.IsOK()) { + std::cerr << "graph resolve error:" << status.ErrorMessage() << std::endl; + } return status.Code(); }; the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto { diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 7cd6da206a6cd..3f46fbde8c714 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -107,12 +107,11 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri auto graph_proto_subgraph = graph.ToGraphProto(); *model_proto->mutable_graph() = *graph_proto_subgraph; auto& logger = logging::LoggingManager::DefaultLogger(); - auto filename_data_relative_path = std::filesystem::path(); auto model = Model::Create(std::move(*model_proto), ToPathString(filename), nullptr, logger); if (initializer_size_threshold == std::numeric_limits::max()) { model_proto = model->ToProto(); } else { - model_proto = model->ToGraphProtoWithExternalInitializers(filename_dat, ToPathString(filename), initializer_size_threshold); + model_proto = model->ToGraphProtoWithExternalInitializers(ToPathString(filename_dat), ToPathString(filename), initializer_size_threshold); } auto& metadata = model->MetaData(); if (!metadata.empty()) { @@ -124,7 +123,7 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri *prop->mutable_value() = m.second; } } - std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary); + std::fstream output(ToPathString(filename), std::ios::out | std::ios::trunc | std::ios::binary); bool result = model_proto->SerializeToOstream(output); output << std::flush; vai_assert(result, "model serialize to ostream error"); diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index f53894b9d1efb..d5e9c63847fbe 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -12,8 +12,7 @@ gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& ten auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { std::vector unpacked_tensor; - auto path = onnxruntime::Path::Create(); - auto s = onnxruntime::utils::UnpackInitializerData(tensor, *path, unpacked_tensor); + auto s = onnxruntime::utils::UnpackInitializerData(tensor, std::filesystem::path(), unpacked_tensor); mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); mut_tensor.clear_float_data(); mut_tensor.clear_int32_data(); diff --git a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h new file mode 100644 index 0000000000000..61a595cf1ae15 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h @@ -0,0 +1,81 @@ +#pragma once + +// Standard headers/libs. +#include +#include +#include +#include + +// 1st-party headers/libs. +#include "core/providers/shared_library/provider_api.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +constexpr const uint8_t kXCCode = 1; +constexpr const uint8_t kDDCode = 2; +constexpr const uint8_t kVCode = 4; + +static constexpr const char* kEPContextOp = "EPContext"; +static constexpr const char* kMainContextAttr = "main_context"; +static constexpr const char* kEPCacheContextAttr = "ep_cache_context"; +static constexpr const char* kEmbedModeAttr = "embed_mode"; +static constexpr const char* kPartitionNameAttr = "partition_name"; +static constexpr const char* kSourceAttr = "source"; +static constexpr const char* kEPSDKVersionAttr = "ep_sdk_version"; +static constexpr const char* kONNXModelFileNameAttr = "onnx_model_filename"; +static constexpr const char* kNotesAttr = "notes"; +static constexpr const char* kEPContextOpDomain = "com.microsoft"; +static constexpr const char* kEPContextOpName = "VitisAIEPContextOp"; + +std::unique_ptr +ConvertIndexedSubGraphToFunctionProto(const IndexedSubGraph&, const Graph&); + +std::unique_ptr ConvertFunctionProtoToIndexedSubGraph( + const std::unique_ptr&); + +std::string SerializeCapabilities( + const std::vector>&, const Graph&); + +void DeserializeCapabilities( + const std::string&, std::vector>&); + +std::string SerializeOrigialGraph(const GraphViewer&); + +// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc". +ONNX_NAMESPACE::ModelProto* CreateEPContexModel(const GraphViewer&, const std::string&, const std::string&, const int64_t, + const std::string&, const std::string&, bool, const logging::Logger*); + +// Ref.: `static common::Status Save(Model& model, int fd)` in the file "model.h". +void DumpEPContextModel(const std::unique_ptr&, const std::string&); + +const Node* GetEPContextNodePtr(const Graph&); + +bool ValidateEPContextNode(const Graph&); + +void CreateEPContexNodes(Graph*, const std::vector&, const std::string&, const std::string&, + const int64_t, const std::string&, const std::string&, bool, const logging::Logger*); + +std::string RetrieveEPContextCache(const Graph&, const PathString&, bool binary_mode = true); + +void RetrieveBackendCacheInfo(const Graph&, std::string&, std::string&); + +std::unique_ptr RetrieveOriginalGraph(const Graph&); + +bool GraphHasEPContextNode(const Graph&); + +bool FusedGraphHasEPContextNode( + const std::vector&); + +const fs::path& GetTopLevelModelPath(const GraphViewer&); + +bool GetEPContextModelFileLocation( + const std::string&, const PathString&, bool, PathString&); + +// The file for EP context cache is in the same folder as the EP context model file. +PathString GetEPContextCacheFileLocation(const PathString&, const PathString&); + +std::string Slurp(const fs::path&, bool binary_mode = false); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 1f8b8802e86b4..3fdbc60bb0ee6 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -14,3 +14,5 @@ void initialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); +void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); +void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 46fc4ac9b2a5d..74482d8e9ee0e 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -38,7 +38,13 @@ enum TensorProto_DataType : int { TensorProto_DataType_UINT64 = 13, TensorProto_DataType_COMPLEX64 = 14, TensorProto_DataType_COMPLEX128 = 15, - TensorProto_DataType_BFLOAT16 = 16 + TensorProto_DataType_BFLOAT16 = 16, + TensorProto_DataType_FLOAT8E4M3FN = 17, + TensorProto_DataType_FLOAT8E4M3FNUZ = 18, + TensorProto_DataType_FLOAT8E5M2 = 19, + TensorProto_DataType_FLOAT8E5M2FNUZ = 20, + TensorProto_DataType_UINT4 = 21, + TensorProto_DataType_INT4 = 22 }; enum AttributeProto_AttributeType : int { AttributeProto_AttributeType_UNDEFINED = 0, diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h index f6831604d883f..def522b8a3a0e 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h @@ -3,7 +3,7 @@ #pragma once #ifdef ONNXRUNTIME_VITISAI_EP_STUB -#include "core/common/gsl.h" +#include #else #include #endif diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 62a7bb602e7e8..3346739890484 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { #define VAIP_ORT_API_MAJOR (3u) -#define VAIP_ORT_API_MINOR (0u) +#define VAIP_ORT_API_MINOR (1u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { uint32_t magic; // 'VAIP' or something else to make sure the following field diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 6fc09f3495aa1..f45b89649bfcb 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -2,22 +2,43 @@ // Licensed under the MIT License. #include "vitisai_execution_provider.h" +// Standard headers/libs. #include #include #include +#include + +// 1st-party headers/libs. +#include "core/platform/env_var_utils.h" +#include "core/common/exceptions.h" #include "vaip/capability.h" #include "vaip/global_api.h" +#include "ep_context_utils.h" using namespace ONNX_NAMESPACE; +namespace fs = std::filesystem; + namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; VitisAIExecutionProvider::VitisAIExecutionProvider( const ProviderOptions& info) + // const ProviderOptions& info, const SessionOptions* p_sess_opts) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { CreateKernelRegistry(); + + auto it = info_.find("ep_context_enable"); + ep_ctx_enabled_ = it != info_.end() && it->second == "1"; + it = info_.find("ep_context_embed_mode"); + ep_ctx_embed_mode_ = it != info_.end() && it->second != "0"; + // ep_ctx_embed_mode_ = it == info_.end() || it->second != "0"; + it = info_.find("ep_context_file_path"); + ep_ctx_model_path_cfg_ = it == info_.end() ? "" : it->second; + LOGS_DEFAULT(VERBOSE) << "EP Context cache enabled: " << ep_ctx_enabled_; + LOGS_DEFAULT(VERBOSE) << "EP context cache embed mode: " << ep_ctx_embed_mode_; + LOGS_DEFAULT(VERBOSE) << "User specified EP context cache path: " << ep_ctx_model_path_cfg_; } void VitisAIExecutionProvider::CreateKernelRegistry() { @@ -30,9 +51,115 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } +// This method is called after both `GetComputeCapabilityOps()` and `Compile()`. +// This timing is required to work with both compilation-based EPs and non-compilation-based EPs. +const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_node_ptrs; + // All preconditions are supposed to have happened. + if (p_ep_ctx_model_) { + auto& graph = p_ep_ctx_model_->MainGraph(); + for (const auto* p_node : graph.Nodes()) { + ep_context_node_ptrs.push_back(p_node); + } + } + return ep_context_node_ptrs; +} + +void VitisAIExecutionProvider::LoadEPContexModelFromFile() const { + // XXX: should "p_ep_ctx_model_" be checked or not? + if (!p_ep_ctx_model_ && !ep_ctx_model_file_loc_.empty()) { + auto status = Model::Load(ep_ctx_model_file_loc_, *p_ep_ctx_model_proto_); + if (!status.IsOK()) { + ORT_THROW("Loading EP context model failed from ", PathToUTF8String(ep_ctx_model_file_loc_)); + } + p_ep_ctx_model_ = Model::Create(std::move(*p_ep_ctx_model_proto_), ep_ctx_model_file_loc_, nullptr, *GetLogger()); + LOGS_DEFAULT(VERBOSE) << "Loaded EP context model from: " << PathToUTF8String(ep_ctx_model_file_loc_); + } else if (ep_ctx_model_file_loc_.empty()) { + LOGS_DEFAULT(WARNING) << "Cannot load an EP-context model due to bad file path"; + } +} + +void VitisAIExecutionProvider::PrepareEPContextEnablement( + const onnxruntime::GraphViewer& graph_viewer) const { + if (model_path_str_.empty()) { + // TODO: platform dependency (Linux vs Windows). + model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string()); + } + std::string backend_cache_dir, backend_cache_key; + get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode, backend_cache_dir, backend_cache_key, backend_cache_data_); + info_["cacheDir"] = backend_cache_dir; + info_["cacheKey"] = backend_cache_key; + // Create a new model, reusing the graph name, the op-domain-to-opset-version map, + // the op schema registry of the current graph, etc. + p_ep_ctx_model_ = graph_viewer.CreateModel(*GetLogger()); + LOGS_DEFAULT(VERBOSE) << "Container model created"; +} + +void VitisAIExecutionProvider::FulfillEPContextEnablement( + const std::vector& fused_nodes_and_graphs) { + auto& ep_ctx_graph = p_ep_ctx_model_->MainGraph(); + if (!ep_ctx_embed_mode_) { + auto ep_ctx_cache_path_str = GetEPContextCacheFileLocation(ep_ctx_model_file_loc_, model_path_str_); + std::ofstream ep_ctx_cache_ofs(ep_ctx_cache_path_str.c_str(), std::ios::trunc); + if (!ep_ctx_cache_ofs.is_open()) { + ORT_THROW("Failed to open a file to write EP context cache: ", ep_ctx_cache_path_str.c_str()); + } + ep_ctx_cache_ofs.write(backend_cache_data_.c_str(), backend_cache_data_.length()); + if (!ep_ctx_cache_ofs.good()) { + ep_ctx_cache_ofs.close(); + ORT_THROW("Exception writing EP context cache file: ", ep_ctx_cache_path_str.c_str()); + } + ep_ctx_cache_ofs.close(); + CreateEPContexNodes(&ep_ctx_graph, fused_nodes_and_graphs, "", PathToUTF8String(ep_ctx_cache_path_str), 0, info_.at("cacheDir"), info_.at("cacheKey"), false, GetLogger()); + } else { + CreateEPContexNodes(&ep_ctx_graph, fused_nodes_and_graphs, backend_cache_data_, "", 1, info_["cacheDir"], info_["cacheKey"], false, GetLogger()); + } + if (GraphHasEPContextNode(ep_ctx_graph)) { + LOGS_DEFAULT(VERBOSE) << "Created model has EP context nodes"; + } else { + LOGS_DEFAULT(WARNING) << "No EP eontext nodes created"; + } +} + std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { - if (graph.IsSubgraph()) { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + bool is_ep_ctx_model = GraphHasEPContextNode(graph_viewer.GetGraph()); + // TODO: platform dependency (Linux vs Windows). + model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string()); + if (GetEPContextModelFileLocation( + ep_ctx_model_path_cfg_, model_path_str_, is_ep_ctx_model, ep_ctx_model_file_loc_)) { + if (is_ep_ctx_model) { + LOGS_DEFAULT(VERBOSE) << "An EP context model passed in"; + ValidateEPContextNode(graph_viewer.GetGraph()); + std::string cache_dir, cache_key; + RetrieveBackendCacheInfo(graph_viewer.GetGraph(), cache_dir, cache_key); + info_["cacheDir"] = cache_dir; + info_["cacheKey"] = cache_key; + LOGS_DEFAULT(VERBOSE) << "Trying getting compilation cache from " << PathToUTF8String(ep_ctx_model_file_loc_); + auto ep_ctx_payload = RetrieveEPContextCache(graph_viewer.GetGraph(), ep_ctx_model_file_loc_, false); + restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string()); + } else { + if (fs::exists(ep_ctx_model_file_loc_) && fs::is_regular_file(ep_ctx_model_file_loc_) && ep_ctx_enabled_) { + ORT_THROW("The inference session was created with a normal ONNX model but a model file with EP context cache exists at ", + PathToUTF8String(ep_ctx_model_file_loc_), ". Please remove the EP context model manually if you want to re-generate it."); + // Disable the flexibility implemented below by throwing an exception. + // Now the code below is unreachable but DCE will take care of it. + // We might want to re-enable it in future, so we keep it as is. + LoadEPContexModelFromFile(); + ValidateEPContextNode(p_ep_ctx_model_->MainGraph()); + std::string cache_dir, cache_key; + RetrieveBackendCacheInfo(p_ep_ctx_model_->MainGraph(), cache_dir, cache_key); + info_["cacheDir"] = cache_dir; + info_["cacheKey"] = cache_key; + auto ep_ctx_payload = RetrieveEPContextCache(p_ep_ctx_model_->MainGraph(), ep_ctx_model_file_loc_, false); + restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string()); + } + } + } else { + LOGS_DEFAULT(WARNING) << "Failed to get EP context model file location"; + } + + if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; } @@ -40,13 +167,16 @@ std::vector> VitisAIExecutionProvider::GetCap // Only compiling a model once is currently supported return {}; } - execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); - auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); + execution_providers_ = std::make_unique(compile_onnx_model(graph_viewer, *GetLogger(), info_)); + auto result = vaip::GetComputeCapabilityOps(graph_viewer, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { - result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph, ep.get(), index)); + result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph_viewer, ep.get(), index)); index = index + 1; } + if (ep_ctx_enabled_ && !is_ep_ctx_model) { + PrepareEPContextEnablement(graph_viewer); + } return result; } @@ -74,6 +204,10 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector #include #include #include #include +// 1st-party headers/libs. +// #include "core/framework/session_options.h" #include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" +#include "core/common/inlined_containers_fwd.h" // we cannot include vaip/vaip.hpp here because header file referred by // onnxruntime_pybind_state_common.cc @@ -24,9 +28,11 @@ namespace onnxruntime { class VitisAIExecutionProvider : public IExecutionProvider { public: explicit VitisAIExecutionProvider(const ProviderOptions& info); + // explicit VitisAIExecutionProvider(const ProviderOptions& info, + // const SessionOptions* p_sess_opts = nullptr); ~VitisAIExecutionProvider() = default; - std::vector> GetCapability(const onnxruntime::GraphViewer& graph, + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } @@ -35,16 +41,34 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; + // This method is called after both `GetComputeCapabilityOps()` and `Compile()`. + // This timing is required to work with both compliation-based EPs and non-compilation-based EPs. + const InlinedVector GetEpContextNodes() const override; + private: void CreateKernelRegistry(); using my_ep_t = vaip_core::DllSafe>>; using my_ep_uptr_t = std::shared_ptr; // we have to hide the implementation by forward declaration. mutable my_ep_uptr_t execution_providers_; - ProviderOptions info_; + mutable ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; std::set vitisai_optypes_; + // EP context related. + bool ep_ctx_enabled_ = false; + bool ep_ctx_embed_mode_ = true; + std::string ep_ctx_model_path_cfg_{""}; + mutable std::string backend_cache_data_{""}; + mutable PathString model_path_str_{}; + mutable PathString ep_ctx_model_file_loc_{}; + mutable std::unique_ptr p_ep_ctx_model_; + mutable std::unique_ptr p_ep_ctx_model_proto_; + // It might need to be called before loading + // the EP context model that is compiled AOT/offline. + void LoadEPContexModelFromFile() const; + void PrepareEPContextEnablement(const onnxruntime::GraphViewer&) const; + void FulfillEPContextEnablement(const std::vector&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h new file mode 100644 index 0000000000000..7285e4d02c8e9 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h @@ -0,0 +1,131 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ReluOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Relu Activation."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +class SigmoidOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Sigmoid Activation."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +class TanhOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Tanh activation."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class LeakyReluOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating LeakyRelu activation."; + const auto& node = node_unit.GetNode(); + NodeAttrHelper helper(node); + auto alpha = helper.Get("alpha", 1.0f); + auto op = + graph_ep->GetGraph()->CreateOperation(alpha); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class EluOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Elu activation."; + const auto& node = node_unit.GetNode(); + NodeAttrHelper helper(node); + auto alpha = helper.Get("alpha", 1.0f); + auto op = + graph_ep->GetGraph()->CreateOperation(alpha); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class HardSigmoidOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating HardSigmoid activation."; + const auto& node = node_unit.GetNode(); + NodeAttrHelper helper(node); + auto alpha = helper.Get("alpha", 1.0f); + auto beta = helper.Get("beta", 1.0f); + auto op = graph_ep->GetGraph()->CreateOperation( + alpha, beta); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc new file mode 100644 index 0000000000000..653940fa5e611 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc @@ -0,0 +1,215 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +bool BaseOpBuilder::IsSupported(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) const { + auto initializers = graph_viewer.GetAllInitializedTensors(); + if (!HasSupportedOpSet(node_unit)) { + return false; + } + if (!HasSupportedInputOutputs(initializers, node_unit)) { + return false; + } + return IsOpSupported(graph_viewer, &node_unit.GetNode()); +} + +bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const { + // We do not support unknown(null) input shape + auto has_supported_shape = [](const NodeArg& node_arg, const std::string& name, const std::string& op_type) { + const auto* shape_proto = node_arg.Shape(); + if (!shape_proto) { + LOGS_DEFAULT(WARNING) << "Node [" << name << "] type [" << op_type + << "] Input [" << node_arg.Name() << "] has no shape"; + return false; + } + + // We do not support dynamic shape input yet, but resize op's second input can be empty + for (const auto& dim : shape_proto->dim()) { + if (!dim.has_dim_value()) { + LOGS_DEFAULT(WARNING) << "Dynamic shape is not supported for now, for input:" << node_arg.Name(); + return false; + } + if (dim.dim_value() == 0 && op_type != "Resize") { + LOGS_DEFAULT(WARNING) << "Zero in shape is not supported for now, for input:" << node_arg.Name(); + return false; + } + } + return true; + }; + + auto has_initialized_quant_param = [](const NodeArg& arg, const InitializedTensorSet& initializers) { + auto it = initializers.find(arg.Name()); + if (it == initializers.end()) { + LOGS_DEFAULT(WARNING) << "The quantization param must be an initializer tensor"; + return false; + } + return true; + }; + + for (const auto& input : node_unit.Inputs()) { + if (!input.node_arg.Exists()) { + continue; + } + if (!has_supported_shape(input.node_arg, node_unit.Name(), node_unit.OpType())) + return false; + + if (input.quant_param.has_value()) { + if (!has_supported_shape(input.quant_param->scale, node_unit.Name(), node_unit.OpType())) + return false; + + if (!has_initialized_quant_param(input.quant_param->scale, initializers)) + return false; + // zero point is optional + if (input.quant_param->zero_point) { + if (!has_supported_shape(*input.quant_param->zero_point, node_unit.Name(), node_unit.OpType())) + return false; + if (!has_initialized_quant_param(*input.quant_param->zero_point, initializers)) + return false; + if (input.quant_param->zero_point->Type() != input.node_arg.Type()) { + LOGS_DEFAULT(ERROR) << "Invalid input type because the data type mismatch with its' quant param type."; + return false; + } + } + } + } + for (const auto& output : node_unit.Outputs()) { + for (const auto& dim : output.node_arg.Shape()->dim()) { + if (!dim.has_dim_value()) { + LOGS_DEFAULT(WARNING) << "Dynamic shape is not supported for now, for output:" << output.node_arg.Name(); + return false; + } + if (dim.dim_value() == 0 && output.node_arg.Shape()->dim_size() > 1) { + LOGS_DEFAULT(WARNING) << "Zero in shape is not supported for now, for output:" << output.node_arg.Name(); + return false; + } + } + if (output.quant_param.has_value()) { + if (!has_supported_shape(output.quant_param->scale, node_unit.Name(), node_unit.OpType())) + return false; + + if (!has_initialized_quant_param(output.quant_param->scale, initializers)) + return false; + // zero point is optional + if (output.quant_param->zero_point) { + if (!has_supported_shape(*output.quant_param->zero_point, node_unit.Name(), node_unit.OpType())) + return false; + if (!has_initialized_quant_param(*output.quant_param->zero_point, initializers)) + return false; + } + } + } + return HasSupportedInputOutputsImpl(initializers, node_unit); +} + +bool BaseOpBuilder::HasSupportedInputOutputsImpl( + const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit) const { + // Check input/output data type, int64 is generally unsupported + // specific op builder can override this if the int64 input corresponds to VSINPU param + for (const auto& input : node_unit.Inputs()) { + auto input_type = input.node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&input.node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + } + for (const auto& output : node_unit.Outputs()) { + auto output_type = output.node_arg.Type(); + if (*output_type == "tensor(int64)" || !util::IsTypeSupported(&output.node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported output type : " + << *output_type; + return false; + } + } + return true; +} + +bool BaseOpBuilder::HasSupportedOpSet(const NodeUnit& node_unit) const { + auto since_version = node_unit.SinceVersion(); + if (since_version < GetMinSupportedOpSet(node_unit) || since_version > GetMaxSupportedOpSet(node_unit)) { + LOGS_DEFAULT(VERBOSE) << node_unit.OpType() << " opset [" << since_version + << "] is only supported for opset [" + << GetMinSupportedOpSet(node_unit) << ", " + << GetMaxSupportedOpSet(node_unit) << "]"; + return false; + } + + return true; +} + +bool BaseOpBuilder::BuildOp(vsi::npu::GraphEP* graph_ep, + const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) { + std::vector> inputs; + std::vector input_defs = node_unit.Inputs(); + std::vector output_defs = node_unit.Outputs(); + + for (const auto input_def : input_defs) { + auto it = std::find_if( + graph_ep->GetGraphInputs().begin(), graph_ep->GetGraphInputs().end(), + [input_def](const std::shared_ptr& info) { + return info->name == input_def.node_arg.Name(); + }); + tim::vx::TensorAttribute attr; + if (graph_viewer.IsConstantInitializer(input_def.node_arg.Name(), true)) { + attr = tim::vx::TensorAttribute::CONSTANT; + } else if (it == graph_ep->GetGraphInputs().end()) { + attr = tim::vx::TensorAttribute::TRANSIENT; + } else { + attr = tim::vx::TensorAttribute::INPUT; + } + + auto tensor = graph_ep->MapTIMVXTensor(graph_ep->GetGraph(), input_def, node_unit, + &graph_viewer, attr); + inputs.push_back(tensor); + } + + std::vector> outputs; + + for (auto output_def : output_defs) { + auto it = std::find_if( + graph_ep->GetGraphOutputs().begin(), graph_ep->GetGraphOutputs().end(), + [output_def](const std::shared_ptr& info) { + return info->name == output_def.node_arg.Name(); + }); + tim::vx::TensorAttribute attribute = + it == graph_ep->GetGraphOutputs().end() + ? tim::vx::TensorAttribute::TRANSIENT + : tim::vx::TensorAttribute::OUTPUT; + auto tensor = graph_ep->MapTIMVXTensor(graph_ep->GetGraph(), output_def, node_unit, + &graph_viewer, attribute); + outputs.push_back(tensor); + } + return HandleBuildOp(graph_ep, inputs, outputs, node_unit); +} +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h new file mode 100644 index 0000000000000..692dbc833f0a7 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h @@ -0,0 +1,75 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include "core/providers/vsinpu/builders/op_builder.h" +#include "core/providers/vsinpu/vsinpu_ep_graph.h" +#include "core/providers/vsinpu/vsinpu_util.h" +#include "tim/vx/operation.h" +#include "tim/vx/ops.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class BaseOpBuilder : public IOpBuilder { + public: + virtual ~BaseOpBuilder() = default; + + bool IsSupported(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) const override; + bool BuildOp(vsi::npu::GraphEP* graph_ep, + const onnxruntime::GraphViewer& graph_viewer, const NodeUnit& node_unit) override; + virtual bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const { + return true; + } + + virtual bool IsQuantizedOp(const NodeUnit& /* node_unit */) const { return false; } + + virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } + virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 22; } + + virtual bool HasSupportedInputOutputsImpl( + const InitializedTensorSet& initializers, const NodeUnit& node_unit) const; + + // TODO(cfy): Check if this node_unit's type is supported + virtual bool IsNodeUnitTypeSupported(const NodeUnit& node_unit) const { return true; } + + virtual bool HandleBuildOp( + vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) { + return true; + } + + private: + bool HasSupportedOpSet(const NodeUnit& node_unit) const; + bool HasSupportedInputOutputs(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const; +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h new file mode 100644 index 0000000000000..57b080acdce4e --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h @@ -0,0 +1,48 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class CastOpBuilder : public BaseOpBuilder { + protected: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, std::vector>& inputs, + std::vector>& outputs, const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Cast Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc new file mode 100644 index 0000000000000..85096d0e262d7 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -0,0 +1,115 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include "core/providers/vsinpu/builders/impl/clip_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +namespace clip_internal { +template +struct LowMax { + constexpr static T low() { + return std::numeric_limits::lowest(); + } + constexpr static T max() { + return std::numeric_limits::max(); + } +}; +} // namespace clip_internal + +template +struct ClipOpBuilder::ClipImpl { + ClipImpl(vsi::npu::GraphEP* graph_ep, std::vector>& inputs, + std::vector>& outputs) { + T min_default = clip_internal::LowMax::low(); + T max_default = clip_internal::LowMax::max(); + + T* min_data = &min_default; + T* max_data = &max_default; + std::shared_ptr min_tensor = nullptr; + std::shared_ptr max_tensor = nullptr; + if (inputs.size() > 1) { + min_tensor = inputs[1]; + if (inputs.size() > 2) { + max_tensor = inputs[2]; + } + } + if (min_tensor) { + min_tensor->CopyDataFromTensor(min_data); + } + if (max_tensor) { + max_tensor->CopyDataFromTensor(max_data); + } + auto op = graph_ep->GetGraph()->CreateOperation( + static_cast(*min_data), static_cast(*max_data)); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } +}; + +bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) { + LOGS_DEFAULT(INFO) << "Creating Clip Op."; + if (node_unit.SinceVersion() <= 6) { + NodeAttrHelper helper(node_unit.GetNode()); + auto min = helper.Get("min", -3.402e+38f); + auto max = helper.Get("max", 3.402e+38f); + auto op = graph_ep->GetGraph()->CreateOperation(min, max); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } else { + switch (inputs[0]->GetDataType()) { + case tim::vx::DataType::INT8: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::UINT8: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::INT16: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::INT32: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::FLOAT16: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::FLOAT32: + default: + ClipImpl(graph_ep, inputs, outputs); + break; + } + } + return true; +} + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h new file mode 100644 index 0000000000000..9e5570645126e --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h @@ -0,0 +1,58 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ClipOpBuilder final : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + if (node->SinceVersion() > 6) { + if (node->InputDefs().size() > 1 && + !Contains(graph_viewer.GetAllInitializedTensors(), node->InputDefs()[1]->Name())) { + LOGS_DEFAULT(WARNING) << "Min/Max value must be const input or attribute."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override; + + private: + template + struct ClipImpl; +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h new file mode 100644 index 0000000000000..87c8c70247bee --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h @@ -0,0 +1,66 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ConcatOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", 0); + auto input_defs = node->InputDefs(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(ERROR) << "Axis is invalid in Concat."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Concat Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(static_cast(axis), inputs.size()); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h new file mode 100644 index 0000000000000..3ed432c2efa1c --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h @@ -0,0 +1,163 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class ConvOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + if (shape.NumDimensions() == 5) { + LOGS_DEFAULT(WARNING) << "Not support conv3d yet."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + auto input_tensor = inputs[0]; + auto weight_tensor = inputs[1]; + auto OChannel_idx = weight_tensor->GetShape().size() - 1; + const bool is_1d_conv = + weight_tensor->GetShape().size() == 3 ? true : false; + NodeAttrHelper helper(node_unit.GetNode()); + auto padtype = helper.Get("auto_pad", std::string("")); + auto group = helper.Get("group", static_cast(1)); + + std::string op_type = (group != 1 && group == weight_tensor->GetShape()[OChannel_idx]) + ? "DepthwiseConv" + : (group != 1) ? "GroupConv" + : "Conv"; + op_type += is_1d_conv ? "1D" : "2D"; + std::string op_name = std::string("Creating ") + op_type + " Op"; + LOGS_DEFAULT(INFO) << op_name; + + uint32_t default_uint = 1; + std::vector default_vec = {1, 1}; + + auto stride = + helper.Get("strides", is_1d_conv ? std::vector{default_uint} + : default_vec); + auto dilation = + helper.Get("dilations", is_1d_conv ? std::vector{default_uint} + : default_vec); + + std::shared_ptr op; + if (padtype != "NOTSET") { // array "pads" is not set + if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { + if (is_1d_conv) { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), stride[0], + dilation[0], group, tim::vx::DataLayout::WCN, + tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), + /* W_stride, H_stride*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } else { + int32_t multiplier = group == 1 + ? 0 + : weight_tensor->GetShape()[OChannel_idx] / input_tensor->GetShape()[OChannel_idx - 1]; + if (is_1d_conv) { + op = graph_ep->GetGraph()->CreateOperation( + vsi::npu::util::GetPadType(padtype), stride[0], dilation[0], multiplier, + tim::vx::DataLayout::WCN, tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph()->CreateOperation( + vsi::npu::util::GetPadType(padtype), + /* W_stride, H_stride*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } + } else { + auto pads = helper.Get("pads", std::vector{0U, 0U}); + if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { + if (is_1d_conv) { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), + std::array{pads[0], pads[1]}, stride[0], + dilation[0], group, tim::vx::DataLayout::WCN, + tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph() + ->CreateOperation( + /* W_begin,W_end, H_begin,H_end*/ std::array< + uint32_t, 4>{pads[1], pads[3], pads[0], pads[2]}, + /* W_stride, H_stide*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } else { + int32_t multiplier = group == 1 + ? 0 + : weight_tensor->GetShape()[OChannel_idx] / input_tensor->GetShape()[OChannel_idx - 1]; + if (is_1d_conv) { + op = graph_ep->GetGraph()->CreateOperation( + std::array{pads[0], pads[1]}, stride[0], dilation[0], + multiplier, tim::vx::DataLayout::WCN, tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph()->CreateOperation( + /* W_begin,W_end, H_begin,H_end*/ std::array{pads[1], pads[3], + pads[0], pads[2]}, + /* W_stride, H_stride*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } + } + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h new file mode 100644 index 0000000000000..5551ff2a0d31a --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h @@ -0,0 +1,84 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class DequantizeLinearOpBuilder : public BaseOpBuilder { + enum DequantizeINPUTS { + input_tensor = 0, + scale_tensor = 1, + zero_point_tensor = 2 + }; + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (!node_unit.Inputs()[0].quant_param.has_value()) { + LOGS_DEFAULT(WARNING) << "The quantization params must be known."; + return false; + } + if (node_unit.Inputs()[0].quant_param->scale.Shape()->dim_size() != 0 && + node_unit.Inputs()[0].quant_param->scale.Shape()->dim(0).dim_value() != 1) { + LOGS_DEFAULT(WARNING) << "Per channel quantized input is not support in DequantizeLinear op."; + return false; + } + return true; + } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + if (helper.HasAttr("block_size") && helper.Get("block_size", 0) != 0) { + LOGS_DEFAULT(WARNING) << "Not support block quantization yet."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Dequantize Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/dropout_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/dropout_op_builder.h new file mode 100644 index 0000000000000..bd8fa0bc587cf --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/dropout_op_builder.h @@ -0,0 +1,81 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class DropoutOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + if (node_unit.Inputs().size() > 2) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = + initializers.at(node_unit.Inputs()[2].node_arg.Name()); + std::vector training_mode(1); + auto status = onnxruntime::utils::UnpackTensor( + *tensor_proto, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().data() : nullptr, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().size() : 0, + training_mode.data(), training_mode.size()); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Failed to get data training mode tensor."; + return false; + } + if (training_mode[0] == true) { + LOGS_DEFAULT(WARNING) << "Only support inference typed dropout now."; + return false; + } + } + if (node_unit.Inputs().size() > 1) return false; + return true; + } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + if (helper.HasAttr("seed")) { + LOGS_DEFAULT(WARNING) << "Not support seed in Dropout op."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating DropOut Op."; + auto op = graph_ep->GetGraph()->CreateOperation(1.0); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h new file mode 100644 index 0000000000000..df2e429f58b2f --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h @@ -0,0 +1,99 @@ + +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +#define ELEMENTWISE_OP_BUILDER(onnx_op_type, vsinpu_op_kind) \ + class onnx_op_type##OpBuilder : public BaseOpBuilder { \ + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, \ + const Node* node) const override { \ + for (auto input : node->InputDefs()) { \ + if (*input->Type() == "tensor(int64)") { \ + LOGS_DEFAULT(WARNING) << "Int64 type is not supported as elementwise operation input."; \ + return false; \ + } \ + } \ + return true; \ + } \ + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, \ + std::vector>& inputs, \ + std::vector>& outputs, \ + const NodeUnit& node_unit) override { \ + LOGS_DEFAULT(INFO) << "Creating " << #onnx_op_type << " Op"; \ + auto op = graph_ep->GetGraph() -> CreateOperation(); \ + (*op).BindInputs(inputs).BindOutputs(outputs); \ + return true; \ + ; \ + } \ + }; + +ELEMENTWISE_OP_BUILDER(Add, Add); +ELEMENTWISE_OP_BUILDER(Sub, Sub); +ELEMENTWISE_OP_BUILDER(Mul, Multiply); +ELEMENTWISE_OP_BUILDER(Div, Div); // not consider zero +ELEMENTWISE_OP_BUILDER(Abs, Abs); +ELEMENTWISE_OP_BUILDER(Sqrt, Sqrt); +ELEMENTWISE_OP_BUILDER(Exp, Exp); +ELEMENTWISE_OP_BUILDER(Floor, Floor); +ELEMENTWISE_OP_BUILDER(Log, Log); +ELEMENTWISE_OP_BUILDER(Sin, Sin); +ELEMENTWISE_OP_BUILDER(HardSwish, HardSwish); + +class PowOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input0_type = *node->InputDefs()[0]->Type(); + auto input1_type = *node->InputDefs()[1]->Type(); + if (input0_type != input1_type) { + if ((input0_type == "tensor(float)" && input1_type == "tensor(int32)") || + (input0_type == "tensor(int32)" && input1_type == "tensor(float)")) { + LOGS_DEFAULT(WARNING) << "Pow op does not support one of input is float32 while the other one is int32 type."; + return false; + } + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Pow Op"; + auto op = graph_ep->GetGraph() + ->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h new file mode 100644 index 0000000000000..fa8ccfc96a730 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h @@ -0,0 +1,66 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class FlattenOpBuilder : public BaseOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Flatten Op."; + std::vector reshape_param; + if (outputs[0]->GetShape().size() == 2) { + reshape_param = outputs[0]->GetShape(); + } else { + auto input_shape = inputs[0]->GetShape(); + NodeAttrHelper helper(node_unit.GetNode()); + int64_t axis = helper.Get("axis", 1); + axis = util::ReverseAxis(static_cast(axis), input_shape.size()); + uint32_t first_dim = 1; + for (int64_t i = 0; i < axis; i++) { + first_dim *= inputs[0]->GetShape()[i]; + } + uint32_t second_dim = inputs[0]->GetSpec().GetElementNum() / first_dim; + reshape_param.push_back(first_dim); + reshape_param.push_back(second_dim); + } + auto op = graph_ep->GetGraph()->CreateOperation(reshape_param); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h new file mode 100644 index 0000000000000..bd91bbd81e1fa --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h @@ -0,0 +1,87 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class GatherOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input = node_unit.Inputs()[0]; + auto indices = node_unit.Inputs()[1]; + if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&indices.node_arg)) { + if (*input.node_arg.Type() == "tensor(int64)") { + LOGS_DEFAULT(WARNING) << "Only support indices tensor to be int64 type in gather op."; + return false; + } + if (*indices.node_arg.Type() != "tensor(int64)" && *indices.node_arg.Type() != "tensor(int32)") { + LOGS_DEFAULT(WARNING) << "Unsupported indices tensor type in gather op."; + return false; + } + if (*indices.node_arg.Type() == "tensor(int64)" && !Contains(initializers, indices.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Only support const attribute if indice tensor is in int64 type."; + return false; + } + return true; + } + return false; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Gather Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(axis, 0); + + bool is_i64_indices = inputs[1]->GetDataType() == tim::vx::DataType::INT64; + if (!is_i64_indices) { + (*op).BindInputs(inputs).BindOutputs(outputs); + } else { + std::vector origin_data(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(origin_data.data()); + std::vector transformed_data(origin_data.begin(), origin_data.end()); + tim::vx::TensorSpec ts = inputs[1]->GetSpec().SetAttribute(tim::vx::TensorAttribute::INPUT); + ts.SetDataType(tim::vx::DataType::INT32); + auto transformed_indices = graph_ep->GetGraph()->CreateTensor(ts, transformed_data.data()); + (*op).BindInput(inputs[0]).BindInput(transformed_indices).BindOutput(outputs[0]); + } + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h new file mode 100644 index 0000000000000..07388f862ae1d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h @@ -0,0 +1,149 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class GemmOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + NodeAttrHelper helper(*node); + auto weight_units = helper.Get("transB", 0) == 1 + ? vsi::npu::util::GetTensorShape(*input_defs[1]).GetDims()[0] + : vsi::npu::util::GetTensorShape(*input_defs[1]).GetDims()[1]; + if (input_defs.size() > 2) { + auto bias_shape = vsi::npu::util::GetTensorShape(*input_defs[2]); + if (bias_shape.NumDimensions() == 1 && bias_shape.GetDims()[0] != weight_units) { + LOGS_DEFAULT(WARNING) << "Not support to broadcast bias shape."; + return false; + } else if (bias_shape.NumDimensions() == 2 && + (bias_shape.Size() != weight_units || + (bias_shape.GetDims()[0] != 1 && bias_shape.GetDims()[1] != 1))) { + LOGS_DEFAULT(WARNING) << "Not support 2-dims bias shape."; + return false; + } + + if (*input_defs[2]->Type() == "tensor(float16)" && + !graph_viewer.IsConstantInitializer(input_defs[2]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Not support f16 bias with input attr."; + return false; + } + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Gemm Op."; + auto input_A = inputs[0]; + auto input_B = inputs[1]; + NodeAttrHelper helper(node_unit.GetNode()); + + auto trans_A = helper.Get("transA", 0); + auto trans_B = helper.Get("transB", 0); + const bool has_alpha = (helper.Get("alpha", 1.0f) != 1.0); + const bool has_beta = (helper.Get("beta", 1.0f) != 1.0); + const bool has_C = (inputs.size() == 3); + auto weight_units = helper.Get("transB", 0) == 1 ? inputs[1]->GetShape()[1] : inputs[1]->GetShape()[0]; + + tim::vx::TensorSpec coef_spec(tim::vx::DataType::FLOAT32, {1}, + tim::vx::TensorAttribute::CONSTANT); + + auto multiply_impl = [&](std::shared_ptr input, + std::shared_ptr coef, + std::shared_ptr output) { + auto multiply_op = graph_ep->GetGraph()->CreateOperation(); + (*multiply_op).BindInput(input).BindInput(coef).BindOutput(output); + graph_ep->GetOps().push_back(multiply_op); + }; + + auto transpose_impl = [&](std::shared_ptr input, + std::shared_ptr output) { + std::vector perm = {1U, 0U}; + auto transpose_op = graph_ep->GetGraph()->CreateOperation(perm); + (*transpose_op).BindInput(input).BindOutput(output); + graph_ep->GetOps().push_back(std::move(transpose_op)); + }; + + auto fc_impl = [&](std::vector> inputs, + std::shared_ptr output) { + auto fc_op = graph_ep->GetGraph()->CreateOperation(0, weight_units); + (*fc_op).BindInputs(inputs).BindOutput(output); + graph_ep->GetOps().push_back(std::move(fc_op)); + }; + + auto alpha_A = input_A; + std::shared_ptr beta_C; + auto final_A = input_A; + auto final_B = input_B; + + if (has_alpha) { + auto alpha_tensor = graph_ep->GetGraph()->CreateTensor(coef_spec); + auto alpha = helper.Get("alpha", 1.0f); + alpha_tensor->CopyDataToTensor(&alpha); + alpha_A = graph_ep->GetGraph()->CreateTensor( + input_A->GetSpec().AsTransientSpec()); + multiply_impl(input_A, alpha_tensor, alpha_A); + final_A = alpha_A; + } + if (has_beta) { + auto beta_tensor = graph_ep->GetGraph()->CreateTensor(coef_spec); + auto beta = helper.Get("beta", 1.0f); + beta_tensor->CopyDataToTensor(&beta); + beta_C = graph_ep->GetGraph()->CreateTensor( + inputs[2]->GetSpec().AsTransientSpec()); + multiply_impl(inputs[2], beta_tensor, beta_C); + } else if (has_C) { + beta_C = inputs[2]; + } + + if (trans_A) { + final_A = graph_ep->GetGraph()->CreateTensor( + input_A->GetSpec().AsTransientSpec()); + transpose_impl(alpha_A, final_A); + } + if (!trans_B) { + final_B = graph_ep->GetGraph()->CreateTensor( + input_B->GetSpec().AsTransientSpec()); + transpose_impl(input_B, final_B); + } + std::vector> fc_inputs = {final_A, final_B}; + + if (has_C) fc_inputs.push_back(beta_C); + fc_impl(fc_inputs, outputs[0]); + + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h new file mode 100644 index 0000000000000..f439736bde09f --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h @@ -0,0 +1,57 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class MatMulOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto output_defs = node->OutputDefs(); + if (output_defs[0]->Shape()->dim_size() == 0) { + LOGS_DEFAULT(WARNING) << "Inner product of 1-D tensor is not supported in MatMul op."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Matmul Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h new file mode 100644 index 0000000000000..085850e8a8b5d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h @@ -0,0 +1,87 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class BatchNormOpBuilder : public BaseOpBuilder { + enum NormINPUTS { + input_tensor = 0, + scale_tensor = 1, + Bias_tensor = 2, + mean_tensor = 3, + var_tensor = 4 + }; + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 9; } + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + NodeAttrHelper helper(*node); + auto training_mode = helper.Get("training_mode", 0); + if (training_mode) { + LOGS_DEFAULT(WARNING) << "Training is not supported in batch_norm op."; + return false; + } + if (helper.HasAttr("spatial")) { + LOGS_DEFAULT(WARNING) << "VSINPU does not support 'spatial' parameter."; + return false; + } + if (!graph_viewer.IsConstantInitializer(input_defs[NormINPUTS::scale_tensor]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Not support mean/var/gamma/beta set as dynamic input yet."; + return false; + } + + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating BatchNorm Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto epsilon = helper.Get("epsilon", 1e-5f); + auto op = graph_ep->GetGraph()->CreateOperation(epsilon); + std::vector> reordered_inputs; + int indices[] = {NormINPUTS::input_tensor, NormINPUTS::mean_tensor, NormINPUTS::var_tensor, + NormINPUTS::scale_tensor, NormINPUTS::Bias_tensor}; + for (int i : indices) { + reordered_inputs.push_back(inputs[i]); + } + (*op).BindInputs(reordered_inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h new file mode 100644 index 0000000000000..34339875cd6e0 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h @@ -0,0 +1,153 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class BasePoolOpBuilder : public BaseOpBuilder { + public: + explicit BasePoolOpBuilder(tim::vx::PoolType pool_type) : pool_type_(pool_type) {} + + protected: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const override { + auto shape = vsi::npu::util::GetTensorShape(*node->InputDefs()[0]); + if (shape.NumDimensions() == 5) { + LOGS_DEFAULT(WARNING) << "3DPool is not supported yet."; + return false; + } + + NodeAttrHelper helper(*node); + if (helper.HasAttr("dilations")) { + LOGS_DEFAULT(WARNING) << "NonMaxPool with Dilation parameter is not supported."; + return false; + } + return true; + } + bool CreatePoolingOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const std::array& kernel_size, + const std::array& strides, + const std::array& pads, + bool is_global, + const tim::vx::RoundType ceil_mode) { + const bool is_1d_pool = inputs[0]->GetShape().size() == 3; + std::shared_ptr op; + + // Create the appropriate pooling operation + if (is_global) { + if (is_1d_pool) { + op = graph_ep->GetGraph()->CreateOperation(pool_type_, inputs[0]->GetShape()[0], + ceil_mode); + } else { + std::array input_size = {inputs[0]->GetShape()[0], inputs[0]->GetShape()[1]}; + op = graph_ep->GetGraph()->CreateOperation(pool_type_, input_size, ceil_mode); + } + + } else { + if (is_1d_pool) { + std::array arr = {pads[2], pads[0]}; + op = graph_ep->GetGraph()->CreateOperation(pool_type_, arr, + kernel_size[1], strides[1], ceil_mode); + } else { + op = graph_ep->GetGraph()->CreateOperation(pool_type_, pads, kernel_size, + strides, ceil_mode); + } + } + + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } + tim::vx::PoolType pool_type_; +}; + +class TraditionalPoolOpBuilder : public BasePoolOpBuilder { + public: + TraditionalPoolOpBuilder() : BasePoolOpBuilder(tim::vx::PoolType::MAX) {} + + protected: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + NodeAttrHelper helper(node_unit.GetNode()); + auto ksize = helper.Get("kernel_shape", std::vector{1U, 1U}); + auto strides = helper.Get("strides", std::vector{1U, 1U}); + auto pads = helper.Get("pads", std::vector{0U, 0U, 0U, 0U}); + tim::vx::RoundType ceil_mode = helper.Get("ceil_mode", 0U) == 0 + ? tim::vx::RoundType::FLOOR + : tim::vx::RoundType::CEILING; + return CreatePoolingOp(graph_ep, inputs, outputs, {ksize[1], ksize[0]}, {strides[1], strides[0]}, + {pads[1], pads[3], pads[0], pads[2]}, false, ceil_mode); + } +}; + +class GlobalPoolOpBuilder : public BasePoolOpBuilder { + public: + GlobalPoolOpBuilder() : BasePoolOpBuilder(tim::vx::PoolType::MAX) {} + + protected: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + NodeAttrHelper helper(node_unit.GetNode()); + tim::vx::RoundType ceil_mode = helper.Get("ceil_mode", 0U) == 0 + ? tim::vx::RoundType::FLOOR + : tim::vx::RoundType::CEILING; + return CreatePoolingOp(graph_ep, inputs, outputs, {}, {}, {}, true, ceil_mode); + } +}; + +class GlobalAveragePoolOpBuilder : public GlobalPoolOpBuilder { + public: + GlobalAveragePoolOpBuilder() { pool_type_ = tim::vx::PoolType::AVG; } +}; + +class GlobalMaxPoolOpBuilder : public GlobalPoolOpBuilder { + public: + GlobalMaxPoolOpBuilder() { pool_type_ = tim::vx::PoolType::MAX; } +}; + +class AveragePoolOpBuilder : public TraditionalPoolOpBuilder { + public: + AveragePoolOpBuilder() { pool_type_ = tim::vx::PoolType::AVG; } +}; + +class MaxPoolOpBuilder : public TraditionalPoolOpBuilder { + public: + MaxPoolOpBuilder() { pool_type_ = tim::vx::PoolType::MAX; } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h new file mode 100644 index 0000000000000..0c09d2183420c --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h @@ -0,0 +1,86 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class BaseQLinearOpBuilder : public BaseOpBuilder { + enum { + INPUT_A = 0, + INPUT_A_SCALE = 1, + INPUT_A_ZP = 2, + INPUT_B = 3, + INPUT_B_SCALE = 4, + INPUT_B_ZP = 5, + OUTPUT_SCALE = 6, + OUTPUT_ZP = 7, + }; + + protected: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const override { + for (int i = 0; i < node->InputDefs().size(); i++) { + if (i == INPUT_A || i == INPUT_B) continue; + if (!graph_viewer.IsConstantInitializer(node->InputDefs()[i]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Only support const scale / zero point."; + return false; + } + } + return true; + } +}; + +class QLinearAddOpBuilder : public BaseQLinearOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearAdd Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class QLinearMulOpBuilder : public BaseQLinearOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearMul Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h new file mode 100644 index 0000000000000..4789c06855921 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h @@ -0,0 +1,49 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class QLinearConcatOpBuilder : public BaseOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, std::vector>& inputs, + std::vector>& outputs, const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearConcat Op."; + NodeAttrHelper helper(node_unit.GetNode()); + int axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(axis, inputs.size()); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h new file mode 100644 index 0000000000000..a5ad42d3c41f4 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h @@ -0,0 +1,152 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/framework/tensorprotoutils.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class QLinearConvOpBuilder : public BaseOpBuilder { + enum QLinearConvINPUTS { + INPUT_TENSOR = 0, + INPUT_TENSOR_SCALE = 1, + INPUT_TENSOR_ZP = 2, + WEIGHT_TENSOR = 3, + WEIGHT_TENSOR_SCALE = 4, + WEIGHT_TENSOR_ZP = 5, + OUTPUT_TENSOR_SCALE = 6, + OUTPUT_TENSOR_ZP = 7, + BIAS_TENSOR = 8, + }; + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[QLinearConvINPUTS::INPUT_TENSOR]); + auto w_scale_shape = vsi::npu::util::GetTensorShape(*input_defs[QLinearConvINPUTS::WEIGHT_TENSOR_SCALE]); + auto w_shape_dims = vsi::npu::util::GetTensorShape(*input_defs[QLinearConvINPUTS::WEIGHT_TENSOR]).GetDims(); + if (input_shape.NumDimensions() != 4) { + LOGS_DEFAULT(WARNING) << "Not support conv3d&& conv1d yet."; + return false; + } + + if (!graph_viewer.IsConstantInitializer(input_defs[QLinearConvINPUTS::INPUT_TENSOR_SCALE]->Name(), true) || + !graph_viewer.IsConstantInitializer(input_defs[WEIGHT_TENSOR]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Not support quantization definitions or weights that are not constant yet."; + return false; + } + + if (w_shape_dims[2] > 15) { + LOGS_DEFAULT(WARNING) << "Not support weight kernel with height higher than 15."; + return false; + } + + if (w_scale_shape.Size() != 1 && *input_defs[WEIGHT_TENSOR]->Type() == "tensor(int8)") { + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_viewer.GetConstantInitializer(input_defs[QLinearConvINPUTS::WEIGHT_TENSOR_ZP]->Name(), true); + std::vector w_zp(tensor_proto->dims_size() == 0 ? 1 : tensor_proto->dims()[0]); + + auto status = onnxruntime::utils::UnpackTensor( + *tensor_proto, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().data() : nullptr, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().size() : 0, + w_zp.data(), w_zp.size()); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Failed to get data from weight zp tensor."; + return false; + } + if (std::any_of(w_zp.begin(), w_zp.end(), [](int i) { return i != 0; })) { + LOGS_DEFAULT(WARNING) << "Asymmetric perchannel quantization only allows uint8 datatype or int8 with all zero."; + return false; + } + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearConv Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + auto padtype = helper.Get("auto_pad", std::string("")); + auto group = helper.Get("group", static_cast(1)); + std::vector default_vec = {1, 1, 1, 1}; + auto stride = + helper.Get("strides", default_vec); + auto dilation = + helper.Get("dilations", default_vec); + std::shared_ptr op; + if (padtype != "NOTSET") { // array "pads" is not set + if (group != 1 && group != inputs[1]->GetShape()[3]) { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + + } else { + int32_t multiplier = group == 1 ? 0 : inputs[1]->GetShape()[3] / inputs[0]->GetShape()[2]; + op = graph_ep->GetGraph()->CreateOperation( + vsi::npu::util::GetPadType(padtype), + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } else { + std::vector default_pads(4, 0); + auto pads = helper.Get("pads", default_pads); + if (group != 1 && group != inputs[1]->GetShape()[3]) { + op = graph_ep->GetGraph() + ->CreateOperation( + std::array{pads[1], pads[3], pads[0], pads[2]}, + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + + } else { + int32_t multiplier = group == 1 ? 0 : inputs[1]->GetShape()[3] / inputs[0]->GetShape()[2]; + op = graph_ep->GetGraph()->CreateOperation( + std::array{pads[1], pads[3], + pads[0], pads[2]}, + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h new file mode 100644 index 0000000000000..8b53d2d64296d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h @@ -0,0 +1,84 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class QLinearMatMulOpBuilder : public BaseOpBuilder { + enum { + matrixA = 0, + A_scale = 1, + A_zero_point = 2, + matrixB = 3, + B_scale = 4, + B_zero_point = 5, + out_scale = 6, + out_zero_point = 7 + }; + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto A_def = input_defs[matrixA]; + auto B_def = input_defs[matrixB]; + for (auto def : input_defs) { + if (def->Name() == A_def->Name() || def->Name() == B_def->Name()) { + continue; + } else { + if (!graph_viewer.IsConstantInitializer(def->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Scale and zero point must be known before setting graph."; + return false; + } + } + } + int64_t A_elements = util::GetTensorShape(*input_defs[A_scale]).Size(); + int64_t B_elements = util::GetTensorShape(*input_defs[B_scale]).Size(); + int64_t Out_elements = util::GetTensorShape(*input_defs[out_scale]).Size(); + if (A_elements > 1 || B_elements > 1 || Out_elements > 1) { + LOGS_DEFAULT(WARNING) << "Per channel quantized input/output is not supported in QLinearMatmul Op."; + return false; + } + + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating QLinearMatmul Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h new file mode 100644 index 0000000000000..6c0fd44464e17 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h @@ -0,0 +1,80 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class QuantizeLinearOpBuilder : public BaseOpBuilder { + enum QuantizeINPUTS { + input_tensor = 0, + scale_tensor = 1, + zero_point_tensor = 2 + }; + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto scale_shape = npu::util::GetTensorShape(*input_defs[QuantizeINPUTS::scale_tensor]); + NodeAttrHelper helper(*node); + if (helper.HasAttr("block_size") && helper.Get("block_size", 0) != 0) { + LOGS_DEFAULT(WARNING) << "Not support block quantization."; + return false; + } + if (!graph_viewer.IsConstantInitializer(input_defs[QuantizeINPUTS::scale_tensor]->Name(), true) || + (input_defs.size() == 3 && !graph_viewer.IsConstantInitializer( + input_defs[QuantizeINPUTS::zero_point_tensor]->Name(), true))) { + LOGS_DEFAULT(WARNING) << "Only support const scale / zero point."; + return false; + } + + if (scale_shape.Size() != 1) { + LOGS_DEFAULT(WARNING) << "Per channel quantized output is not supported in QuantizeLinearOp."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Quantize Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h new file mode 100644 index 0000000000000..d2e16026b712e --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h @@ -0,0 +1,83 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ReduceMeanOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + if (*input_defs[0]->Type() == "tensor(int32)") { + LOGS_DEFAULT(WARNING) << "Not support int32 reduce mean yet."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating ReduceMean Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + std::vector def_axes; + auto input_shape_size = inputs[0]->GetShape().size(); + + if (node_unit.SinceVersion() < 18 && helper.HasAttr("axes")) { + def_axes = helper.Get("axes", def_axes); + } else if (inputs.size() > 1) { + def_axes.resize(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(def_axes.data()); + } else { + for (int64_t i = 0; i < input_shape_size; ++i) { + def_axes.push_back(i); + } + } + + std::vector axes(def_axes.begin(), def_axes.end()); + axes = util::ReverseAxis(axes, input_shape_size); + + if (helper.HasAttr("noop_with_empty_axes") && inputs.size() == 1 && helper.Get("noop_with_empty_axes", 0) == 1) { + outputs[0] = inputs[0]; + return true; + } + + bool keepdims = helper.Get("keepdims", 1) == 1; + auto op = graph_ep->GetGraph()->CreateOperation(axes, keepdims); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h new file mode 100644 index 0000000000000..db3ccb153ec35 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h @@ -0,0 +1,156 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ResizeOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (node_unit.SinceVersion() > 10) { + if (node_unit.Inputs().size() > 2 && !Contains(initializers, node_unit.Inputs()[2].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Scale tensor must be constant."; + return false; + } + if (node_unit.Inputs().size() > 3 && !Contains(initializers, node_unit.Inputs()[3].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Size tensor must be constant."; + return false; + } + } else { + if (!Contains(initializers, node_unit.Inputs()[1].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Scale tensor must be constant."; + return false; + } + } + return true; + } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const override { + auto shape = vsi::npu::util::GetTensorShape(*node->InputDefs()[0]); + if (shape.NumDimensions() > 4) { + LOGS_DEFAULT(WARNING) << "3D or more dimesions resize is not supported."; + return false; + } + + NodeAttrHelper helper(*node); + if (helper.Get("antialiax", 0) != 0) { + LOGS_DEFAULT(WARNING) << "Antialias attribute is not supported."; + return false; + } + auto& cooridinate = helper.Get("coordinate_transoformation_mode", "half_pixel"); + if (cooridinate != "align_corners" && cooridinate != "half_pixel") { + LOGS_DEFAULT(WARNING) << "Only support half_pixel and align_corners attributes now."; + return false; + } + if (helper.Get("keep_aspect_ratio_policy", "stretch") != "stretch") { + LOGS_DEFAULT(WARNING) << "Not support to keep aspect ratio."; + return false; + } + if (helper.Get("mode", "nearest") == "cubic") { + LOGS_DEFAULT(WARNING) << "Not support the cubic resize type yet."; + return false; + } + if (helper.HasAttr("axes")) { + LOGS_DEFAULT(WARNING) << "Axes-specifying is not support."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Resize Op."; + auto inputs_num = inputs.size(); + bool is_1dresize = inputs[0]->GetShape().size() == 1; + NodeAttrHelper helper(node_unit.GetNode()); + auto onnx_mode = helper.Get("mode", "nearest"); + auto coordinate_transformation = helper.Get("coordinate_transformation_mode", "half_pixel"); + bool is_size_set = helper.HasAttr("size"); + int32_t scale_index = node_unit.SinceVersion() > 10 ? 2 : 1; + + auto resize_type = onnx_mode == "nearest" ? tim::vx::ResizeType::NEAREST_NEIGHBOR : tim::vx::ResizeType::BILINEAR; + bool align_corners = coordinate_transformation == "align_corners"; + bool half_pixel_center = coordinate_transformation == "half_pixel"; + std::shared_ptr op = nullptr; + if (is_1dresize) { + int target_size; + if (is_size_set) { + int64_t onnx_size; + inputs[3]->CopyDataFromTensor(&onnx_size); + target_size = static_cast(onnx_size); + op = graph_ep->GetGraph()->CreateOperation(resize_type, 0.0f, align_corners, + half_pixel_center, target_size); + } else { + float scale; + inputs[scale_index]->CopyDataFromTensor(&scale); + op = graph_ep->GetGraph()->CreateOperation(resize_type, scale, align_corners, + half_pixel_center, 0); + } + } else { + int target_h, target_w; + if (is_size_set) { + std::vector onnx_sizes(inputs[3]->GetShape().size()); + inputs[3]->CopyDataFromTensor(onnx_sizes.data()); + target_h = static_cast(onnx_sizes[1]); + target_w = static_cast(onnx_sizes[0]); + op = graph_ep->GetGraph()->CreateOperation(resize_type, 0.0f, align_corners, + half_pixel_center, target_h, target_w); + } else { + auto input_shape = inputs[0]->GetShape(); + std::vector scales(input_shape.size()); + std::vector out_shape(input_shape.size()); + inputs[scale_index]->CopyDataFromTensor(scales.data()); + for (int i = 0; i < input_shape.size(); i++) { + out_shape[i] = input_shape[i] * scales[input_shape.size() - 1 - i]; + } + target_h = static_cast(out_shape[1]); + target_w = static_cast(out_shape[0]); + op = graph_ep->GetGraph()->CreateOperation(resize_type, 0, align_corners, + half_pixel_center, target_h, target_w); + } + } + + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/slice_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/slice_op_builder.h new file mode 100644 index 0000000000000..8457f91e02cb6 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/slice_op_builder.h @@ -0,0 +1,148 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +enum SliceInputs { + data = 0, + starts = 1, + ends = 2, + axes = 3, + steps = 4 +}; + +class SliceOpBuilder : public BaseOpBuilder { + public: + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 10; } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (!util::IsTypeSupported(&iodef.node_arg) || + (i == 0 && *iodef.node_arg.Type() == "tensor(int64)") || + (i != 0 && !Contains(initializers, iodef.node_arg.Name()))) { + return false; + } + } + return true; + } + + template + void CopyTensorDataToVector(const std::shared_ptr& tensor, std::vector& vec) { + std::vector data(tensor->GetSpec().GetElementNum()); + tensor->CopyDataFromTensor(data.data()); + std::transform(data.begin(), data.end(), vec.begin(), [](T val) { + return static_cast(std::clamp(val, static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()))); + }); + } + + void ProcessAxes(const std::vector>& inputs, + int dims, bool full_axes, + std::vector& timvx_starts, + std::vector& timvx_ends, + std::vector& timvx_strides) { + auto num_elements = full_axes ? dims : inputs[SliceInputs::axes]->GetSpec().GetElementNum(); + std::vector onnx_starts(num_elements), onnx_ends(num_elements), + onnx_axes(num_elements), onnx_strides(num_elements, 1); + + auto data_type = inputs[SliceInputs::starts]->GetSpec().GetDataType(); + std::iota(onnx_axes.begin(), onnx_axes.end(), 0); + if (data_type == tim::vx::DataType::INT64) { + CopyTensorDataToVector(inputs[SliceInputs::starts], onnx_starts); + CopyTensorDataToVector(inputs[SliceInputs::ends], onnx_ends); + if (inputs.size() > 3) { + CopyTensorDataToVector(inputs[SliceInputs::axes], onnx_axes); + if (inputs.size() == 5) { + CopyTensorDataToVector(inputs[SliceInputs::steps], onnx_strides); + } + } + } else { + CopyTensorDataToVector(inputs[SliceInputs::starts], onnx_starts); + CopyTensorDataToVector(inputs[SliceInputs::ends], onnx_ends); + if (inputs.size() > 3) { + CopyTensorDataToVector(inputs[SliceInputs::axes], onnx_axes); + if (inputs.size() == 5) { + CopyTensorDataToVector(inputs[SliceInputs::steps], onnx_strides); + } + } + } + + if (!full_axes) { + for (auto& axis : onnx_axes) { + axis = HandleNegativeAxis(axis, inputs[0]->GetShape().size()); + } + } + + for (int i = 0; i < dims; ++i) { + if (full_axes || std::find(onnx_axes.begin(), onnx_axes.end(), i) != onnx_axes.end()) { + int axes_index = std::distance(onnx_axes.begin(), std::find(onnx_axes.begin(), onnx_axes.end(), i)); + timvx_starts[i] = onnx_starts[axes_index]; + timvx_ends[i] = onnx_ends[axes_index]; + if (inputs.size() == 5) { + timvx_strides[i] = onnx_strides[axes_index]; + } + } else if (!full_axes) { + timvx_starts[i] = 0; + timvx_ends[i] = inputs[SliceInputs::data]->GetShape()[dims - i - 1]; + } + } + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Slice Op."; + auto total_dims = inputs[SliceInputs::data]->GetShape().size(); + bool full_axes = inputs.size() <= 3 || (inputs[SliceInputs::axes]->GetSpec().GetElementNum() == total_dims); + std::vector timvx_starts(total_dims), timvx_ends(total_dims), timvx_strides(total_dims, 1); + + ProcessAxes(inputs, total_dims, full_axes, timvx_starts, timvx_ends, timvx_strides); + + std::reverse(timvx_starts.begin(), timvx_starts.end()); + std::reverse(timvx_ends.begin(), timvx_ends.end()); + std::reverse(timvx_strides.begin(), timvx_strides.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + timvx_starts, timvx_ends, timvx_strides, 0, 0, 0); + op->BindInput(inputs[SliceInputs::data]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h new file mode 100644 index 0000000000000..055b3680db801 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h @@ -0,0 +1,103 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class SoftmaxOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", -1); + auto input_defs = node->InputDefs(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(ERROR) << "Axis is invalid in Softmax."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Softmax Op."; + NodeAttrHelper helper(node_unit.GetNode()); + int32_t def_val = node_unit.SinceVersion() < 13 ? 1 : -1; + auto axis = helper.Get("axis", def_val); + + if (def_val == 1) { + // In earlier opset version of softmax, input is coerced into 2D shape + // Attribute "axis" is to describe the axis of the inputs coerced to 2D but not take part in softmax computation + const bool is_2d_shape = inputs[0]->GetShape().size() == 2 ? true : false; + if (!is_2d_shape) { + axis = HandleNegativeAxis(axis, inputs[0]->GetShape().size()); + auto it = inputs[0]->GetShape().end(); + uint32_t last_dim = std::accumulate(it - axis, it, 1, std::multiplies()); + uint32_t first_dim = std::accumulate(inputs[0]->GetShape().begin(), it - axis, 1, std::multiplies()); + auto reshaped_spec = inputs[0]->GetSpec().AsTransientSpec().SetShape( + std::vector{first_dim, last_dim}); + auto reshaped_input = graph_ep->GetGraph()->CreateTensor(reshaped_spec); + auto reshaped_output = graph_ep->GetGraph()->CreateTensor( + inputs[0]->GetSpec().AsTransientSpec()); + + auto reshape_input_op = graph_ep->GetGraph()->CreateOperation( + std::vector{first_dim, last_dim}); + auto softmax_op = graph_ep->GetGraph()->CreateOperation(1, 0); + auto reshaped_output_op = graph_ep->GetGraph()->CreateOperation(inputs[0]->GetShape()); + + (*reshape_input_op).BindInputs(inputs).BindOutput(reshaped_input); + (*softmax_op).BindInput(reshaped_input).BindOutput(reshaped_output); + (*reshaped_output_op).BindInput(reshaped_output).BindOutputs(outputs); + + graph_ep->GetOps().push_back(std::move(reshape_input_op)); + graph_ep->GetOps().push_back(std::move(softmax_op)); + graph_ep->GetOps().push_back(std::move(reshaped_output_op)); + } else { + auto op = graph_ep->GetGraph()->CreateOperation(1, 0); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } + } else { + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(1, static_cast(axis)); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h new file mode 100644 index 0000000000000..bc90ae1ff26a9 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h @@ -0,0 +1,89 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class SqueezeOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (node_unit.SinceVersion() > 11) { + if (node_unit.Inputs().size() > 1 && !Contains(initializers, node_unit.Inputs()[1].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Only support const axes in Squeeze op."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Squeeze Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + std::vector def_axes; + auto input_shape_size = inputs[0]->GetShape().size(); + + if (node_unit.SinceVersion() < 13 && helper.HasAttr("axes")) { + def_axes = helper.Get("axes", def_axes); + } else if (inputs.size() > 1) { + def_axes.resize(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(def_axes.data()); + } else { // if axes is empty from onnx, check input shape to determine + for (int64_t i = 0; i < input_shape_size; ++i) { + if (inputs[0]->GetShape()[i] == 1) { + def_axes.push_back(i); + } + } + } + + std::vector axes(def_axes.begin(), def_axes.end()); + axes = util::ReverseAxis(axes, input_shape_size); + + std::vector timvx_axes(axes.begin(), axes.end()); + + auto op = graph_ep->GetGraph()->CreateOperation(timvx_axes); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h new file mode 100644 index 0000000000000..9a53bef2a6693 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h @@ -0,0 +1,143 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ReshapeOpBuilder : public BaseOpBuilder { + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 5; } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input = node_unit.Inputs()[0]; + auto shape = node_unit.Inputs()[1]; + if (initializers.end() == initializers.find(shape.node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "Target shape of reshape op must be known."; + return false; + } + if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&shape.node_arg)) { + if (*input.node_arg.Type() != "tensor(int64)") { + return true; + } + } + return false; + } + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + + NodeAttrHelper helper(*node); + const bool allow_zero = helper.Get("allowzero", 0) == 1; + auto& perm_tensor_proto = *graph_viewer.GetConstantInitializer(input_defs[1]->Name(), true); + std::vector perm(perm_tensor_proto.dims()[0]); + auto status = onnxruntime::utils::UnpackTensor( + perm_tensor_proto, + perm_tensor_proto.has_raw_data() ? perm_tensor_proto.raw_data().data() : nullptr, + perm_tensor_proto.has_raw_data() ? perm_tensor_proto.raw_data().size() : 0, + perm.data(), perm.size()); + + // Check if perm has any 0's when allow zero is enabled. + if (allow_zero && std::find(perm.begin(), perm.end(), 0L) != perm.end()) { + LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 as dimension when allowzero is enabled"; + return false; + } + + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Reshape Op."; + std::vector new_shape(inputs[1]->GetShape()[0]); + inputs[1]->CopyDataFromTensor(new_shape.data()); + for (size_t i = 0; i < new_shape.size(); i++) { + if (new_shape[i] == 0) { + new_shape[i] = inputs[0]->GetShape()[inputs[0]->GetShape().size() - i - 1]; + } + } + + int64_t element_count = std::accumulate(new_shape.begin(), new_shape.end(), static_cast(1), + [&](int64_t a, int64_t b) { + return b == -1 ? a : a * b; + }); + auto negative_it = std::find(new_shape.begin(), new_shape.end(), -1); + if (negative_it != new_shape.end()) { + *negative_it = inputs[0]->GetSpec().GetElementNum() / element_count; + } + + std::vector new_shape_uint32(new_shape.begin(), new_shape.end()); + std::reverse(new_shape_uint32.begin(), new_shape_uint32.end()); + auto op = graph_ep->GetGraph()->CreateOperation(new_shape_uint32); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class TransposeOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto shape_dim = vsi::npu::util::GetTensorShape(*input_defs[0]).NumDimensions(); + NodeAttrHelper helper(*node); + auto perm = helper.Get("perm", std::vector(shape_dim, 1)); + if (perm.size() != shape_dim) { + LOGS_DEFAULT(VERBOSE) << "Size mismatch between perm vector and input shape."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Transpose Op."; + std::vector def_val(inputs[0]->GetShape().size()); + for (int64_t i = 0; i < def_val.size(); i++) def_val[i] = def_val.size() - i - 1; + + NodeAttrHelper helper(node_unit.GetNode()); + def_val = helper.Get("perm", def_val); + std::vector timvx_perm; + for (uint32_t i = 0; i < def_val.size(); i++) { + timvx_perm.push_back(def_val.size() - 1 - def_val[def_val.size() - i - 1]); + } + auto op = graph_ep->GetGraph()->CreateOperation(timvx_perm); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h new file mode 100644 index 0000000000000..501468251cea3 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h @@ -0,0 +1,72 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class TileOpBuilder : public BaseOpBuilder { + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 6; } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input = node_unit.Inputs()[0]; + auto multipliers = node_unit.Inputs()[1]; + if (initializers.end() == initializers.find(multipliers.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Multipliers of tile op must be known."; + return false; + } + if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&multipliers.node_arg)) { + if (*input.node_arg.Type() != "tensor(int64)") { + return true; + } + } + LOGS_DEFAULT(WARNING) << "Input type not supported."; + return false; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Tile Op."; + std::vector multipliers(inputs[1]->GetShape()[0]); + inputs[1]->CopyDataFromTensor(multipliers.data()); + std::reverse(multipliers.begin(), multipliers.end()); + std::vector timvx_multipliers(multipliers.begin(), multipliers.end()); + auto op = graph_ep->GetGraph()->CreateOperation(timvx_multipliers); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h new file mode 100644 index 0000000000000..d7fbaeaa72fc8 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h @@ -0,0 +1,90 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class UnsqueezeOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (node_unit.SinceVersion() > 11 && !Contains(initializers, node_unit.Inputs()[1].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Only support const axes in Unsqueeze op."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Unsqueeze Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + std::vector def_axes; + auto input_shape_size = inputs[0]->GetShape().size(); + + if (node_unit.SinceVersion() < 13 && helper.HasAttr("axes")) { + def_axes = helper.Get("axes", def_axes); + } else if (inputs.size() > 1) { + def_axes.resize(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(def_axes.data()); + } else { // if axes is empty from onnx, check input shape to determine + for (int64_t i = 0; i < input_shape_size; ++i) { + if (inputs[0]->GetShape()[i] == 1) { + def_axes.push_back(i); + } + } + } + + std::vector axes(def_axes.begin(), def_axes.end()); + axes = util::ReverseAxis(axes, input_shape_size + axes.size()); + + std::vector timvx_axes(inputs[0]->GetShape().begin(), inputs[0]->GetShape().end()); + for (int32_t dim : axes) { + timvx_axes.insert(timvx_axes.begin() + dim, 1); + } + + auto op = graph_ep->GetGraph()->CreateOperation(timvx_axes); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder.h b/onnxruntime/core/providers/vsinpu/builders/op_builder.h new file mode 100644 index 0000000000000..d81a478149c6b --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder.h @@ -0,0 +1,48 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class GraphEP; + +class IOpBuilder { + public: + IOpBuilder() {} + virtual ~IOpBuilder() {} + virtual bool IsSupported(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) const { + return true; + } + virtual bool BuildOp(GraphEP* graph_ep, + const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) = 0; +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h new file mode 100644 index 0000000000000..27c148c1672c5 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h @@ -0,0 +1,136 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include "impl/activation_op_builder.h" +#include "impl/conv_op_builder.h" +#include "impl/elementwise_op_builder.h" +#include "impl/gemm_op_builder.h" +#include "impl/pool_op_builder.h" +#include "impl/qlinearconv_op_builder.h" +#include "impl/flatten_op_builder.h" +#include "impl/matmul_op_builder.h" +#include "impl/tensor_op_builder.h" +#include "impl/concat_op_builder.h" +#include "impl/softmax_op_builder.h" +#include "impl/norm_op_builder.h" +#include "impl/clip_op_builder.h" +#include "impl/reduce_op_builder.h" +#include "impl/quantize_op_builder.h" +#include "impl/dequantize_op_builder.h" +#include "impl/qlinearmatmul_op_builder.h" +#include "impl/qlinear_binary_op_builder.h" +#include "impl/qlinearconcat_op_builder.h" +#include "impl/gather_op_builder.h" +#include "impl/tile_op_builder.h" +#include "impl/squeeze_op_builder.h" +#include "impl/unsqueeze_op_builder.h" +#include "impl/resize_op_builder.h" +#include "impl/cast_op_builder.h" +#include "impl/dropout_op_builder.h" +#include "impl/slice_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +using createIOpBuildItemFunc = std::function()>; +using OpBuildItemType = std::map>; + +static const std::map reg = { +#define REGISTER_OP_BUILDER(ONNX_NODE_TYPE, BUILDER_TYPE) \ + { \ + ONNX_NODE_TYPE, [] { return std::make_unique(); } \ + } + + REGISTER_OP_BUILDER("Add", AddOpBuilder), + REGISTER_OP_BUILDER("Sub", SubOpBuilder), + REGISTER_OP_BUILDER("Mul", MulOpBuilder), + REGISTER_OP_BUILDER("Div", DivOpBuilder), + REGISTER_OP_BUILDER("Abs", AbsOpBuilder), + REGISTER_OP_BUILDER("Pow", PowOpBuilder), + REGISTER_OP_BUILDER("Sqrt", SqrtOpBuilder), + REGISTER_OP_BUILDER("Exp", ExpOpBuilder), + REGISTER_OP_BUILDER("Floor", FloorOpBuilder), + REGISTER_OP_BUILDER("Log", LogOpBuilder), + REGISTER_OP_BUILDER("Sin", SinOpBuilder), + REGISTER_OP_BUILDER("Conv", ConvOpBuilder), + REGISTER_OP_BUILDER("Gemm", GemmOpBuilder), + REGISTER_OP_BUILDER("Relu", ReluOpBuilder), + REGISTER_OP_BUILDER("LeakyRelu", LeakyReluOpBuilder), + REGISTER_OP_BUILDER("Tanh", TanhOpBuilder), + REGISTER_OP_BUILDER("Sigmoid", SigmoidOpBuilder), + REGISTER_OP_BUILDER("HardSigmoid", HardSigmoidOpBuilder), + REGISTER_OP_BUILDER("HardSwish", HardSwishOpBuilder), + REGISTER_OP_BUILDER("GlobalAveragePool", GlobalAveragePoolOpBuilder), + REGISTER_OP_BUILDER("QLinearConv", QLinearConvOpBuilder), + REGISTER_OP_BUILDER("Flatten", FlattenOpBuilder), + REGISTER_OP_BUILDER("MatMul", MatMulOpBuilder), + REGISTER_OP_BUILDER("GlobalMaxPool", GlobalMaxPoolOpBuilder), + REGISTER_OP_BUILDER("AveragePool", AveragePoolOpBuilder), + REGISTER_OP_BUILDER("MaxPool", MaxPoolOpBuilder), + REGISTER_OP_BUILDER("Reshape", ReshapeOpBuilder), + REGISTER_OP_BUILDER("Concat", ConcatOpBuilder), + REGISTER_OP_BUILDER("Softmax", SoftmaxOpBuilder), + REGISTER_OP_BUILDER("Transpose", TransposeOpBuilder), + REGISTER_OP_BUILDER("BatchNormalization", BatchNormOpBuilder), + REGISTER_OP_BUILDER("Clip", ClipOpBuilder), + REGISTER_OP_BUILDER("ReduceMean", ReduceMeanOpBuilder), + REGISTER_OP_BUILDER("QuantizeLinear", QuantizeLinearOpBuilder), + REGISTER_OP_BUILDER("DequantizeLinear", DequantizeLinearOpBuilder), + REGISTER_OP_BUILDER("QLinearMatMul", QLinearMatMulOpBuilder), + REGISTER_OP_BUILDER("QLinearAdd", QLinearAddOpBuilder), + REGISTER_OP_BUILDER("QLinearMul", QLinearMulOpBuilder), + REGISTER_OP_BUILDER("QLinearConcat", QLinearConcatOpBuilder), + REGISTER_OP_BUILDER("Gather", GatherOpBuilder), + REGISTER_OP_BUILDER("Tile", TileOpBuilder), + REGISTER_OP_BUILDER("Squeeze", SqueezeOpBuilder), + REGISTER_OP_BUILDER("Unsqueeze", UnsqueezeOpBuilder), + REGISTER_OP_BUILDER("Resize", ResizeOpBuilder), + REGISTER_OP_BUILDER("Cast", CastOpBuilder), + REGISTER_OP_BUILDER("Dropout", DropoutOpBuilder), + REGISTER_OP_BUILDER("Slice", SliceOpBuilder) +#undef REGISTER_OP_BUILDER +}; + +template +struct OpBuildConstructor { + T supported_builtins; + OpBuildConstructor( + const std::map reg) { + LOGS_DEFAULT(INFO) << "Initialize supported ops"; + for (const auto& kv : reg) { + supported_builtins.insert(std::make_pair(kv.first, kv.second())); + } + } +}; + +inline const OpBuildItemType& SupportedBuiltinOps() { + static OpBuildConstructor c(reg); + return c.supported_builtins; +} +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch b/onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch new file mode 100644 index 0000000000000..d44190101d9fa --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch @@ -0,0 +1,26 @@ +diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc +index 47c18c478d..93b44501cd 100644 +--- a/onnxruntime/test/providers/checkers.cc ++++ b/onnxruntime/test/providers/checkers.cc +@@ -195,7 +195,7 @@ struct TensorCheck { + // For any other EPs, we still expect an exact match for the results + // TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513 + if ((provider_type == kNnapiExecutionProvider || provider_type == kDmlExecutionProvider || +- provider_type == kXnnpackExecutionProvider) && ++ provider_type == kXnnpackExecutionProvider || provider_type == kVSINPUExecutionProvider) && + (has_abs_err || has_rel_err)) { + double threshold = has_abs_err ? *(params.absolute_error) + : 0.0; +diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +index 2bc0df5e36..7beb78c2ff 100644 +--- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc ++++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +@@ -498,7 +498,7 @@ class QLinearConvOpTester { + // NOTE, for now the tolerance will only apply if the NNAPI is actually used, + // if for any reason the execution falls back to CPU, we still expect an exact match + // See, 'void Check(...' in onnxruntime/test/providers/provider_test_utils.cc +-#if defined(USE_NNAPI) || defined(USE_DML) ++#if defined(USE_NNAPI) || defined(USE_DML) || defined(USE_VSINPU) + // TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513 + abs_error = 1.0f; + #endif diff --git a/onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch b/onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch new file mode 100644 index 0000000000000..e118ee104912f --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch @@ -0,0 +1,343 @@ +diff --git a/onnxruntime/test/onnx/dataitem_request.cc b/onnxruntime/test/onnx/dataitem_request.cc +index 1ee302d5d5..5c2dd5ab00 100644 +--- a/onnxruntime/test/onnx/dataitem_request.cc ++++ b/onnxruntime/test/onnx/dataitem_request.cc +@@ -135,6 +135,7 @@ std::pair DataTaskRequestContext::RunImpl() { + } + + EXECUTE_RESULT res = EXECUTE_RESULT::SUCCESS; ++ int32_t out_idx = 0; + for (auto& output : expected_output_values) { + const std::string& output_name = output.first; + OrtValue* expected_output_value = output.second; // Automatic cast +@@ -170,7 +171,7 @@ std::pair DataTaskRequestContext::RunImpl() { + } else { // Both expect and actual OrtValues are not None, proceed with data checking + ret = + CompareOrtValue(*actual_output_value, *expected_output_value, per_sample_tolerance, +- relative_per_sample_tolerance, post_procesing); ++ relative_per_sample_tolerance, post_procesing, out_idx); + } + } else { // Expected output is None, ensure that the received output OrtValue is None as well + if (actual_output_value->IsAllocated()) { +@@ -223,9 +224,10 @@ std::pair DataTaskRequestContext::RunImpl() { + if (compare_result != COMPARE_RESULT::SUCCESS && !ret.second.empty()) { + LOGS_DEFAULT(ERROR) << test_case_.GetTestCaseName() << ":output=" << output_name << ":" << ret.second; + } +- if (compare_result != COMPARE_RESULT::SUCCESS) { +- break; +- } ++ // if (compare_result != COMPARE_RESULT::SUCCESS) { ++ // break; ++ // } ++ out_idx ++; + } + return std::make_pair(res, spent_time_); + } +diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc +index f1a7240ea3..436031dfa8 100644 +--- a/onnxruntime/test/providers/checkers.cc ++++ b/onnxruntime/test/providers/checkers.cc +@@ -154,6 +154,7 @@ struct TensorCheck { + } + + const bool has_abs_err = params.absolute_error.has_value(); ++ const int8_t default_abs_err = 1; + if (has_abs_err) { + double threshold = *(params.absolute_error); + +@@ -162,7 +163,8 @@ struct TensorCheck { + } + } else { + for (int i = 0; i < size; ++i) { +- EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; ++ // EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; ++ EXPECT_NEAR(cur_expected[i], cur_actual[i], default_abs_err) << "i:" << i; + } + } + } +diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc +index 3d53d4a3a0..8129af1820 100644 +--- a/onnxruntime/test/util/compare_ortvalue.cc ++++ b/onnxruntime/test/util/compare_ortvalue.cc +@@ -138,11 +138,75 @@ std::pair CompareFloatResult(const Tensor& outvalue + return res; + } + ++template ++std::pair CompareFloatResult(const Tensor& outvalue, const Tensor& expected_value, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx) { ++ const size_t size1 = static_cast(expected_value.Shape().Size()); ++ const FLOAT_TYPE* expected_output = expected_value.Data(); ++ const FLOAT_TYPE* real_output = outvalue.Data(); ++ ++ std::string expected_name = "expected_res"+ std::to_string(out_idx)+ ".txt"; ++ std::string npures_name = "npu_res"+ std::to_string(out_idx)+ ".txt"; ++ std::ofstream expected_res(expected_name), npu_res(npures_name); ++ for(size_t i = 0 ; i < size1; i++){ ++ expected_res << expected_output[i] << std::endl; ++ npu_res << real_output[i] << std::endl; ++ } ++ expected_res.close(); ++ npu_res.close(); ++ ++ std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); ++ double max_diff = 0; ++ size_t diff_count = 0; ++ for (size_t di = 0; di != size1; ++di) { ++ const double real_value = ++ post_processing ? std::max(0.0, std::min(255.0, real_output[di])) : real_output[di]; ++ const double diff = std::fabs(expected_output[di] - real_value); ++ const double tol = per_sample_tolerance + relative_per_sample_tolerance * std::fabs(expected_output[di]); ++ if (!IsResultCloselyMatch(real_value, expected_output[di], diff, tol)) { ++ res.first = COMPARE_RESULT::RESULT_DIFFERS; ++ // update error message if this is a larger diff ++ if (diff > max_diff || (std::isnan(diff) && !std::isnan(max_diff))) { ++ int64_t expected_int = 0; ++ int64_t real_int = 0; ++ memcpy(&expected_int, &expected_output[di], sizeof(FLOAT_TYPE)); ++ memcpy(&real_int, &real_output[di], sizeof(FLOAT_TYPE)); ++ ++ std::ostringstream oss; ++ oss << std::hex << "expected " << expected_output[di] << " (" << expected_int << "), got " << real_value << " (" ++ << real_int << ")" ++ << ", diff: " << diff << ", tol=" << tol << std::dec << " idx=" << di << "."; ++ res.second = oss.str(); ++ max_diff = diff; ++ } ++ ++diff_count; ++ } ++ } ++ ++ if (res.first == COMPARE_RESULT::SUCCESS) return res; ++ ++ std::ostringstream oss; ++ oss << res.second << " " << diff_count << " of " << size1 << " differ"; ++ res.second = oss.str(); ++ return res; ++} ++ ++ + template +-std::pair IsResultExactlyMatch(const Tensor& outvalue, const Tensor& expected_value) { ++std::pair IsResultExactlyMatch(const Tensor& outvalue, const Tensor& expected_value, int32_t out_idx) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); ++ std::string expected_name = "expected_res"+ std::to_string(out_idx)+ ".txt"; ++ std::string npures_name = "npu_res"+ std::to_string(out_idx)+ ".txt"; ++ std::ofstream expected_res(expected_name), npu_res(npures_name); ++ for(size_t i = 0 ; i < size1; i++){ ++ expected_res << expected_output[i] << std::endl; ++ npu_res << real_output[i] << std::endl; ++ } ++ expected_res.close(); ++ npu_res.close(); + for (size_t di = 0; di != size1; ++di) { + if (expected_output[di] != real_output[di]) { + std::ostringstream oss; +@@ -201,7 +265,7 @@ std::pair CompareBFloat16Result(const Tensor& outva + + std::pair CompareTwoTensors(const Tensor& outvalue, const Tensor& expected_tensor, + double per_sample_tolerance, +- double relative_per_sample_tolerance, bool post_processing) { ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx) { + if (expected_tensor.Shape() != outvalue.Shape()) { + std::ostringstream oss; + oss << "shape mismatch, expect " << expected_tensor.Shape().ToString() << " got " << outvalue.Shape().ToString(); +@@ -209,30 +273,30 @@ std::pair CompareTwoTensors(const Tensor& outvalue, + } + if (outvalue.IsDataType()) { + return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, out_idx); + } else if (outvalue.IsDataType()) { + return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, out_idx); + } else if (outvalue.IsDataTypeString()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { + return CompareFloat16Result(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, + post_processing); +@@ -300,7 +364,7 @@ std::pair CompareSparseTensors(const SparseTensor& + " actual: ", actual.Format()); + + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual.Values(), expected.Values(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "While comparing sparse values"); + + if (actual.Format() == SparseFormat::kCoo) { +@@ -308,16 +372,16 @@ std::pair CompareSparseTensors(const SparseTensor& + auto expected_view = expected.AsCoo(); + + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual_view.Indices(), expected_view.Indices(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "Comparing COO indices"); + } else if (actual.Format() == SparseFormat::kCsrc) { + auto actual_view = actual.AsCsr(); + auto expected_view = expected.AsCsr(); + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual_view.Inner(), expected_view.Inner(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "Comparing Csr(c) inner indices"); + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual_view.Outer(), expected_view.Outer(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "Comparing Csr(c) outer indices"); + } + +@@ -385,7 +449,83 @@ std::pair CompareOrtValue(const OrtValue& o, const + return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, oss.str()); + } + return CompareTwoTensors(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, 0); ++ } else if (o.IsSparseTensor()) { ++#if !defined(DISABLE_SPARSE_TENSORS) ++ TEST_RETURN_IF_NOT(expected_mlvalue.IsSparseTensor(), COMPARE_RESULT::TYPE_MISMATCH, ++ "SparseTensor is not expected as output"); ++ TEST_RETURN_IF_ERROR(CompareSparseTensors(o.Get(), expected_mlvalue.Get(), ++ per_sample_tolerance, relative_per_sample_tolerance, ++ post_processing), ++ "while comaring sparse tensors"); ++#endif ++ return std::make_pair(COMPARE_RESULT::SUCCESS, ""); ++ } else if (o.IsTensorSequence()) { ++ auto& expected_tensor_seq = expected_mlvalue.Get(); ++ auto expected_tensor_count = expected_tensor_seq.Size(); ++ ++ auto& actual_tensor_seq = o.Get(); ++ auto actual_tensor_count = actual_tensor_seq.Size(); ++ ++ if (expected_tensor_count != actual_tensor_count) { ++ std::ostringstream oss; ++ oss << "expected tensor count in the sequence: " << expected_tensor_count << " got " ++ << actual_tensor_count; ++ return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); ++ } ++ ++ if (!expected_tensor_seq.IsSameDataType(actual_tensor_seq)) { ++ std::ostringstream oss; ++ oss << "expected tensor type in the sequence: " << expected_tensor_seq.DataType() << " got " ++ << actual_tensor_seq.DataType(); ++ return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, oss.str()); ++ } ++ ++ for (size_t i = 0; i < expected_tensor_count; ++i) { ++ auto res = CompareTwoTensors(actual_tensor_seq.Get(i), expected_tensor_seq.Get(i), per_sample_tolerance, relative_per_sample_tolerance, ++ post_processing,0); ++ if (res.first != COMPARE_RESULT::SUCCESS) { ++ return res; ++ } ++ } ++ ++ return std::make_pair(COMPARE_RESULT::SUCCESS, ""); ++ ++ } else { ++ // Maps ++#if !defined(DISABLE_ML_OPS) ++ if (o.Type() == DataTypeImpl::GetType()) { ++ return CompareSeqOfMapToFloat(o.Get(), expected_mlvalue.Get(), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing); ++ } ++ if (o.Type() == DataTypeImpl::GetType()) { ++ return CompareSeqOfMapToFloat(o.Get(), expected_mlvalue.Get(), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing); ++ } ++ return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); ++#else ++ return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, "Map type is not supported in this build."); ++#endif ++ } ++} ++ ++std::pair CompareOrtValue(const OrtValue& o, const OrtValue& expected_mlvalue, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx) { ++ if (o.Type() != expected_mlvalue.Type()) { ++ return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, ""); ++ } ++ if (o.IsTensor()) { ++ const Tensor& outvalue = o.Get(); ++ const Tensor& expected_tensor = expected_mlvalue.Get(); ++ if (outvalue.DataType() != expected_tensor.DataType()) { ++ std::ostringstream oss; ++ oss << "expect " << ElementTypeToString(expected_tensor.DataType()) << " got " ++ << ElementTypeToString(outvalue.DataType()); ++ return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, oss.str()); ++ } ++ return CompareTwoTensors(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, ++ post_processing, out_idx); + } else if (o.IsSparseTensor()) { + #if !defined(DISABLE_SPARSE_TENSORS) + TEST_RETURN_IF_NOT(expected_mlvalue.IsSparseTensor(), COMPARE_RESULT::TYPE_MISMATCH, +@@ -419,7 +559,7 @@ std::pair CompareOrtValue(const OrtValue& o, const + + for (size_t i = 0; i < expected_tensor_count; ++i) { + auto res = CompareTwoTensors(actual_tensor_seq.Get(i), expected_tensor_seq.Get(i), per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, out_idx); + if (res.first != COMPARE_RESULT::SUCCESS) { + return res; + } +diff --git a/onnxruntime/test/util/include/compare_ortvalue.h b/onnxruntime/test/util/include/compare_ortvalue.h +index 24b74b9002..8269346528 100644 +--- a/onnxruntime/test/util/include/compare_ortvalue.h ++++ b/onnxruntime/test/util/include/compare_ortvalue.h +@@ -24,7 +24,9 @@ enum class COMPARE_RESULT { SUCCESS, + std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, + double per_sample_tolerance, + double relative_per_sample_tolerance, bool post_processing); +- ++std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx); + // verify if the 'value' matches the 'expected' ValueInfoProto. 'value' is a model output + std::pair VerifyValueInfo(const ONNX_NAMESPACE::ValueInfoProto& expected, + const OrtValue* value); +diff --git a/onnxruntime/test/util/include/test/compare_ortvalue.h b/onnxruntime/test/util/include/test/compare_ortvalue.h +index 545df706c9..170eb9dc4c 100644 +--- a/onnxruntime/test/util/include/test/compare_ortvalue.h ++++ b/onnxruntime/test/util/include/test/compare_ortvalue.h +@@ -28,7 +28,9 @@ enum class COMPARE_RESULT { + std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, + double per_sample_tolerance, + double relative_per_sample_tolerance, bool post_processing); +- ++std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx); + // Compare two OrtValue numerically equal or not. The difference with CompareOrtValue is that this function + // will only check the numerical values of the OrtValue, and ignore the type, shape, etc. + // diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch new file mode 100644 index 0000000000000..2176ff559c684 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -0,0 +1,35 @@ +diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake +index 304aa77f54..5c22b7097b 100644 +--- a/cmake/onnxruntime_mlas.cmake ++++ b/cmake/onnxruntime_mlas.cmake +@@ -354,7 +354,7 @@ else() + ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") +- if (NOT APPLE) ++ if (NOT APPLE AND NOT onnxruntime_USE_VSINPU) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S +diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h +index cdfd283899..678a055b24 100644 +--- a/onnxruntime/core/mlas/inc/mlas.h ++++ b/onnxruntime/core/mlas/inc/mlas.h +@@ -82,6 +82,9 @@ Abstract: + + #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) + #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) ++#if !defined(USE_VSINPU) ++// Had to tempory disable fp16 under VeriSilicon ARM64 to avoid ++// conflict of compilation flag. + #if !defined(__APPLE__) + // Had to temporary disable fp16 under APPLE ARM64, as compiling + // the source files require a hardware specific compilation flag. +@@ -90,6 +93,7 @@ Abstract: + + #define MLAS_F16VEC_INTRINSICS_SUPPORTED + ++#endif // + #endif // + #endif // ARM64 + #endif // Visual Studio 16 or earlier does not support fp16 intrinsic diff --git a/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py new file mode 100644 index 0000000000000..e4e9b44fdc252 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py @@ -0,0 +1,29 @@ +import sys + +import numpy as np +from numpy.linalg import norm + + +def read_values(filename): + with open(filename) as file: + values = np.array([float(line.strip()) for line in file]) + return values + + +def cosine_similarity(vec1, vec2): + return np.dot(vec1, vec2) / (norm(vec1) * norm(vec2)) + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python cosine_similarity.py ") + sys.exit(1) + + file1 = sys.argv[1] + file2 = sys.argv[2] + + vec1 = read_values(file1) + vec2 = read_values(file2) + + similarity = cosine_similarity(vec1, vec2) + print(f"Cosine Similarity: {similarity}") diff --git a/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py new file mode 100644 index 0000000000000..cde75b7f18c1e --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py @@ -0,0 +1,34 @@ +import sys + + +def read_values(filename): + with open(filename) as file: + values = [(float(line.strip()), i + 1) for i, line in enumerate(file)] + return values + + +def top_n(values, N): + return sorted(values, key=lambda x: x[0], reverse=True)[:N] + + +def compare_files(cpu_file, npu_file, N): + cpu_values = read_values(cpu_file) + npu_values = read_values(npu_file) + + cpu_topn = top_n(cpu_values, N) + npu_topn = top_n(npu_values, N) + + print(f"Top-{N} values in {cpu_file}: {cpu_topn}") + print(f"Top-{N} values in {npu_file}: {npu_topn}") + + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: python compare_topn.py ") + sys.exit(1) + + N = int(sys.argv[1]) + cpu_file = sys.argv[2] + npu_file = sys.argv[3] + + compare_files(cpu_file, npu_file, N) diff --git a/onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh b/onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh new file mode 100644 index 0000000000000..c27af51c26799 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh @@ -0,0 +1,23 @@ +#!/bin/bash +res_file_dir=$1 +output_num=$2 + +# specifying N value +N=5 + +for i in $(seq 0 $((output_num-1))); +do + # 构建文件名 + golden_file="${res_file_dir}/expected_res${i}.txt" + npu_file="${res_file_dir}/npu_res${i}.txt" + + echo "Comparing Top-${N} for the output_${i}" + python3 compare_topn.py $N $golden_file $npu_file + + echo "--------------------------------" + + echo "Comparing Cosine Similarity for output_${i}:" + python3 compare_cosine_sim.py $golden_file $npu_file + + echo "" +done diff --git a/onnxruntime/core/providers/vsinpu/symbols.txt b/onnxruntime/core/providers/vsinpu/symbols.txt new file mode 100644 index 0000000000000..d69c92692f5fe --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/symbols.txt @@ -0,0 +1 @@ +OrtSessionOptionsAppendExecutionProvider_VSINPU diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc new file mode 100644 index 0000000000000..bbf8255ac2940 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -0,0 +1,300 @@ + +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include "core/providers/vsinpu/vsinpu_ep_graph.h" +#include "core/providers/vsinpu/builders/op_builder_factory.h" +#include "core/providers/vsinpu/vsinpu_util.h" +#include "core/framework/node_unit.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" + +namespace onnxruntime { + +namespace vsi { +namespace npu { +GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { + Prepare(); + context_ = tim::vx::Context::Create(); + graph_ = context_->CreateGraph(); + compiled_ = false; +} + +bool GraphEP::Prepare() { + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + for (const auto& node_unit : node_unit_holder_) { + auto quant_op_type = util::GetQuantizedOpType(*node_unit); + + // Not a qlinear op or qdq node group + if (quant_op_type == util::QuantizedOpType::Unknown) + continue; + + const auto add_quantized_input = + [&all_quantized_op_inputs = all_quantized_op_inputs_](const NodeUnit& node_unit, size_t input_idx) { + const auto& input_name = node_unit.Inputs()[input_idx].node_arg.Name(); + all_quantized_op_inputs[input_name].push_back(&node_unit); + }; + + // All quantized ops EXCEPT QuantizeLinear has quantized input + if (quant_op_type != util::QuantizedOpType::QuantizeLinear) { + add_quantized_input(*node_unit, 0); + } + + if (util::IsQuantizedBinaryOp(quant_op_type)) { + add_quantized_input(*node_unit, 1); + if (util::IsQuantizedConv(quant_op_type) && node_unit->Inputs().size() == 3) { + add_quantized_input(*node_unit, 2); + } + } + } // All quantized inputs is recorded + return true; +} + +bool GraphEP::SupportedOp(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) { + const auto& supported_builtins = vsi::npu::SupportedBuiltinOps(); + const auto& target_node = node_unit.GetNode(); + const auto& it = supported_builtins.find(target_node.OpType()); + if (supported_builtins.end() != it) { + return it->second->IsSupported(graph_viewer, node_unit); + } + LOGS_DEFAULT(WARNING) << "Fallback unsupported op (node_unit) " << node_unit.OpType() + << " to cpu."; + return false; +} + +bool GraphEP::IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer) { + return SupportedOp(graph_viewer, node_unit); +} + +const NodeUnit& GraphEP::GetNodeUnit(const Node* node) const { + const auto node_unit_it = node_unit_map_.find(node); + ORT_ENFORCE(node_unit_it != node_unit_map_.end(), "Node does not have corresponding NodeUnit."); + return *node_unit_it->second; +} + +void GraphEP::UpdateTensorMap(const std::string& name, const std::shared_ptr& dst_tensor) { + auto it = tensors_.find(name); + if (it != tensors_.end()) { + it->second = dst_tensor; + } + for (auto& IO : graph_inputs_) { + if (IO->name == name) { + IO->tensor = dst_tensor; + break; + } + } + for (auto& IO : graph_outputs_) { + if (IO->name == name) { + IO->tensor = dst_tensor; + break; + } + } +} + +std::shared_ptr GraphEP::ConstructNodeIO(const std::shared_ptr& op, + std::vector input_arg, + std::vector output_arg) { + auto info = std::make_shared(); + info->op_ = op; + std::vector input_names, output_names; + if (input_arg.empty()) { + info->input_names_ = std::vector(); + } else { + input_names.reserve(input_arg.size()); + std::transform(input_arg.begin(), input_arg.end(), std::back_inserter(input_names), + [](const NodeArg* node) -> std::string { + return node->Name(); + }); + info->input_names_ = input_names; + } + if (output_arg.empty()) { + info->output_names_ = std::vector(); + } else { + output_names.reserve(output_arg.size()); + std::transform(output_arg.begin(), output_arg.end(), std::back_inserter(output_names), + [](const NodeArg* node) -> std::string { + return node->Name(); + }); + info->output_names_ = output_names; + } + + return info; +} + +bool GraphEP::BindTensors(const std::shared_ptr& nodeio_info) { + auto op = nodeio_info->op_; + auto input_names = nodeio_info->input_names_; + auto output_names = nodeio_info->output_names_; + if (!input_names.empty()) { + for (auto& name : input_names) { + if (tensors_.find(name) == tensors_.end() || tensors_[name] == nullptr) { + LOGS_DEFAULT(ERROR) << "Input tensor not defined or not found!"; + return false; + } + (*op).BindInput(tensors_[name]); + } + } + if (!output_names.empty()) { + for (auto& name : output_names) { + if (tensors_.find(name) == tensors_.end() || tensors_[name] == nullptr) { + LOGS_DEFAULT(ERROR) << "Output tensor not defined or not found!"; + return false; + } + (*op).BindOutput(tensors_[name]); + } + } + return true; +} + +std::shared_ptr GraphEP::MapTIMVXTensor( + std::shared_ptr& graph, const NodeUnitIODef nudef, + const NodeUnit& node_unit, + const GraphViewer* graph_viewer, tim::vx::TensorAttribute attribute) { + const auto& arg = nudef.node_arg; + + if (tensors_.end() != tensors_.find(nudef.node_arg.Name())) { + return tensors_.find(arg.Name())->second; + } + auto shape = vsi::npu::util::OnnxShapeToTIMVXShape(vsi::npu::util::GetTensorShape(arg)); + std::reverse(shape.begin(), shape.end()); + tim::vx::DataType dt = vsi::npu::util::OnnxDtypeToTIMVXDtype(arg.Type()); + tim::vx::TensorSpec spec = tim::vx::TensorSpec(dt, shape, attribute); + + // Tensors have same name may not have same status of quant_param existence, such as QLinearConv->MaxPool->QLinearConv + // Maxpool output tensor is not set quantization at first pass + bool is_qtensor = nudef.quant_param.has_value() || Contains(all_quantized_op_inputs_, arg.Name()); + if (is_qtensor) { + float scale = 0.0f; + int32_t zp = 0; + std::optional> scales; + std::optional> zps; + if (nudef.quant_param.has_value()) { + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, + nudef, node_unit.ModelPath(), + scale, zp, scales, zps); + } else { + auto target_nodeunit = all_quantized_op_inputs_[arg.Name()][0]; + auto qinput = all_quantized_op_inputs_[arg.Name()][0]->Inputs(); + auto it = std::find_if(qinput.begin(), qinput.end(), [&arg](const NodeUnitIODef& nud) { + return nud.node_arg.Name() == arg.Name(); + }); + bool is_conv_bias = std::distance(qinput.begin(), it) == 2; + if (!is_conv_bias || it->quant_param.has_value()) { + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, + *it, target_nodeunit->ModelPath(), + scale, zp, scales, zps); + } else if (!it->quant_param.has_value()) { + float in_scale, w_scale; + int32_t in_zp, w_zp; + std::optional> in_scales, w_scales; + std::optional> in_zps, w_zps; + + // onnx defines conv bias with non quantization, but it must be quantized in VSINPU support + // The bias scale is set as input_scale * weight_scale if per layer quantized, + // otherwise input_scale* weight_scale[i] if per channel quantized + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, + qinput[0], target_nodeunit->ModelPath(), + in_scale, in_zp, in_scales, in_zps); + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, + qinput[1], target_nodeunit->ModelPath(), + w_scale, w_zp, w_scales, w_zps); + scale = in_scale * w_scale; + zp = 0; + if (w_scales) { + std::vector temp; + for (size_t i = 0; i < w_scales->size(); i++) { + temp.push_back(w_scales.value()[i] * in_scale); + } + scales = temp; + } + } + } + tim::vx::Quantization quant; + // per tensor quantization + if (!scales.has_value()) { + quant.SetType(tim::vx::QuantType::ASYMMETRIC); + quant.SetScales({scale}); + quant.SetZeroPoints({zp}); + } else { // per channel quantization + if (zps.has_value()) { + bool has_nonzero = std::find_if(zps->begin(), zps->end(), [](int elem) { return elem != 0; }) != zps->end(); + if (has_nonzero && *arg.Type() == "tensor(uint8)") { + quant.SetType(tim::vx::QuantType::ASYMMETRIC_PER_CHANNEL); + } else { + quant.SetType(tim::vx::QuantType::SYMMETRIC_PER_CHANNEL); + } + quant.SetZeroPoints(zps.value()); + } else { + if (*arg.Type() == "tensor(int32)" || zp == 0) { + // set bias quant type + quant.SetType(tim::vx::QuantType::SYMMETRIC_PER_CHANNEL); + } else { + quant.SetType(tim::vx::QuantType::ASYMMETRIC_PER_CHANNEL); + } + quant.SetZeroPoints({zp}); + } + quant.SetScales(scales.value()); + quant.SetChannelDim(shape.size() - 1); + } + spec.SetQuantization(quant); + } + + std::shared_ptr tensor; + if (attribute == + tim::vx::TensorAttribute::CONSTANT) { // create const tensor + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_viewer_.GetConstantInitializer(arg.Name(), true); + std::shared_ptr unpackedTensor = + vsi::npu::util::UnpackTensor(&arg, *tensor_proto); + + const void* valueAddr = + reinterpret_cast(unpackedTensor.get()); + tensor = graph->CreateTensor(spec, valueAddr); + + } else { + tensor = graph->CreateTensor(spec); + } + for (auto& input : graph_inputs_) { + if (input->name == arg.Name()) { + input->tensor = tensor; + input->shape = vsi::npu::util::GetTensorShape(arg); + break; + } + } + for (auto& output : graph_outputs_) { + if (output->name == arg.Name()) { + output->tensor = tensor; + output->shape = utils::GetTensorShapeFromTensorShapeProto(*arg.Shape()); + break; + } + } + tensors_.insert({arg.Name(), tensor}); + return tensor; +} + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h new file mode 100644 index 0000000000000..49344770d060e --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -0,0 +1,117 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ + +#pragma once +#include +#include +#include +#include +#include +#include "builders/op_builder.h" +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/tensor.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +struct GraphIOInfo { + std::string name; + bool is_initializer; + std::shared_ptr tensor; + TensorShape shape; +}; + +struct NodeIOInfo { + std::shared_ptr op_; + std::vector input_names_; + std::vector output_names_; +}; + +class GraphEP { + public: + explicit GraphEP(const GraphViewer& graph_viewer); + ~GraphEP() {} + + bool Prepare(); + + static bool SupportedOp(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit); + + // If a node is supported by VSINPU in a partition node group + // `node_outputs_in_group` is the set of the output names of the nodes added to this group so far + static bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer); + + const NodeUnit& GetNodeUnit(const Node* node) const; + + bool& GetCompiled() { return compiled_; } + std::shared_ptr& GetGraph() { return graph_; } + std::vector>& GetOps() { return ops_; } + std::map>& GetTensors() { + return tensors_; + } + + std::vector>& GetGraphInputs() { + return graph_inputs_; + } + + std::vector>& GetGraphOutputs() { + return graph_outputs_; + } + + void UpdateTensorMap(const std::string& name, const std::shared_ptr& dst_tensor); + + std::shared_ptr ConstructNodeIO(const std::shared_ptr& op, + std::vector input_arg, std::vector output_arg); + + bool BindTensors(const std::shared_ptr& nodeio_info); + + std::shared_ptr MapTIMVXTensor( + std::shared_ptr& graph, const NodeUnitIODef nudef, + const NodeUnit& nodeunit, + const GraphViewer* graph_viewer, tim::vx::TensorAttribute attribute); + + private: + std::shared_ptr context_; + std::shared_ptr graph_; + std::map> tensors_; + std::vector> ops_; + std::vector> graph_inputs_; + std::vector> graph_outputs_; + + // Contains all quantized operators' input and the NodeUnit(s) using the input + // In the form of {input_name, [NodeUnit(s) using the input]} + std::unordered_map> all_quantized_op_inputs_; + const GraphViewer& graph_viewer_; + + // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is + // valid throughout the lifetime of the ModelBuilder + std::vector> node_unit_holder_; + std::unordered_map node_unit_map_; + bool compiled_; +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc new file mode 100644 index 0000000000000..466fe1f82461c --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -0,0 +1,279 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/framework/compute_capability.h" +#include "core/providers/vsinpu/vsinpu_execution_provider.h" +#include "core/providers/vsinpu/vsinpu_ep_graph.h" +#include "core/providers/vsinpu/builders/op_builder.h" +#include "core/providers/vsinpu/builders/op_builder_factory.h" +#include "core/providers/vsinpu/vsinpu_util.h" +#include "core/framework/kernel_registry.h" +#include "core/framework/node_unit.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/partitioning_utils.h" + +namespace onnxruntime { +VSINPUExecutionProvider::VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kVSINPUExecutionProvider}, + device_id_(info.device_id) { + AllocatorCreationInfo default_memory_info{ + [](int) { + return std::make_unique( + OrtMemoryInfo("VSINPU", OrtAllocatorType::OrtDeviceAllocator)); + }}; + + CreateAllocator(default_memory_info); + + AllocatorCreationInfo cpu_memory_info{ + [](int) { + return std::make_unique( + OrtMemoryInfo("VSINPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); + }}; + + CreateAllocator(cpu_memory_info); +} + +VSINPUExecutionProvider::~VSINPUExecutionProvider() {} + +std::vector> +VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const { + std::vector> result; + + if (graph_viewer.IsSubgraph()) { + return result; + } + + for (const auto& tensor : graph_viewer.GetAllInitializedTensors()) { + if (tensor.second->has_data_location()) { + LOGS_DEFAULT(VERBOSE) << "location:" << tensor.second->data_location(); + if (tensor.second->data_location() == + ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + LOGS_DEFAULT(WARNING) << "VSINPU: Initializers with external data location are not " + "currently supported"; + return result; + } + } + } + // Get all the NodeUnits in the graph_viewer + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + + // This holds the result of whether a NodeUnit is supported or not, + // to prevent nodes in a NodeUnit to be checked for multiple times + std::unordered_map node_unit_supported_result; + node_unit_supported_result.reserve(node_unit_holder.size()); + std::unordered_set node_outputs_in_current_group{}; + + const auto is_node_supported = [&](const Node& node) -> bool { + const NodeUnit* node_unit = node_unit_map.at(&node); + bool supported = false; + + // If we have visited one of the nodes in the node_unit, use the result directly + const auto it = node_unit_supported_result.find(node_unit); + if (it != node_unit_supported_result.cend()) { + supported = it->second; + } else { + // We only check the target node of the node unit + supported = vsi::npu::GraphEP::IsNodeSupportedInGroup(*node_unit, graph_viewer); + node_unit_supported_result[node_unit] = supported; + } + + LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + << "] Operator type: [" << node.OpType() + << "] index: [" << node.Index() + << "] name: [" << node.Name() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + + if (supported) { + // We want to save all the output names of nodes in the current group for easy query + for (const auto* output : node.OutputDefs()) { + node_outputs_in_current_group.insert(output->Name()); + } + } + return supported; + }; + + const auto on_group_closed = [&](const std::vector& group) -> bool { + // reset per-partition node group tracking + node_outputs_in_current_group.clear(); + return true; + }; + + const auto gen_metadef_name = [&]() { + static size_t group_counter = 0; + return "VSINPU_" + std::to_string(++group_counter); + }; + result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "VSINPU", kVSINPUExecutionProvider, &node_unit_map); + std::for_each(result.begin(), result.end(), [&graph_viewer](auto& capability) { + if (capability && capability->sub_graph && capability->sub_graph->GetMetaDef()) { + const auto* meta_def = capability->sub_graph->GetMetaDef(); + bool has_any_non_constant_inputs = std::any_of(meta_def->inputs.begin(), + meta_def->inputs.end(), [&graph_viewer](const auto& input) { + return !graph_viewer.IsConstantInitializer(input, true); + }); + + // ALL inputs are constant + if (!has_any_non_constant_inputs) { + capability.reset(); + } + } + }); + + const auto num_of_partitions = result.size(); + const auto num_of_supported_nodes = std::accumulate( + result.begin(), result.end(), size_t{0}, + [](const auto& acc, const auto& partition) -> size_t { + return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0); + }); + + const auto summary_msg = MakeString( + "VSINPUExecutionProvider::GetCapability,", + " number of partitions supported by VSINPU: ", num_of_partitions, + "; number of nodes in the graph: ", graph_viewer.NumberOfNodes(), + "; number of nodes supported by VSINPU: ", num_of_supported_nodes); + + // If the graph is partitioned in multiple subgraphs, and this may impact performance, + // we want to give users a summary message at warning level. + if (num_of_partitions > 1) { + LOGS_DEFAULT(WARNING) << summary_msg; + } else { + LOGS_DEFAULT(INFO) << summary_msg; + } + + return result; +} + +Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, + OrtKernelContext* context) { + Ort::KernelContext ctx(context); + size_t num_in = ctx.GetInputCount(); + const size_t num_inputs = graph_ep->GetGraphInputs().size(); + + for (size_t i = 0, j = 0; i < num_inputs; i++) { + if (!graph_ep->GetGraphInputs()[i]->is_initializer) { + const auto onnx_input_tensor = ctx.GetInput(i); + const auto tensor_info = onnx_input_tensor.GetTensorTypeAndShapeInfo(); + + auto origin_tensor = graph_ep->GetGraphInputs()[i]->tensor; + origin_tensor->CopyDataToTensor(onnx_input_tensor.GetTensorRawData(), + vsi::npu::util::GetTensorBytes(tensor_info)); + j++; + } + } + + if (!graph_ep->GetGraph()->Run()) { + LOGS_DEFAULT(ERROR) << "Failed to run graph."; + } + for (size_t i = 0; i < ctx.GetOutputCount(); i++) { + auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor; + auto out_shape = graph_ep->GetGraphOutputs()[i]->shape.GetDims(); + auto onnx_output_tensor = + ctx.GetOutput(i, out_shape.data(), out_shape.size()); + timvx_tensor->CopyDataFromTensor(const_cast(onnx_output_tensor.GetTensorRawData())); + } + + return Status::OK(); +} + +Status VSINPUExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + std::shared_ptr graph_ep = std::make_shared(graph_viewer); + + for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) { + LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" + << graph_viewer.IsInitializedTensor(tensor->Name()); + auto input = std::make_shared(); + input->name = tensor->Name(); + input->is_initializer = graph_viewer.IsConstantInitializer(tensor->Name(), true); + graph_ep->GetGraphInputs().push_back(input); + } + for (auto tensor : graph_viewer.GetOutputs()) { + LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); + auto output = std::make_shared(); + output->name = tensor->Name(); + output->is_initializer = false; + graph_ep->GetGraphOutputs().push_back(output); + } + + auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& node_index : node_indices) { + const auto node = graph_viewer.GetNode(node_index); + const NodeUnit& node_unit = graph_ep->GetNodeUnit(node); + + // Only add op when we hit the target node + if (node != &node_unit.GetNode()) { + continue; + } + LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]"; + vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit); + } + + LOGS_DEFAULT(INFO) << "Verifying graph"; + graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile(); + if (!graph_ep->GetCompiled()) { + LOGS_DEFAULT(ERROR) << "Failed to verify graph."; + } else { + LOGS_DEFAULT(INFO) << "Graph has been verified successfully."; + } + + NodeComputeInfo compute_info; + compute_info.create_state_func = [graph_ep](ComputeContext* /*context*/, + FunctionState* state) { + *state = graph_ep.get(); + return 0; + }; + + compute_info.compute_func = + [graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */, + OrtKernelContext* context) { + std::lock_guard lock(this->GetMutex()); + Status res = ComputeStateFunc(graph_ep.get(), context); + return res; + }; + + compute_info.release_state_func = [](FunctionState /*state*/) {}; + + node_compute_funcs.push_back(compute_info); + } + + return Status::OK(); +} + +std::shared_ptr VSINPUExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr kernel_registry = std::make_shared(); + return kernel_registry; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h new file mode 100644 index 0000000000000..44318c332fdd0 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -0,0 +1,53 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include "core/framework/execution_provider.h" +#include "core/session/abi_session_options_impl.h" + +namespace onnxruntime { +struct VSINPUExecutionProviderInfo { + int device_id{0}; +}; + +class VSINPUExecutionProvider : public IExecutionProvider { + public: + explicit VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info); + virtual ~VSINPUExecutionProvider(); + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup) const override; + std::shared_ptr GetKernelRegistry() const override; + Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; + OrtMutex& GetMutex() { return mutex_; } + + private: + int device_id_; + OrtMutex mutex_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc new file mode 100644 index 0000000000000..5f2f961d95c09 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc @@ -0,0 +1,59 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include "core/framework/compute_capability.h" +#include "core/providers/vsinpu/vsinpu_provider_factory.h" +#include "core/providers/vsinpu/vsinpu_provider_factory_creator.h" +#include "core/providers/vsinpu/vsinpu_execution_provider.h" + +namespace onnxruntime { + +struct VSINPUProviderFactory : IExecutionProviderFactory { + VSINPUProviderFactory() {} + ~VSINPUProviderFactory() override {} + + std::unique_ptr CreateProvider() override; +}; + +std::unique_ptr VSINPUProviderFactory::CreateProvider() { + onnxruntime::VSINPUExecutionProviderInfo info; + return std::make_unique(info); +} + +std::shared_ptr CreateExecutionProviderFactory_VSINPU() { + return std::make_shared(); +} + +std::shared_ptr +VSINPUProviderFactoryCreator::Create() { + return std::make_shared(); +} + +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_VSINPU, + _In_ OrtSessionOptions* options) { + options->provider_factories.push_back( + onnxruntime::VSINPUProviderFactoryCreator::Create()); + return nullptr; +} diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h new file mode 100644 index 0000000000000..e69185c0df816 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h @@ -0,0 +1,34 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once + +#include + +#include "core/providers/providers.h" + +namespace onnxruntime { +struct VSINPUProviderFactoryCreator { + static std::shared_ptr Create(); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc new file mode 100644 index 0000000000000..5d2f701ceac20 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc @@ -0,0 +1,513 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ + +#include +#include +#include +#include +#include "core/providers/vsinpu/vsinpu_util.h" + +#include "core/optimizer/initializer.h" +#include "core/providers/shared/utils/utils.h" +namespace onnxruntime { + +template +struct shared_array_deletor { + void operator()(T const* ptr) { delete[] ptr; } +}; +namespace vsi { +namespace npu { +namespace util { +tim::vx::DataType OnnxDtypeToTIMVXDtype(const int32_t dtype) { + switch (dtype) { + case onnx::TensorProto_DataType_FLOAT: + return tim::vx::DataType::FLOAT32; + case onnx::TensorProto_DataType_FLOAT16: + return tim::vx::DataType::FLOAT16; + case onnx::TensorProto_DataType_INT8: + return tim::vx::DataType::INT8; + case onnx::TensorProto_DataType_UINT8: + return tim::vx::DataType::UINT8; + case onnx::TensorProto_DataType_INT32: + return tim::vx::DataType::INT32; + case onnx::TensorProto_DataType_INT16: + return tim::vx::DataType::INT16; + case onnx::TensorProto_DataType_UINT16: + return tim::vx::DataType::UINT16; + case onnx::TensorProto_DataType_BOOL: + return tim::vx::DataType::BOOL8; + default: + LOGS_DEFAULT(WARNING) << "Unsupported data type: " << dtype; + break; + } + return tim::vx::DataType::FLOAT32; +} + +tim::vx::DataType OnnxDtypeToTIMVXDtype(const ONNX_NAMESPACE::DataType type) { + static const std::map type_table = { + {"tensor(float)", tim::vx::DataType::FLOAT32}, + {"tensor(float16)", tim::vx::DataType::FLOAT16}, + {"tensor(int8)", tim::vx::DataType::INT8}, + {"tensor(uint8)", tim::vx::DataType::UINT8}, + {"tensor(int32)", tim::vx::DataType::INT32}, + {"tensor(int16)", tim::vx::DataType::INT16}, + {"tensor(uint16)", tim::vx::DataType::UINT16}, + {"tensor(int64)", tim::vx::DataType::INT64}, + {"tensor(bool)", tim::vx::DataType::BOOL8}, + }; + auto search = type_table.find(*type); + if (search != type_table.end()) { + return search->second; + } + LOGS_DEFAULT(WARNING) << "Unsupported data type: " << *type; + return tim::vx::DataType::FLOAT32; +} + +tim::vx::ShapeType OnnxShapeToTIMVXShape(const onnxruntime::TensorShape& ts) { + tim::vx::ShapeType timvx_shape(ts.NumDimensions()); + if (ts.NumDimensions() == 0) { + timvx_shape.push_back(1); + } else { + for (size_t i = 0; i < ts.NumDimensions(); i++) { + timvx_shape[i] = ts.GetDims()[i]; + } + } + return timvx_shape; +} + +std::string PrintNode(const onnxruntime::NodeArg& node_arg) { + auto shape = node_arg.Shape(); + if (shape == nullptr) { + return ""; + } + std::string s = node_arg.Name() + ":<"; + if (shape->dim_size() == 0) { + s += "1>, is a scalar"; + return s; + } + for (int i = 0; i < shape->dim_size(); i++) { + auto dim = shape->dim(i); + std::string s1; + std::stringstream ss; + ss << dim.dim_value(); + ss >> s1; + s += s1; + if (i < shape->dim_size() - 1) { + s += ","; + } else { + s += ">"; + } + } + return s; +} + +std::string PrintNode(const std::vector shape) { + if (shape.size() == 0) { + return ""; + } + std::string s = "<"; + for (std::size_t i = 0; i < shape.size(); i++) { + auto dim = shape[i]; + std::string s1; + std::stringstream ss; + ss << dim; + ss >> s1; + s += s1; + if (i < shape.size() - 1) { + s += ","; + } else { + s += ">"; + } + } + return s; +} + +size_t GetTensorElementSize(const ONNXTensorElementDataType type) { + switch (type) { + case onnx::TensorProto_DataType_INT64: + return 8; + case onnx::TensorProto_DataType_FLOAT: + case onnx::TensorProto_DataType_INT32: + return 4; + case onnx::TensorProto_DataType_FLOAT16: + case onnx::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType_UINT16: + return 2; + case onnx::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType_UINT8: + case onnx::TensorProto_DataType_BOOL: + return 1; + default: + break; + } + return 0; +} + +size_t GetTensorBytes(const Ort::TensorTypeAndShapeInfo& info) { + return info.GetElementCount() * GetTensorElementSize(info.GetElementType()); +} + +TensorShape GetTensorShape(const onnxruntime::NodeArg& node_arg) { + auto shape_proto = node_arg.Shape(); + std::vector dims; + if (shape_proto != nullptr) { + for (int i = 0; i < shape_proto->dim_size(); i++) { + auto dim = shape_proto->dim(i); + dims.push_back(dim.dim_value()); + } + } + if (dims.size() == 0) { + dims.push_back(1); + } + TensorShape ts(dims); + return ts; +} + +std::shared_ptr UnpackTensor( + const NodeArg* node_arg, const ONNX_NAMESPACE::TensorProto& initializer) { + std::shared_ptr unpackedTensor; + auto shape = GetTensorShape(*node_arg); + size_t elementCount = shape.Size(); + +#define CASE_PROTO(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: { \ + size_t tensorByteSize = elementCount * sizeof(Y); \ + unpackedTensor.reset(new uint8_t[tensorByteSize], \ + shared_array_deletor()); \ + auto status = onnxruntime::utils::UnpackTensor( \ + initializer, \ + initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ + initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ + reinterpret_cast(unpackedTensor.get()), elementCount); \ + if (!status.IsOK()) { \ + LOGS_DEFAULT(ERROR) << "Unpack tensor data failed."; \ + } \ + break; \ + } + switch (initializer.data_type()) { + CASE_PROTO(FLOAT, float); + CASE_PROTO(DOUBLE, double); + CASE_PROTO(BOOL, bool); + CASE_PROTO(INT8, int8_t); + CASE_PROTO(INT16, int16_t); + CASE_PROTO(INT32, int32_t); + CASE_PROTO(INT64, int64_t); + CASE_PROTO(UINT8, uint8_t); + CASE_PROTO(UINT16, uint16_t); + CASE_PROTO(UINT32, uint32_t); + CASE_PROTO(FLOAT16, onnxruntime::MLFloat16); + default: + return nullptr; + } + + return unpackedTensor; +} + +tim::vx::PadType GetPadType(const std::string type) { + static const std::map type_table = { + {"NOTSET", tim::vx::PadType::AUTO}, + {"SAME_UPPER", tim::vx::PadType::SAME}, + {"SAME_LOWER", tim::vx::PadType::SAME}, + {"VALID", tim::vx::PadType::VALID}, + }; + auto search = type_table.find(type); + if (search != type_table.end()) { + return search->second; + } + return tim::vx::PadType::NONE; +} + +int32_t ReverseAxis(int32_t origin_axis, int32_t length) { + int32_t axis = 0; + if (origin_axis < 0) { + origin_axis += length; + } + axis = length - origin_axis - 1; + return axis; +} + +std::vector ReverseAxis(std::vector origin_axes, int32_t length) { + std::vector axes; + for (int32_t& axis : origin_axes) { + if (axis < 0) { + axis += length; + } + axes.push_back(length - axis - 1); + } + std::sort(axes.begin(), axes.end()); + return axes; +} + +bool IsTypeSupported(const NodeArg* node_arg) { + const auto* type_proto = node_arg->TypeAsProto(); + if (!type_proto) { + return false; + } + + switch (type_proto->tensor_type().elem_type()) { + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + return true; + default: + return false; + } +} + +QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) { + const auto& op_type = node_unit.OpType(); + if (node_unit.UnitType() == NodeUnit::Type::SingleNode) { + if (op_type == "DequantizeLinear") + return QuantizedOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QuantizedOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QuantizedOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QuantizedOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QuantizedOpType::QLinearAdd; + else if (op_type == "QLinearMul") + return QuantizedOpType::QLinearMul; + else if (op_type == "QLinearSigmoid") + return QuantizedOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QuantizedOpType::QLinearAveragePool; + } else if (node_unit.UnitType() == NodeUnit::Type::QDQGroup) { + if (op_type == "Conv") + return QuantizedOpType::QDQConv; + else if (op_type == "Resize") + return QuantizedOpType::QDQResize; + else if (op_type == "AveragePool") + return QuantizedOpType::QDQAveragePool; + else if (op_type == "Add") + return QuantizedOpType::QDQAdd; + else if (op_type == "Mul") + return QuantizedOpType::QDQMul; + else if (op_type == "Transpose") + return QuantizedOpType::QDQTranspose; + else if (op_type == "Reshape") + return QuantizedOpType::QDQReshape; + else if (op_type == "Softmax") + return QuantizedOpType::QDQSoftmax; + else if (op_type == "Concat") + return QuantizedOpType::QDQConcat; + else if (op_type == "Gemm") + return QuantizedOpType::QDQGemm; + else if (op_type == "MatMul") + return QuantizedOpType::QDQMatMul; + } + return QuantizedOpType::Unknown; +} + +ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers) { + NodeAttrHelper helper(node_unit); + const auto group = helper.Get("group", 1); + + const auto& weight = node_unit.Inputs()[1].node_arg.Name(); + const auto& weight_tensor = *initializers.at(weight); + + // For ONNX we only have 1 conv ops + // For VSINPU we have 3 + // Input is (W, H, C, N) + // group == 1, --> regular conv + // group != 1 && weight is (kW, kH, group, M), --> depthwise conv + // group != 1 && weight is (kW, kH, C/group, M), --> grouped conv + if (group == 1) + return ConvType::Regular; + else if ((weight_tensor.dims()[1] == group)) + return ConvType::Depthwise; + else + return ConvType::Grouped; +} + +bool IsQuantizedConv(QuantizedOpType quant_op_type) { + return (quant_op_type == QuantizedOpType::QLinearConv) || + (quant_op_type == QuantizedOpType::QDQConv); +} + +bool IsQuantizedPool(QuantizedOpType quant_op_type) { + return (quant_op_type == QuantizedOpType::QLinearAveragePool) || + (quant_op_type == QuantizedOpType::QDQAveragePool); +} + +bool IsQuantizedGemm(QuantizedOpType quant_op_type) { + return (quant_op_type == QuantizedOpType::QLinearMatMul) || + (quant_op_type == QuantizedOpType::QDQGemm) || + (quant_op_type == QuantizedOpType::QDQMatMul); +} + +bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type) { + return quant_op_type == QuantizedOpType::QLinearMatMul || + quant_op_type == QuantizedOpType::QLinearAdd || + quant_op_type == QuantizedOpType::QLinearMul || + quant_op_type == QuantizedOpType::QDQAdd || + quant_op_type == QuantizedOpType::QDQMul || + quant_op_type == QuantizedOpType::QDQGemm || + quant_op_type == QuantizedOpType::QDQMatMul || + IsQuantizedConv(quant_op_type); +} + +bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) { + auto quant_op_type = GetQuantizedOpType(node_unit); + int32_t a_input_type, b_input_type; + if (!IsQuantizedBinaryOp(quant_op_type)) { + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] is not a binary qlinear op"; + return false; + } + + const auto& inputs = node_unit.Inputs(); + if (!GetType(inputs[0].node_arg, a_input_type)) + return false; + if (!GetType(inputs[1].node_arg, b_input_type)) + return false; + + // QlinearConv/MatMul/QDQGemm/QDQMatMul supports u8u8 or u8s8 + // QLinearAdd/QLinearMul only support u8u8 + bool is_quant_conv_or_gemm = IsQuantizedConv(quant_op_type) || IsQuantizedGemm(quant_op_type); + + bool has_valid_qlinear_conv_weight = + (b_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + b_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8); + + bool has_valid_qlinear_conv_input = + (a_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + a_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8); + + if ((is_quant_conv_or_gemm && !has_valid_qlinear_conv_weight) || + (!is_quant_conv_or_gemm && a_input_type != b_input_type)) { + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() + << "] A Input type: [" << a_input_type + << "] B Input type: [" << b_input_type + << "] is not supported for now"; + return false; + } + + return true; +} + +void GetQuantizationScaleAndZeroPoint( + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::filesystem::path& model_path, + float& scale, int32_t& zero_point, std::optional>& pcq_scales, + std::optional>& pcq_zps) { + scale = 0.0f; + zero_point = 0; + + const auto& quant_param = *io_def.quant_param; + { // get the scale + const auto& name = quant_param.scale.Name(); + const auto* s = graph_viewer.GetConstantInitializer(name); + if (!s) { + LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; + }; + Initializer unpacked_tensor(*s, model_path); + scale = unpacked_tensor.DataAsSpan()[0]; + + // per channel quantized handling + if (!unpacked_tensor.dims().empty() && unpacked_tensor.dims()[0] != 0 && unpacked_tensor.dims()[0] != 1) { + auto scales = unpacked_tensor.DataAsSpan(); + std::vector scales_vec(scales.begin(), scales.end()); + pcq_scales = onnxruntime::make_optional(std::move(scales_vec)); + } + } + + if (quant_param.zero_point) { // get the zero point if it exists + const auto& name = quant_param.zero_point->Name(); + const auto* s = graph_viewer.GetConstantInitializer(name); + if (!s) { + LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; + }; + Initializer unpacked_tensor(*s, model_path); + bool is_i8_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT8; + // some qdq conv bias is int32 quantized + bool is_int32_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT32; + zero_point = is_i8_zp + ? static_cast(unpacked_tensor.DataAsSpan()[0]) + : is_int32_zp ? static_cast(unpacked_tensor.DataAsSpan()[0]) + : static_cast(unpacked_tensor.DataAsByteSpan()[0]); + + // per channel quantized handling + if (!unpacked_tensor.dims().empty() && unpacked_tensor.dims()[0] != 0 && unpacked_tensor.dims()[0] != 1) { + auto type = unpacked_tensor.data_type(); + if (is_i8_zp) { + auto zps = unpacked_tensor.DataAsSpan(); + std::vector zps_vec(zps.begin(), zps.end()); + pcq_zps = onnxruntime::make_optional(std::move(zps_vec)); + } else if (is_int32_zp) { + auto zps = unpacked_tensor.DataAsByteSpan(); + std::vector zps_vec(zps.begin(), zps.end()); + pcq_zps = onnxruntime::make_optional(std::move(zps_vec)); + } else { + auto zps = unpacked_tensor.DataAsSpan(); + std::vector zps_vec(zps.begin(), zps.end()); + pcq_zps = onnxruntime::make_optional(std::move(zps_vec)); + } + } + } +} + +static bool IsInternalQuantizedNodeUnit(const NodeUnit& node_unit) { + // First, ignore QDQ NodeUnit which is not internal quantized node + if (node_unit.UnitType() == NodeUnit::Type::QDQGroup) + return false; + + // These operators can use uint8 input without specific QLinear version of it + // However, the mode has to be internal to the graph/partition (they cannot consume graph inputs) + static const std::unordered_set internal_quantized_op_types = { + "Transpose", + "Resize", + "Concat", + "MaxPool", + }; + + const auto& node = node_unit.GetNode(); + if (!Contains(internal_quantized_op_types, node.OpType())) + return false; + + int32_t input_type; + ORT_ENFORCE(GetType(*node.InputDefs()[0], input_type)); + + return input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; +} + +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + LOGS_DEFAULT(WARNING) << "NodeArg [" << node_arg.Name() << "] has no input type"; + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} +} // namespace util +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.h b/onnxruntime/core/providers/vsinpu/vsinpu_util.h new file mode 100644 index 0000000000000..09ed8d07dcd89 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.h @@ -0,0 +1,131 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ + +#pragma once +#include +#include +#include +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensorprotoutils.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/node_unit.h" +#include "tim/vx/tensor.h" +#include "tim/vx/types.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +namespace util { + +tim::vx::DataType OnnxDtypeToTIMVXDtype(const int32_t dtype); + +tim::vx::DataType OnnxDtypeToTIMVXDtype(const ONNX_NAMESPACE::DataType type); + +tim::vx::ShapeType OnnxShapeToTIMVXShape(const onnxruntime::TensorShape& ts); + +std::string PrintNode(const onnxruntime::NodeArg& node_arg); + +std::string PrintNode(const std::vector shape); + +size_t GetTensorElementSize(const ONNXTensorElementDataType type); + +size_t GetTensorBytes(const Ort::TensorTypeAndShapeInfo& info); + +TensorShape GetTensorShape(const onnxruntime::NodeArg& node_arg); + +std::shared_ptr UnpackTensor( + const NodeArg* node, const ONNX_NAMESPACE::TensorProto& initializer); + +tim::vx::PadType GetPadType(const std::string type); + +int32_t ReverseAxis(int32_t origin_axis, int32_t length); + +std::vector ReverseAxis(std::vector origin_axes, int32_t length); + +bool IsTypeSupported(const NodeArg* node_arg); + +enum class QuantizedOpType : uint8_t { + Unknown, // Unknown or not a quantized NodeUnit + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + // Not yet supported + // QLinearReduceMean, + QDQConv, + QDQResize, + QDQAveragePool, + QDQAdd, + QDQMul, + QDQTranspose, + QDQReshape, + QDQSoftmax, + QDQConcat, + QDQGemm, + QDQMatMul, + // TODO(cfy) :Add other QDQ NodeUnit types +}; + +enum class ConvType : uint8_t { + Regular, + Depthwise, + Grouped, +}; +QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit); + +ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers); + +// If this is a quantized Conv (QLinearConv or QDQConv) +bool IsQuantizedConv(QuantizedOpType quant_op_type); + +// If this is a quantized Pool (QLinearAveragePool or QDQAveragePool) +bool IsQuantizedPool(QuantizedOpType quant_op_type); + +// If this is a quantized Gemm (QLinearMatMul or QDQMatMul/QDQGemm) +bool IsQuantizedGemm(QuantizedOpType quant_op_type); + +// This quantized op is an operator or qdq node unit takes 2 inputs and produces 1 output +// Such as QLinearConv, QLinearMatMul, QLinearAdd, QDQConv,... +bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type); + +// Check if a qlinear binary op has valid inputs, Qlinear[Conv/MatMul/Add] +bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit); + +void GetQuantizationScaleAndZeroPoint( + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::filesystem::path& model_path, + float& scale, int32_t& zero_point, + std::optional>& pcq_scales, + std::optional>& pcq_zps); + +bool GetType(const NodeArg& node_arg, int32_t& type); + +} // namespace util +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 5d1794ed0afa8..44e6953db438e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -136,21 +136,33 @@ bool IsSupportedDataType(const int32_t data_type, supported_data_types.end(); } -bool IsValidMultidirectionalBroadcast(std::vector& shape_a, - std::vector& shape_b, - const logging::Logger& logger) { +bool GetBidirectionalBroadcastShape(std::vector& shape_a, + std::vector& shape_b, + std::vector& output_shape) { size_t size_a = shape_a.size(); size_t size_b = shape_b.size(); size_t smaller_size = std::min(size_a, size_b); - for (size_t i = 0; i < smaller_size; i++) { + size_t larger_size = std::max(size_a, size_b); + + output_shape.resize(larger_size); + + for (size_t i = 0; i < larger_size; i++) { // right alignment size_t axis_a = size_a - i - 1; size_t axis_b = size_b - i - 1; - // Broadcastable tensors must either have each dimension the same size or equal to one. - if (shape_a[axis_a] != shape_b[axis_b] && shape_a[axis_a] != 1 && shape_b[axis_b] != 1) { - return false; + + if (i < smaller_size) { + // Broadcastable tensors must either have each dimension the same size or equal to one. + if (shape_a[axis_a] != shape_b[axis_b] && shape_a[axis_a] != 1 && shape_b[axis_b] != 1) { + return false; + } + output_shape[larger_size - i - 1] = std::max(shape_a[axis_a], shape_b[axis_b]); + } else { + // For the remaining dimensions in the larger tensor, copy the dimension size directly to the output shape. + output_shape[larger_size - i - 1] = (size_a > size_b) ? shape_a[axis_a] : shape_b[axis_b]; } } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 7240fa37d9cc9..496f886e5a076 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -160,7 +160,7 @@ static const InlinedHashMap op_map = { {"ArgMax", {"argMax", true}}, {"ArgMin", {"argMin", true}}, {"AveragePool", {"averagePool2d", true}}, - {"BatchNormalization", {"batchNormalization", false}}, + {"BatchNormalization", {"batchNormalization", true}}, {"Cast", {"cast", true}}, {"Ceil", {"ceil", true}}, {"Clip", {"clamp", true}}, @@ -176,11 +176,11 @@ static const InlinedHashMap op_map = { {"Equal", {"equal", true}}, {"Erf", {"erf", false}}, {"Exp", {"exp", true}}, - {"Expand", {"expand", false}}, + {"Expand", {"expand", true}}, {"Flatten", {"reshape", true}}, {"Floor", {"floor", true}}, {"Gather", {"gather", true}}, - {"Gelu", {"gelu", false}}, + {"Gelu", {"gelu", true}}, {"Gemm", {"gemm", true}}, {"GlobalAveragePool", {"averagePool2d", true}}, {"GlobalMaxPool", {"maxPool2d", true}}, @@ -190,8 +190,8 @@ static const InlinedHashMap op_map = { {"HardSigmoid", {"hardSigmoid", true}}, {"HardSwish", {"hardSwish", true}}, {"Identity", {"identity", true}}, - {"InstanceNormalization", {"instanceNormalization", false}}, - {"LayerNormalization", {"layerNormalization", false}}, + {"InstanceNormalization", {"instanceNormalization", true}}, + {"LayerNormalization", {"layerNormalization", true}}, {"LeakyRelu", {"leakyRelu", true}}, {"Less", {"lesser", true}}, {"LessOrEqual", {"lesserOrEqual", true}}, @@ -208,24 +208,24 @@ static const InlinedHashMap op_map = { {"Pad", {"pad", true}}, {"Pow", {"pow", true}}, {"PRelu", {"prelu", true}}, - {"Reciprocal", {"reciprocal", false}}, - {"ReduceL1", {"reduceL1", false}}, - {"ReduceL2", {"reduceL2", false}}, - {"ReduceLogSum", {"reduceLogSum", false}}, - {"ReduceLogSumExp", {"reduceLogSumExp", false}}, + {"Reciprocal", {"reciprocal", true}}, + {"ReduceL1", {"reduceL1", true}}, + {"ReduceL2", {"reduceL2", true}}, + {"ReduceLogSum", {"reduceLogSum", true}}, + {"ReduceLogSumExp", {"reduceLogSumExp", true}}, {"ReduceMax", {"reduceMax", true}}, {"ReduceMean", {"reduceMean", true}}, {"ReduceMin", {"reduceMin", true}}, {"ReduceProd", {"reduceProduct", true}}, {"ReduceSum", {"reduceSum", true}}, - {"ReduceSumSquare", {"reduceSumSquare", false}}, + {"ReduceSumSquare", {"reduceSumSquare", true}}, {"Relu", {"relu", true}}, {"Reshape", {"reshape", true}}, {"Resize", {"resample2d", true}}, {"Shape", {"slice", true}}, {"Sigmoid", {"sigmoid", true}}, {"Softplus", {"softplus", true}}, - {"Softsign", {"softsign", false}}, + {"Softsign", {"softsign", true}}, {"Sin", {"sin", true}}, {"Slice", {"slice", true}}, {"Softmax", {"softmax", true}}, @@ -277,9 +277,9 @@ static const std::unordered_set webnn_supp bool IsSupportedDataType(const int32_t data_type, const std::unordered_set& supported_data_types); -bool IsValidMultidirectionalBroadcast(std::vector& shape_a, - std::vector& shape_b, - const logging::Logger& logger); +bool GetBidirectionalBroadcastShape(std::vector& shape_a, + std::vector& shape_b, + std::vector& output_shape); bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 2c97ef490f4a3..23e19d5943144 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -63,14 +63,23 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - // XNNPACK prelu operator expects slope to be a static value. - // https://github.com/google/XNNPACK/issues/4692 - // TODO: Remove this check after it is solved. - if (op_type == "PRelu" && !Contains(initializers, input_defs[1]->Name()) && device_type == WebnnDeviceType::CPU) { - LOGS(logger, VERBOSE) << "The second input (slope) for PRelu must be a constant initializer for WebNN CPU backend."; + std::vector input0_shape; + std::vector input1_shape; + if (!GetShape(*input_defs[0], input0_shape, logger) || + !GetShape(*input_defs[1], input1_shape, logger)) { return false; } + // 'prelu' op in WebNN CPU backend restricts the last dimension of input and slope to be same. + // TODO: Remove this workaround once the associated issue is resolved in Chromium: + // https://issues.chromium.org/issues/335517470. + if (op_type == "PRelu" && device_type == WebnnDeviceType::CPU) { + if (input0_shape.back() != input1_shape.back()) { + LOGS(logger, VERBOSE) << "The last dimension of input and slope for PRelu must be same for WebNN CPU backend."; + return false; + } + } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 4eaa4855ff7b2..847db6a9975c6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -37,7 +37,6 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod // skip the weight for conv as we need to transpose for preferred layout NHWC. if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W - model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); } } @@ -168,7 +167,7 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, if (is_conv == 1) dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 else - dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv weight + dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight SafeInt num_elements = SafeInt(Product(dest_shape)); @@ -186,7 +185,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: element_size = sizeof(float); break; - break; default: break; } @@ -232,9 +230,11 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = emscripten::val::object(); + const auto& initializers(model_builder.GetInitializerTensors()); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); @@ -249,6 +249,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; + const bool is_constant_weight = Contains(initializers, weight_name); // Support conv1d by prepending a 1 or 2 size dimensions. if (is_conv1d) { // Reshape input. @@ -274,12 +275,15 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val options = emscripten::val::object(); ORT_RETURN_IF_ERROR(SetConvBaseOptions( model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); + bool depthwise = false; if (op_type == "Conv" || op_type == "ConvInteger") { int groups = options["groups"].as(); if (is_nhwc) { - bool depthwise = (groups == input_shape[3] && groups != 1); + depthwise = (groups == input_shape[3] && groups != 1); options.set("inputLayout", emscripten::val("nhwc")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); + if (is_constant_weight) { + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); + } if (!depthwise) { options.set("filterLayout", emscripten::val("ohwi")); } else { @@ -290,15 +294,37 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (is_nhwc) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); + if (is_constant_weight) { + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); + } } } emscripten::val filter = model_builder.GetOperand(weight_name); - if (!is_nhwc && is_conv1d) { - // Reshape weight to 4D for conv1d with NCHW preferred layout. - std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + + if (is_conv1d) { + // Reshape weight to 4D for conv1d. + if (!is_nhwc || !is_constant_weight) { + // The weight_shape has been appended 1's, reshape weight operand. + std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); + filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + } + } + + emscripten::val transpose_options = emscripten::val::object(); + if (is_nhwc && !is_constant_weight) { + // For NHWC preferred layout, if the weight is input: + // - Transpose it from iohw -> ohwi for convTranspose. + // - Transpose it from oihw -> ihwo for depthwise conv. + // - Transpose it from oihw -> ohwi for conv. + std::vector perm(4); + if (op_type == "ConvTranspose" || depthwise) { + perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight + } else { + perm = {0, 2, 3, 1}; // L_0231 + } + transpose_options.set("permutation", emscripten::val::array(perm)); + filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); } if (op_type == "Conv") { @@ -371,13 +397,6 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } - // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. - // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 - if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; - } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index 9d4de45fdafd1..9c75c00fa9273 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -44,18 +44,19 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); const auto& shape_tensor = *initializers.at(input_defs[1]->Name()); - std::vector new_shape; + std::vector new_shape; ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(shape_tensor, new_shape, logger), "Cannot get shape."); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape."); - if (new_shape.size() < input_shape.size()) { - // Enlarge new shape to input.rank, right aligned with leading ones - new_shape.insert(new_shape.begin(), input_shape.size() - new_shape.size(), 1); - } + + std::vector output_shape; + ORT_RETURN_IF_NOT(GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape), "Cannot get output shape."); + emscripten::val output = model_builder.GetBuilder().call("expand", - input, emscripten::val::array(new_shape)); + input, + emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } @@ -95,11 +96,8 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } - if (new_shape.size() > input_shape.size()) { - LOGS(logger, VERBOSE) << "The size of shape must be less than or equal to the rank of input."; - } - - if (!IsValidMultidirectionalBroadcast(input_shape, new_shape, logger)) { + std::vector output_shape; + if (!GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape)) { LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape); return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 90ad9b48d5866..a2aa0df5586e3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -87,11 +87,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); - if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) { - std::iota(axes.begin(), axes.end(), axis - 1); - } else { - std::iota(axes.begin(), axes.end(), axis); - } + std::iota(axes.begin(), axes.end(), axis); + options.set("axes", emscripten::val::array(axes)); output = model_builder.GetBuilder().call("layerNormalization", input, options); } else if (op_type == "InstanceNormalization") { diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index c46b04a3c29d9..6b0e1495f552d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -104,8 +104,7 @@ Status ModelBuilder::RegisterInitializers() { if (tensor.has_raw_data()) { tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); } else { - unpacked_tensors_.push_back({}); - std::vector& unpacked_tensor = unpacked_tensors_.back(); + std::vector unpacked_tensor; ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); tensor_ptr = reinterpret_cast(unpacked_tensor.data()); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 80077b3abe56d..6a1688f16d2a6 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -66,7 +66,6 @@ class ModelBuilder { emscripten::val wnn_builder_ = emscripten::val::object(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; - std::vector> unpacked_tensors_; InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 13ed29667debe..e839d6d17b7d9 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -17,24 +17,12 @@ namespace onnxruntime { -WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, - const std::string& webnn_threads_number, const std::string& webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { - // Create WebNN context and graph builder. - const emscripten::val ml = emscripten::val::global("navigator")["ml"]; - if (!ml.as()) { - ORT_THROW("Failed to get ml from navigator."); - } - emscripten::val context_options = emscripten::val::object(); - context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; - // Set "numThreads" if it's not default 0. - if (webnn_threads_number.compare("0") != 0) { - context_options.set("numThreads", stoi(webnn_threads_number)); - } } else { preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { @@ -45,11 +33,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f ORT_THROW("Unknown WebNN deviceType."); } } - if (webnn_power_flags.compare("default") != 0) { - context_options.set("powerPreference", emscripten::val(webnn_power_flags)); - } - wnn_context_ = ml.call("createContext", context_options).await(); + wnn_context_ = emscripten::val::module_property("currentContext"); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } @@ -96,6 +81,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto& logger = *GetLogger(); + if (!wnn_builder_.as()) { + // The GetCapability function may be called again after Compile due to the logic in the + // PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc). + // We need to re-create the wnn_builder_ here to avoid it's been released in last Compile. + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + } + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger); if (node_groups.empty()) { @@ -337,6 +329,9 @@ common::Status WebNNExecutionProvider::Compile(const std::vector> @@ -43,8 +42,8 @@ class WebNNExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; private: - emscripten::val wnn_context_ = emscripten::val::object(); - emscripten::val wnn_builder_ = emscripten::val::object(); + emscripten::val wnn_context_ = emscripten::val::undefined(); + mutable emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; webnn::WebnnDeviceType wnn_device_type_; diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 11acec8b1f354..7792aeabaabf2 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,27 +10,22 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number, - const std::string& webnn_power_flags) - : webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {} + explicit WebNNProviderFactory(const std::string& webnn_device_flags) + : webnn_device_flags_(webnn_device_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; std::string webnn_device_flags_; - std::string webnn_threads_number_; - std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { - return std::make_unique(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_); + return std::make_unique(webnn_device_flags_); } std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { - return std::make_shared(provider_options.at("deviceType"), - provider_options.at("numThreads"), - provider_options.at("powerPreference")); + return std::make_shared(provider_options.at("deviceType")); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 0366d9f893f7e..b815cc1570c96 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/transpose_helper.h" diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index d0c46142ac060..4c782f647371e 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -10,7 +10,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/data_types.h" #include "core/framework/error_code_helper.h" #include "core/framework/onnxruntime_typeinfo.h" @@ -584,7 +584,9 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernel auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); // Deserialize TensorProto into pre-allocated, empty Tensor. - status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp); + // TODO: here the TensorProto loses model path information, so it cannot be an external tensor. + status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), std::filesystem::path(), + tensor_proto, *tensorp); if (!status.IsOK()) { return onnxruntime::ToOrtStatus(status); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8d33bc154b0c..f0eed91d70440 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -250,6 +250,7 @@ std::atomic InferenceSession::global_session_id_{1}; std::map InferenceSession::active_sessions_; #ifdef _WIN32 OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +onnxruntime::WindowsTelemetry::EtwInternalCallback InferenceSession::callback_ML_ORT_provider_; #endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, @@ -374,15 +375,14 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, active_sessions_[global_session_id_++] = this; // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider - WindowsTelemetry::RegisterInternalCallback( - [this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { + callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( + [this](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { (void)SourceId; (void)Level; (void)MatchAnyKeyword; @@ -396,19 +396,18 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, LogAllSessions(); } }); + WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - etwRegistrationManager.RegisterInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { (void)SourceId; (void)Level; (void)MatchAnyKeyword; @@ -439,6 +438,10 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } } }); + + // Register callback for ETW capture state (rundown) + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + #endif SetLoggingManager(session_options, session_env); @@ -720,9 +723,11 @@ InferenceSession::~InferenceSession() { } } - // Unregister the session + // Unregister the session and ETW callbacks #ifdef _WIN32 std::lock_guard lock(active_sessions_mutex_); + WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); #endif active_sessions_.erase(global_session_id_); @@ -875,9 +880,7 @@ common::Status InferenceSession::RegisterGraphTransformer( return graph_transformer_mgr_.Register(std::move(p_graph_transformer), level); } -common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) const { - ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-endian machines"); - +common::Status InferenceSession::SaveToOrtFormat(const std::filesystem::path& filepath) const { // Get the byte size of the ModelProto and round it to the next MB and use it as flatbuffers' init_size // TODO: Investigate whether we should set a max size, and clarify the cost of having a buffer smaller than // what the total flatbuffers serialized size will be. @@ -915,7 +918,7 @@ common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) con uint8_t* buf = builder.GetBufferPointer(); int size = builder.GetSize(); file.write(reinterpret_cast(buf), size); - ORT_RETURN_IF_NOT(file, "Failed to save ORT format model to file: ", ToUTF8String(filepath)); + ORT_RETURN_IF_NOT(file, "Failed to save ORT format model to file: ", ToUTF8String(filepath.native())); } return Status::OK(); @@ -1267,8 +1270,9 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // for the result of the first step in layout transformation debug_graph_fn = [counter = 1, this](const Graph& graph) mutable { if (graph.GraphProtoSyncNeeded()) { - ORT_THROW_IF_ERROR( - Model::Save(*model_, "post_layout_transform_step_" + std::to_string(counter) + ".onnx")); + std::basic_ostringstream modelpath; + modelpath << ORT_TSTR("post_layout_transform_step_") << counter << ORT_TSTR(".onnx"); + ORT_THROW_IF_ERROR(Model::Save(*model_, modelpath.str())); } // counter is used to denote the step, so increment regardless of whether we wrote out the model in this step. @@ -1384,8 +1388,6 @@ Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len } Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort_format_model_bytes) { - static_assert(FLATBUFFERS_LITTLEENDIAN, "ORT format only supports little-endian machines"); - std::lock_guard l(session_mutex_); if (is_model_loaded_) { // already loaded diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 204b24974ff50..e1cd085d2c271 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/common/common.h" #include "core/common/inlined_containers.h" @@ -35,6 +36,10 @@ #include "core/platform/tracing.h" #include #endif +#ifdef _WIN32 +#include "core/platform/windows/logging/etw_sink.h" +#include "core/platform/windows/telemetry.h" +#endif namespace ONNX_NAMESPACE { class ModelProto; @@ -124,6 +129,8 @@ class InferenceSession { static std::map active_sessions_; #ifdef _WIN32 static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ + static onnxruntime::WindowsTelemetry::EtwInternalCallback callback_ML_ORT_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; #endif public: @@ -615,7 +622,7 @@ class InferenceSession { return !custom_schema_registries_.empty(); } - common::Status SaveToOrtFormat(const PathString& filepath) const; + common::Status SaveToOrtFormat(const std::filesystem::path& filepath) const; #endif /** diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 7f7ed5e436afe..4f9669a7dcc4c 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -5,6 +5,7 @@ // It implements onnxruntime::ProviderHost #include "core/common/inlined_containers.h" +#include "core/common/path_string.h" #include "core/framework/allocator_utils.h" #include "core/framework/config_options.h" #include "core/framework/compute_capability.h" @@ -27,6 +28,7 @@ #include "core/session/inference_session.h" #include "core/session/abi_session_options_impl.h" #include "core/session/ort_apis.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" #include "core/util/math.h" #include "core/framework/sparse_utils.h" @@ -35,6 +37,7 @@ #include "core/framework/model_metadef_id_generator.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" @@ -66,10 +69,12 @@ using StringStringEntryProtos = google::protobuf::RepeatedPtrField; using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; using ValueInfoProtos = google::protobuf::RepeatedPtrField; +using FunctionProtos = google::protobuf::RepeatedPtrField; } // namespace ONNX_NAMESPACE namespace onnxruntime { using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; +using IndexedSubGraph_SourceOfSchema = IndexedSubGraph::SourceOfSchema; } // namespace onnxruntime #include "core/common/cpuid_info.h" @@ -130,6 +135,8 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); +ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); +ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} @@ -241,6 +248,11 @@ struct ProviderHostImpl : ProviderHost { void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { GetProviderInfo_CUDA().CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } #endif +#ifdef USE_MIGRAPHX + std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(device_id, name); } + std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXPinnedAllocator(device_id, name); } +#endif + #ifdef USE_ROCM std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_ROCM().CreateROCMAllocator(device_id, name); } std::unique_ptr CreateROCMPinnedAllocator(const char* name) override { return GetProviderInfo_ROCM().CreateROCMPinnedAllocator(name); } @@ -283,7 +295,7 @@ struct ProviderHostImpl : ProviderHost { Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } - Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const std::filesystem::path& model_path, /*out*/ std::vector& unpacked_tensor) override { return utils::UnpackInitializerData(tensor, model_path, unpacked_tensor); } @@ -391,6 +403,11 @@ struct ProviderHostImpl : ProviderHost { int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); } ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); }; + // OperatorSetIdProto + std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) override { return p->mutable_domain(); } + void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) override { return p->set_version(version); } + int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) override { return p->version(); } + #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional (wrapped) const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) override { return p->elem_type(); } @@ -519,6 +536,11 @@ struct ProviderHostImpl : ProviderHost { void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); } ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); }; + const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) override { return p->opset_import(index); } + ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) override { return p->mutable_opset_import(index); } + int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) override { return p->opset_import_size(); } + ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) override { return p->add_opset_import(); } + // NodeProto (wrapped) std::unique_ptr NodeProto__construct() override { return std::make_unique(); } void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) override { delete p; } @@ -526,6 +548,7 @@ struct ProviderHostImpl : ProviderHost { int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) override { return p->attribute_size(); } const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const override { return p->attribute(index); } ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) override { return p->mutable_attribute(index); } + ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) override { return p->add_attribute(); } // TensorProto (wrapped) std::unique_ptr TensorProto__construct() override { return std::make_unique(); } @@ -600,9 +623,65 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } - static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { - auto* shape = ctx.getAttribute("shape"); - auto* data_type = ctx.getAttribute("data_type"); + // FunctionProto (wrapped) + std::unique_ptr FunctionProto__construct() override { return std::make_unique(); } + void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) override { delete p; } + + bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) override { return p->SerializeToString(&string); } + bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) override { return p->SerializeToOstream(&output); } + bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) override { return p->ParseFromString(data); } + std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) override { return p->SerializeAsString(); } + + bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_name(); } + const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->name(); } + void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const std::string& name) override { p->set_name(name); } + + bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_doc_string(); } + const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->doc_string(); } + void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const std::string& doc_string) override { p->set_doc_string(doc_string); } + + bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_domain(); } + const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->domain(); } + void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const std::string& domain) override { p->set_domain(domain); } + + const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->input(index); } + std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_input(index); } + int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->input_size(); } + void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_input(value); } + + const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->output(index); } + std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_output(index); } + int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->output_size(); } + void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_output(value); } + + const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute(index); } + std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute(index); } + int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_size(); } + void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_attribute(value); } + + const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute_proto(index); } + ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute_proto(index); } + int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_proto_size(); } + ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_attribute_proto(); } + + const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->node(index); } + ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_node(index); } + int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->node_size(); } + ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_node(); } + + const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->value_info(index); } + ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_value_info(index); } + ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_value_info(); } + int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->value_info_size(); } + ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_value_info(); } + + const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->metadata_props(index); } + ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_metadata_props(index); } + ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_metadata_props(); } + int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); } + ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); } + + static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) { int32_t elemType = 0; if (data_type->s() == "float32") { elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; @@ -612,8 +691,12 @@ struct ProviderHostImpl : ProviderHost { elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; } else if (data_type->s() == "int32") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } else if (data_type->s() == "uint32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32; } else if (data_type->s() == "int64") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else if (data_type->s() == "uint64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64; } else if (data_type->s() == "int1") { elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; } else if (data_type->s() == "bfloat16") { @@ -624,17 +707,63 @@ struct ProviderHostImpl : ProviderHost { elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; } else if (data_type->s() == "int16") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; - } else { - return; + } else if (data_type->s() == "double") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + } else if (data_type->s() == "string") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING; + } else if (data_type->s() == "complex64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64; + } else if (data_type->s() == "complex128") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128; + } else if (data_type->s() == "float8e4m3fn") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; + } else if (data_type->s() == "float8e4m3fnuz") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ; + } else if (data_type->s() == "float8e5m2") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + } else if (data_type->s() == "float8e5m2funz") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; + } else if (data_type->s() == "uint4") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4; + } else if (data_type->s() == "int4") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4; } - ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); - if (shape != nullptr) { - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + return elemType; + } + + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_output = ctx.getNumOutputs(); + if (num_output == 1) { + auto* shape = ctx.getAttribute("shape"); + auto* data_type = ctx.getAttribute("data_type"); + if (data_type == nullptr) { + std::cerr << "Custom op is missing `data_type` attr." << std::endl; + return; + } + int32_t elemType = convert_elem_type(data_type); + ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); + if (shape != nullptr) { + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } else { + // set scalar type. + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); } } else { - // set scalar type. - ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); + for (auto idx = 0u; idx < num_output; idx++) { + auto* shape = ctx.getAttribute("shape_" + std::to_string(idx)); + auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx)); + if (shape == nullptr || data_type == nullptr) { + // this output is optional + } else { + int32_t elemType = convert_elem_type(data_type); + ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType); + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } + } } } @@ -734,9 +863,12 @@ struct ProviderHostImpl : ProviderHost { std::vector& IndexedSubGraph__Nodes(IndexedSubGraph* p) override { return p->nodes; } - void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { return p->SetMetaDef(std::move(meta_def_)); } + void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { p->SetMetaDef(std::move(meta_def_)); } const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) override { return p->GetMetaDef(); } + void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) override { p->schema_source = schema_source; } + IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) override { return p->schema_source; } + // KernelDef (wrapped) void KernelDef__operator_delete(KernelDef* p) override { delete p; } void KernelDef__SinceVersion(const KernelDef* p, int* start, int* end) override { return p->SinceVersion(start, end); } @@ -999,7 +1131,7 @@ struct ProviderHostImpl : ProviderHost { const std::string& NodeUnit__Name(const NodeUnit* p) noexcept override { return p->Name(); } int NodeUnit__SinceVersion(const NodeUnit* p) noexcept override { return p->SinceVersion(); } NodeIndex NodeUnit__Index(const NodeUnit* p) noexcept override { return p->Index(); } - const Path& NodeUnit__ModelPath(const NodeUnit* p) noexcept override { return p->ModelPath(); } + const std::filesystem::path& NodeUnit__ModelPath(const NodeUnit* p) noexcept override { return p->ModelPath(); } ProviderType NodeUnit__GetExecutionProviderType(const NodeUnit* p) noexcept override { return p->GetExecutionProviderType(); } @@ -1039,7 +1171,7 @@ struct ProviderHostImpl : ProviderHost { void Model__operator_delete(Model* p) override { delete p; } Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); } std::unique_ptr Model__ToProto(Model* p) override { return std::make_unique(p->ToProto()); } - std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) override { return std::make_unique(p->ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold)); }; + std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) override { return std::make_unique(p->ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold)); }; const ModelMetaData& Model__MetaData(const Model* p) const noexcept override { return p->MetaData(); }; Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) override { return Model::Load(file_path, model_proto); } @@ -1076,7 +1208,7 @@ struct ProviderHostImpl : ProviderHost { const Graph* Graph__ParentGraph(const Graph* p) const override { return p->ParentGraph(); } Graph* Graph__MutableParentGraph(Graph* p) override { return p->MutableParentGraph(); } const std::string& Graph__Name(const Graph* p) const noexcept override { return p->Name(); } - const Path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } + const std::filesystem::path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); } const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } @@ -1132,7 +1264,7 @@ struct ProviderHostImpl : ProviderHost { } const std::string& GraphViewer__Name(const GraphViewer* p) noexcept override { return p->Name(); } - const Path& GraphViewer__ModelPath(const GraphViewer* p) noexcept override { return p->ModelPath(); } + const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept override { return p->ModelPath(); } const Node* GraphViewer__GetNode(const GraphViewer* p, NodeIndex node_index) override { return p->GetNode(node_index); } const NodeArg* GraphViewer__GetNodeArg(const GraphViewer* p, const std::string& name) override { return p->GetNodeArg(name); } @@ -1178,13 +1310,6 @@ struct ProviderHostImpl : ProviderHost { } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } - // Path (wrapped) - PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } - const std::vector& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); } - bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); } - std::unique_ptr Path__construct() override { return std::make_unique(); } - void Path__operator_delete(ONNX_NAMESPACE::Path* p) override { delete p; }; - // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -1800,7 +1925,18 @@ std::shared_ptr OpenVINOProviderFactoryCreator::Creat return s_library_openvino.Get().CreateExecutionProviderFactory(&ov_options_converted_map); } -std::shared_ptr OpenVINOProviderFactoryCreator::Create(const ProviderOptions* provider_options_map) { +void ORTSessionOptionsToOrtOpenVINOProviderOptions(ProviderOptions& ov_options, + const SessionOptions* session_options) { + bool disable_cpu_fallback = session_options->config_options.GetConfigOrDefault( + kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; + if (disable_cpu_fallback) + ov_options["disable_cpu_fallback"] = "true"; +} + +std::shared_ptr OpenVINOProviderFactoryCreator::Create(ProviderOptions* provider_options_map, + const SessionOptions* session_options) { + if (session_options) + onnxruntime::ORTSessionOptionsToOrtOpenVINOProviderOptions(*provider_options_map, session_options); return s_library_openvino.Get().CreateExecutionProviderFactory(provider_options_map); } @@ -1900,6 +2036,20 @@ ProviderInfo_ROCM& GetProviderInfo_ROCM() { ORT_THROW("ROCM Provider not available, can't get interface for it"); } +ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX() try { + return reinterpret_cast(s_library_migraphx.Get().GetInfo()); +} catch (const std::exception& exception) { + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; +} + +ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX() { + if (auto* info = TryGetProviderInfo_MIGraphX()) + return *info; + + ORT_THROW("MIGraphX Provider not available, can't get interface for it"); +} + void CopyGpuToCpu( void* dst_ptr, const void* src_ptr, @@ -2075,7 +2225,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, provider_options[provider_options_keys[i]] = provider_options_values[i]; } - auto factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(&provider_options); + auto factory = onnxruntime::OpenVINOProviderFactoryCreator::Create(&provider_options, &(options->value)); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_OpenVINO_V2: Failed to load shared library"); } @@ -2767,6 +2917,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ provider_options[provider_options_keys[i]] = provider_options_values[i]; } + // EP context related session config options. + provider_options["ep_context_enable"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + provider_options["ep_context_embed_mode"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + provider_options["ep_context_file_path"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 05408db9884cd..db8b97f6d2c13 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -108,11 +108,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "OpenVINO") == 0) { #if defined(USE_OPENVINO) - options->provider_factories.push_back(OpenVINOProviderFactoryCreator::Create(&provider_options)); + options->provider_factories.push_back(OpenVINOProviderFactoryCreator::Create(&provider_options, &(options->value))); #else status = create_not_supported_status(); #endif - } else if (strcmp(provider_name, "SNPE") == 0) { #if defined(USE_SNPE) options->provider_factories.push_back(SNPEProviderFactoryCreator::Create(provider_options)); @@ -128,11 +127,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); - std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0"); - std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; - provider_options["numThreads"] = numThreads; - provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index 43843da3fb96e..dbf961ab5b035 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -15,7 +15,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #if defined(_MSC_VER) #define ORT_FORCEINLINE __forceinline diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index fcd1db31f95ef..c982a7aa2e7e0 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -552,53 +552,80 @@ struct BlockedQuantizeLinear { std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, const std::ptrdiff_t thread_block_size, bool saturate) { ORT_UNUSED_PARAMETER(saturate); + // to avoid a byte being writen from mutiple threads, use 2 * N as thread block + ORT_UNUSED_PARAMETER(thread_block_size); constexpr auto low = static_cast(TOut::min_val); constexpr auto high = static_cast(TOut::max_val); - const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; - const auto num_thread_block = M * K * num_thread_block_N; - const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(float) * 2), - static_cast(thread_block_size * sizeof(typename TOut::UnpackedType)), - static_cast(thread_block_size) * 2.0}; - auto KN = K * N; - auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; - const auto num_thread_block_KN = K * num_thread_block_N; + auto size_thread_block = 2 * N; + auto num_thread_block = (M * K + 1) / 2; + auto num_quant_block_K = (K + quant_block_size - 1) / quant_block_size; + auto num_quant_block_KN = num_quant_block_K * N; + auto MK = M * K; + const TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(float) * 2), + static_cast(size_thread_block * sizeof(typename TOut::UnpackedType)), + static_cast(size_thread_block) * 2.0}; concurrency::ThreadPool::TryParallelFor( thread_pool, num_thread_block, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; - auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; - auto output_idx = m * KN + k * N + n; - auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; - auto quant_param_idx_t = quant_param_idx + n; + begin <<= 1, end = std::min(end << 1, MK); + auto output_idx = begin * N; + auto m = begin / K, k = begin % K; + auto zp_idx = m * num_quant_block_KN + k / quant_block_size * N; for (; begin < end; ++begin) { - auto n_end = std::min(N, n + thread_block_size); - // TODO(fajin): 1> use SIMD, 2> set block to quant_block_size * thread_block_size - // TODO(fajin): process 2 elements at a time - for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { - // TODO(fajin): perf difference + auto zp_idx_t = zp_idx; + auto output_idx_end = output_idx + N; + + // leading unaligned output + if (output_idx & 1) { auto zp = zero_point - ? static_cast(zero_point[quant_param_idx_t >> 1].GetElem(quant_param_idx_t & 1)) + ? static_cast(zero_point[zp_idx_t >> 1].GetElem(zp_idx_t & 1)) : 0; - auto sc = scale[quant_param_idx_t]; + auto sc = scale[zp_idx_t]; auto v = std::clamp(static_cast(std::nearbyint(input[output_idx] / sc)) + zp, low, high); - output[output_idx >> 1].SetElem(output_idx & 1, static_cast(v)); + output[output_idx >> 1].SetElem(1, static_cast(v)); + ++output_idx; + ++zp_idx_t; } - if (n == N) { - n = 0; - ++k; - if (k == K) { - k = 0; - quant_param_idx += N; - } else if (k % quant_block_size == 0) { - quant_param_idx += N; - } + // TODO(fajin): use SIMD + // aligned output + auto output_t = reinterpret_cast(output); + for (; output_idx < output_idx_end - 1; output_idx += 2, zp_idx_t += 2) { + auto zp0 = zero_point + ? static_cast(zero_point[zp_idx_t >> 1].GetElem(zp_idx_t & 1)) + : 0; + auto zp1 = zero_point + ? static_cast(zero_point[(zp_idx_t + 1) >> 1].GetElem((zp_idx_t + 1) & 1)) + : 0; + auto sc0 = scale[zp_idx_t]; + auto sc1 = scale[zp_idx_t + 1]; + auto v0 = std::clamp(static_cast(std::nearbyint(input[output_idx] / sc0)) + zp0, low, high); + auto v1 = std::clamp(static_cast(std::nearbyint(input[output_idx + 1] / sc1)) + zp1, low, high); + output_t[output_idx >> 1] = static_cast((v0 & 0xF) | ((v1 & 0xF) << 4)); + } - quant_param_idx_t = quant_param_idx; + // tailing unaligned output + if (output_idx < output_idx_end) { + auto zp = zero_point + ? static_cast(zero_point[zp_idx_t >> 1].GetElem(zp_idx_t & 1)) + : 0; + auto sc = scale[zp_idx_t]; + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx] / sc)) + zp, low, high); + output[output_idx >> 1].SetElem(0, static_cast(v)); + + ++output_idx; + } + + ++k; + if (k == K) { + k = 0; + zp_idx += N; + } else if (k % quant_block_size == 0) { + zp_idx += N; } } }); @@ -610,53 +637,59 @@ struct BlockedQuantizeLinear { ORT_UNUSED_PARAMETER(saturate); constexpr auto low = static_cast(TOut::min_val); constexpr auto high = static_cast(TOut::max_val); - // quant block size is used as thread block size - const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; - const auto num_thread_block = num_thread_block_K * M; - const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(float)), - static_cast(quant_block_size * sizeof(typename TOut ::UnpackedType)), - static_cast(quant_block_size) * 2.0}; + // to avoid a byte being writen from mutiple threads, use 2 * K as thread block + auto size_thread_block = 2 * K; + auto quant_block_num_K = (K + quant_block_size - 1) / quant_block_size; + auto num_thread_block = (M + 1) / 2; + TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(float)), + static_cast(size_thread_block * sizeof(typename TOut ::UnpackedType)), + static_cast(size_thread_block) * 2.0}; concurrency::ThreadPool::TryParallelFor( thread_pool, num_thread_block, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; - auto output_idx = m * K + k; - - for (; begin < end; ++begin) { - auto zp = zero_point ? static_cast(zero_point[begin >> 1].GetElem(begin & 1)) : 0; - auto sc = scale[begin]; - size_t output_idx_end = std::min(K - k, quant_block_size) + output_idx; - size_t out_start = output_idx, out_end = output_idx_end; - - if (out_start & 1) { - auto v = std::clamp(static_cast(std::nearbyint(input[out_start] / sc)) + zp, low, high); - output[out_start >> 1].SetElem(1, static_cast(v)); - ++out_start; - } + begin <<= 1, end = std::min(end << 1, M); + auto output_idx = begin * K; + auto zp_idx = begin * quant_block_num_K; + + for (; begin < end; ++begin, output_idx += K) { + auto output_row_idx_start = output_idx; + auto output_row_idx_end = output_row_idx_start + K; + + for (; output_row_idx_start < output_row_idx_end; output_row_idx_start += quant_block_size, ++zp_idx) { + auto zp = zero_point ? static_cast(zero_point[zp_idx >> 1].GetElem(zp_idx & 1)) : 0; + auto sc = scale[zp_idx]; + size_t out_start = output_row_idx_start; + size_t out_end = std::min(output_row_idx_start + quant_block_size, output_row_idx_end); + + if (out_start & 1) { + auto v = std::clamp(static_cast(std::nearbyint(input[out_start] / sc)) + zp, low, high); + output[out_start >> 1].SetElem(1, static_cast(v)); + ++out_start; + } - if (out_end & 1) { - --out_end; - auto v = std::clamp(static_cast(std::nearbyint(input[out_end] / sc)) + zp, low, high); - output[out_end >> 1].SetElem(0, static_cast(v)); - } + if (out_end & 1) { + --out_end; + auto v = std::clamp(static_cast(std::nearbyint(input[out_end] / sc)) + zp, low, high); + output[out_end >> 1].SetElem(0, static_cast(v)); + } - if constexpr (std::is_same::value) { - MlasQuantizeLinearS4(input + out_start, reinterpret_cast(&(output[out_start >> 1])), - out_end - out_start, sc, static_cast(zp)); - } else { - MlasQuantizeLinearU4(input + out_start, reinterpret_cast(&(output[out_start >> 1])), - out_end - out_start, sc, static_cast(zp)); + if constexpr (std::is_same::value) { + MlasQuantizeLinearS4(input + out_start, reinterpret_cast(&(output[out_start >> 1])), + out_end - out_start, sc, static_cast(zp)); + } else { + MlasQuantizeLinearU4(input + out_start, reinterpret_cast(&(output[out_start >> 1])), + out_end - out_start, sc, static_cast(zp)); + } } - - output_idx = output_idx_end; - k = output_idx % K; } }); } }; +// Bug(fajin): the same byte in output / zero_point must not be written by different threads, otherwise +// the result is undefined. This is not handled in the current implementation. template struct BlockedQuantizeLinear { static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, @@ -664,54 +697,84 @@ struct BlockedQuantizeLinear { std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, const std::ptrdiff_t thread_block_size, bool saturate) { ORT_UNUSED_PARAMETER(saturate); + // to avoid a byte being writen from mutiple threads, use 2 * N as thread block + ORT_UNUSED_PARAMETER(thread_block_size); constexpr auto low = static_cast(TOut::min_val); constexpr auto high = static_cast(TOut::max_val); - const auto num_thread_block_N = (N + thread_block_size - 1) / thread_block_size; - const auto num_thread_block = M * K * num_thread_block_N; - const TensorOpCost unit_cost{static_cast(thread_block_size * sizeof(MLFloat16) * 2), - static_cast(thread_block_size * sizeof(typename TOut::UnpackedType)), - static_cast(thread_block_size) * 2.0}; - auto KN = K * N; - auto num_quant_block_KN = (K + quant_block_size - 1) / quant_block_size * N; - const auto num_thread_block_KN = K * num_thread_block_N; + auto size_thread_block = 2 * N; + auto num_thread_block = (M * K + 1) / 2; + auto num_quant_block_K = (K + quant_block_size - 1) / quant_block_size; + auto num_quant_block_KN = num_quant_block_K * N; + auto MK = M * K; + const TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(float) * 2), + static_cast(size_thread_block * sizeof(typename TOut::UnpackedType)), + static_cast(size_thread_block) * 2.0}; concurrency::ThreadPool::TryParallelFor( thread_pool, num_thread_block, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto m = begin / num_thread_block_KN, k = begin % num_thread_block_KN / num_thread_block_N; - auto n_blk = begin % num_thread_block_N, n = n_blk * thread_block_size; - auto output_idx = m * KN + k * N + n; - auto quant_param_idx = m * num_quant_block_KN + k / quant_block_size * N; - auto quant_param_idx_t = quant_param_idx + n; + begin <<= 1, end = std::min(end << 1, MK); + auto output_idx = begin * N; + auto m = begin / K, k = begin % K; + auto zp_idx = m * num_quant_block_KN + k / quant_block_size * N; for (; begin < end; ++begin) { - auto n_end = std::min(N, n + thread_block_size); - // TODO(fajin): 1> use SIMD, 2> set block to quant_block_size * thread_block_size - // TODO(fajin): process 2 elements at a time - for (; n < n_end; ++n, ++output_idx, ++quant_param_idx_t) { - // TODO(fajin): perf difference + auto zp_idx_t = zp_idx; + auto output_idx_end = output_idx + N; + + // leading unaligned output + if (output_idx & 1) { auto zp = zero_point - ? static_cast(zero_point[quant_param_idx_t >> 1].GetElem(quant_param_idx_t & 1)) + ? static_cast(zero_point[zp_idx_t >> 1].GetElem(zp_idx_t & 1)) : 0; - auto sc = scale[quant_param_idx_t].ToFloat(); - auto v = std::clamp(static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, - low, high); - output[output_idx >> 1].SetElem(output_idx & 1, static_cast(v)); + auto sc = scale[zp_idx_t].ToFloat(); + auto v = std::clamp( + static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, low, high); + output[output_idx >> 1].SetElem(1, static_cast(v)); + ++output_idx; + ++zp_idx_t; } - if (n == N) { - n = 0; - ++k; - if (k == K) { - k = 0; - quant_param_idx += N; - } else if (k % quant_block_size == 0) { - quant_param_idx += N; - } + // TODO(fajin): use SIMD + // aligned output + auto output_t = reinterpret_cast(output); + for (; output_idx < output_idx_end - 1; output_idx += 2, zp_idx_t += 2) { + auto zp0 = zero_point + ? static_cast(zero_point[zp_idx_t >> 1].GetElem(zp_idx_t & 1)) + : 0; + auto zp1 = zero_point + ? static_cast(zero_point[(zp_idx_t + 1) >> 1].GetElem((zp_idx_t + 1) & 1)) + : 0; + auto sc0 = scale[zp_idx_t].ToFloat(); + auto sc1 = scale[zp_idx_t + 1].ToFloat(); + auto v0 = std::clamp( + static_cast(std::nearbyint(input[output_idx].ToFloat() / sc0)) + zp0, low, high); + auto v1 = std::clamp( + static_cast(std::nearbyint(input[output_idx + 1].ToFloat() / sc1)) + zp1, low, high); + output_t[output_idx >> 1] = static_cast((v0 & 0xF) | ((v1 & 0xF) << 4)); + } - quant_param_idx_t = quant_param_idx; + // tailing unaligned output + if (output_idx < output_idx_end) { + auto zp = zero_point + ? static_cast(zero_point[zp_idx_t >> 1].GetElem(zp_idx_t & 1)) + : 0; + auto sc = scale[zp_idx_t].ToFloat(); + auto v = std::clamp( + static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, low, high); + output[output_idx >> 1].SetElem(0, static_cast(v)); + + ++output_idx; + } + + ++k; + if (k == K) { + k = 0; + zp_idx += N; + } else if (k % quant_block_size == 0) { + zp_idx += N; } } }); @@ -723,32 +786,55 @@ struct BlockedQuantizeLinear { ORT_UNUSED_PARAMETER(saturate); constexpr auto low = static_cast(TOut::min_val); constexpr auto high = static_cast(TOut::max_val); - // quant block size is used as thread block size - const auto num_thread_block_K = (K + quant_block_size - 1) / quant_block_size; - const auto num_thread_block = num_thread_block_K * M; - const TensorOpCost unit_cost{static_cast(quant_block_size * sizeof(MLFloat16)), - static_cast(quant_block_size * sizeof(typename TOut::UnpackedType)), - static_cast(quant_block_size) * 2.0}; + // to avoid a byte being writen from mutiple threads, use 2 * K as thread block + auto size_thread_block = 2 * K; + auto quant_block_num_K = (K + quant_block_size - 1) / quant_block_size; + auto num_thread_block = (M + 1) / 2; + TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(float)), + static_cast(size_thread_block * sizeof(typename TOut ::UnpackedType)), + static_cast(size_thread_block) * 2.0}; concurrency::ThreadPool::TryParallelFor( thread_pool, num_thread_block, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto m = begin / num_thread_block_K, k_blk = begin % num_thread_block_K, k = k_blk * quant_block_size; - auto output_idx = m * K + k; + begin <<= 1, end = std::min(end << 1, M); + auto output_idx = begin * K; + auto zp_idx = begin * quant_block_num_K; + + for (; begin < end; ++begin, output_idx += K) { + auto output_row_idx_start = output_idx; + auto output_row_idx_end = output_row_idx_start + K; + + for (; output_row_idx_start < output_row_idx_end; output_row_idx_start += quant_block_size, ++zp_idx) { + auto zp = zero_point ? static_cast(zero_point[zp_idx >> 1].GetElem(zp_idx & 1)) : 0; + auto sc = scale[zp_idx].ToFloat(); + size_t out_start = output_row_idx_start; + size_t out_end = std::min(output_row_idx_start + quant_block_size, output_row_idx_end); + + if (out_start & 1) { + auto v = std::clamp( + static_cast(std::nearbyint(input[out_start].ToFloat() / sc)) + zp, low, high); + output[out_start >> 1].SetElem(1, static_cast(v)); + ++out_start; + } - for (; begin < end; ++begin) { - // each thread block is also a quantization block - auto zp = zero_point ? static_cast(zero_point[begin >> 1].GetElem(begin & 1)) : 0; - auto sc = scale[begin].ToFloat(); - auto output_idx_end = std::min(K - k, quant_block_size) + output_idx; - for (; output_idx < output_idx_end; ++output_idx) { - auto v = std::clamp(static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, - low, high); - output[output_idx >> 1].SetElem(output_idx & 1, static_cast(v)); - } + if (out_end & 1) { + --out_end; + auto v = std::clamp( + static_cast(std::nearbyint(input[out_end].ToFloat() / sc)) + zp, low, high); + output[out_end >> 1].SetElem(0, static_cast(v)); + } - k = output_idx % K; + auto output_t = reinterpret_cast(output); + for (; out_start < out_end; out_start += 2) { + auto v0 = std::clamp( + static_cast(std::nearbyint(input[out_start].ToFloat() / sc)) + zp, low, high); + auto v1 = std::clamp( + static_cast(std::nearbyint(input[out_start + 1].ToFloat() / sc)) + zp, low, high); + output_t[out_start >> 1] = static_cast((v0 & 0xF) | ((v1 & 0xF) << 4)); + } + } } }); } diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index ff76887e917cd..51a52af1b151e 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -66,6 +66,37 @@ void QuantizeMatMul4BitsBlockwise( tp.get()); } +template +bool QuantizeQDQMatMul4BitsBlockwise( + py::array_t dst, // shape: [K, N / 2] + py::array_t src, // shape: [K, N] + py::array_t scale, // shape: [block_per_K, N] + py::array_t zero_points, // shape: [block_per_K, N / 2] + int32_t quant_block_size, + int32_t N, + int32_t K, + bool is_symmetric) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info scale_buf = scale.request(); + py::buffer_info zp_buf = zero_points.request(); + + return MlasQDQQuantizeBlockwise( + reinterpret_cast(src_buf.ptr), + reinterpret_cast(scale_buf.ptr), + is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), + reinterpret_cast(dst_buf.ptr), + true, + K, + N, + quant_block_size, + tp.get()); +} + template void QuantizeMatMulBnb4Blockwise( py::array_t dst, @@ -99,6 +130,8 @@ void CreateQuantPybindModule(py::module& m) { m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); + m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); } } // namespace python diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 4da25eac32040..218b59688b01c 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -40,7 +40,8 @@ void addGlobalSchemaFunctions(pybind11::module& m) { #ifdef USE_OPENVINO []() { ProviderOptions provider_options_map; - return onnxruntime::OpenVINOProviderFactoryCreator::Create(&provider_options_map); + SessionOptions session_options; + return onnxruntime::OpenVINOProviderFactoryCreator::Create(&provider_options_map, &session_options); }(), #endif #ifdef USE_TENSORRT diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index fa4c906dd054c..e13285c60e69f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1084,7 +1084,7 @@ std::unique_ptr CreateExecutionProviderInstance( } } if (std::shared_ptr openvino_provider_factory = onnxruntime::OpenVINOProviderFactoryCreator::Create( - &OV_provider_options_map)) { + &OV_provider_options_map, &session_options)) { auto p = openvino_provider_factory->CreateProvider(); // Reset global variables config to avoid it being accidentally passed on to the next session openvino_device_type.clear(); @@ -1114,6 +1114,9 @@ std::unique_ptr CreateExecutionProviderInstance( if (it != provider_options_map.end()) { info = it->second; } + info["ep_context_enable"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + info["ep_context_embed_mode"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + info["ep_context_file_path"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kAclExecutionProvider) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index dc9394a83a4ea..4d6e411defae3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -65,7 +65,11 @@ struct OrtStatus { #elif OPENVINO_CONFIG_HETERO #define BACKEND_OPENVINO "-OPENVINO_HETERO" + +#elif OPENVINO_DISABLE_NPU_FALLBACK +#define BACKEND_OPENVINO "-OPENVINO_DISABLE_NPU_FALLBACK" #endif + #else #define BACKEND_OPENVINO "" #endif diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 3f5e4e660003f..10492ae419817 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -128,6 +128,9 @@ def __setitem__(self, key, value): def values(self): return self.data.values() + def items(self): + return self.data.items() + class CalibrationMethod(Enum): MinMax = 0 @@ -155,6 +158,12 @@ def __next__(self): raise StopIteration return result + def __len__(self): + raise NotImplementedError + + def set_range(self, start_index: int, end_index: int): + raise NotImplementedError + class CalibraterBase: def __init__( @@ -409,13 +418,31 @@ def merge_range(self, old_range, new_range): return new_range for key, value in old_range.items(): + # Handling for structured data types with TensorData + if isinstance(value, TensorData): + old_min = value.range_value[0] + old_max = value.range_value[1] + else: + old_min, old_max = value + + if isinstance(new_range[key], TensorData): + new_min = new_range[key].range_value[0] + new_max = new_range[key].range_value[1] + else: + new_min, new_max = new_range[key] + if self.moving_average: - min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0]) - max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1]) + min_value = old_min + self.averaging_constant * (new_min - old_min) + max_value = old_max + self.averaging_constant * (new_max - old_max) + else: + min_value = min(old_min, new_min) + max_value = max(old_max, new_max) + + # If structured as TensorData, wrap the result accordingly + if isinstance(value, TensorData) or isinstance(new_range[key], TensorData): + new_range[key] = TensorData(lowest=min_value, highest=max_value) else: - min_value = min(value[0], new_range[key][0]) - max_value = max(value[1], new_range[key][1]) - new_range[key] = (min_value, max_value) + new_range[key] = (min_value, max_value) return new_range diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 1ad56dc3ac455..eac5b3b78690b 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -52,6 +52,7 @@ def get_qnn_qdq_config( activation_symmetric: bool = False, weight_symmetric: bool | None = None, keep_removable_activations: bool = False, + stride: int | None = None, ) -> StaticQuantConfig: """ Returns a static quantization configuration suitable for running QDQ models on QNN EP. @@ -171,6 +172,7 @@ def get_qnn_qdq_config( "TensorQuantOverrides": overrides_helper.get_dict(), "ActivationSymmetric": activation_symmetric, "WeightSymmetric": weight_symmetric, + "CalibStridedMinMax": stride, } # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 11a830dc6d7f5..40a4a4d26dc1c 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -18,31 +18,36 @@ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version -from onnxruntime.capi._pybind_state import quantize_matmul_4bits +from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel -from .quant_utils import attribute_to_kwarg +from .quant_utils import QuantFormat, attribute_to_kwarg logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) class WeightOnlyQuantConfig: - def __init__(self, algorithm): + def __init__(self, algorithm, quant_format): """This is the Base class for Weight Only Quant Configuration. Args: algorithm: weight only quantize algorithm name. + quant_format: QuantFormat{QOperator, QDQ}. + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. """ self.algorithm = algorithm + self.quant_format = quant_format class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, ratios=None, + quant_format=QuantFormat.QOperator, ): """ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. @@ -51,11 +56,18 @@ def __init__( Args: ratios: percentile of clip. Defaults to {}. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format" + if ratios is None: ratios = {} super().__init__( algorithm="RTN", + quant_format=quant_format, ) self.ratios = ratios @@ -69,6 +81,7 @@ def __init__( actorder=False, mse=False, perchannel=True, + quant_format=QuantFormat.QOperator, ): """ This is a class for GPTQ algorithm Weight Only Quant Configuration. @@ -87,9 +100,16 @@ def __init__( whether get scale and zero point with mse error. perchannel (bool, optional): whether quantize weight per-channel. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format" + super().__init__( algorithm="GPTQ", + quant_format=quant_format, ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp @@ -105,6 +125,7 @@ def __init__( block_size=128, bits=4, axis=1, + quant_format=QuantFormat.QOperator, ): """ This is a class for HQQ algorithm Weight Only Quant Configuration. @@ -112,14 +133,21 @@ def __init__( Args: block_size (int, optional): - channel number in one block to execute a GPTQ quantization iteration. + channel number in one block to execute a HQQ quantization iteration. bits (int, optional): how many bits to represent weight. axis (int, optional): 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format" + super().__init__( algorithm="HQQ", + quant_format=quant_format, ) self.block_size = block_size self.bits = bits @@ -132,8 +160,26 @@ def __init__( block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, + quant_format=QuantFormat.QOperator, ): - super().__init__(algorithm="DEFAULT") + """ + This is a class for weight only affine quantization configuration. + + Args: + block_size (int, optional): + channel number in one block to execute an affine quantization iteration. + is_symmetric (bool, optional): + whether quantize weight symmetrically. + accuracy_level (int, optional): + Accuracy level of the 4-bit quantized MatMul computation. + Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details. + (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits) + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. + """ + super().__init__(algorithm="DEFAULT", quant_format=quant_format) self.block_size = block_size self.is_symmetric = is_symmetric self.bits = 4 @@ -287,23 +333,26 @@ def quantize_internal( return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now import torch logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - b_pb, bs_graph = get_initializer(inputB, graph_stack) + input_b = node.input[1] + b_pb, bs_graph = get_initializer(input_b, graph_stack) if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight b_array = onnx.numpy_helper.to_array(b_pb) if len(b_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix + return [node] # can only process 2-D matrix b_array_torch = torch.from_numpy(b_array) if torch.cuda.is_available(): b_array_torch = b_array_torch.cuda() @@ -334,7 +383,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) b_quant.name = b_pb.name + "_Q4" for input in bs_graph.input: - if input.name == inputB: + if input.name == input_b: bs_graph.input.remove(input) break @@ -366,7 +415,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): logger.info(f"complete quantization of {node.name} ...") - return matmul_q4_node + return [matmul_q4_node] def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: @@ -382,7 +431,7 @@ class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config - def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: + def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """4b quantize fp32 weight to a blob""" if len(fp32weight.shape) != 2: @@ -390,83 +439,136 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: rows, cols = fp32weight.shape block_size = self.config.block_size - blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size - padded_rows = k_blocks * block_size - pad_len = padded_rows - rows - if pad_len > 0: - fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") - # block wise quantization, each block comes from a single column - packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) + if self.config.quant_format == QuantFormat.QOperator: + blob_size = block_size // 2 + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") + + # block wise quantization, each block comes from a single column + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) + quantize_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) + else: + packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype) + quantize_qdq_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) return (packed, scales, zero_point) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: + qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4 + input_b = node.input[1] + b_tensor, b_graph = get_initializer(input_b, graph_stack) + if b_tensor is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: + b_ndarray = onnx.numpy_helper.to_array(b_tensor) + if len(b_ndarray.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix - - packed, scales, zero_points = self.int4_block_quant(B_array) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" - for input in Bs_graph.input: - if input.name == inputB: - Bs_graph.input.remove(input) - break + return [node] # can only process 2-D matrix - scales_tensor = onnx.numpy_helper.from_array(scales) - scales_tensor.name = B.name + "_scales" - Bs_graph.initializer.extend([B_quant, scales_tensor]) + packed, scales, zero_points = self.int4_block_quant(b_ndarray) - input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.config.is_symmetric: - zp_tensor = onnx.numpy_helper.from_array(zero_points) - zp_tensor.name = B.name + "_zero_points" - Bs_graph.initializer.extend([zp_tensor]) - input_names.append(zp_tensor.name) + if self.config.quant_format == QuantFormat.QOperator: + b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4") + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales") + else: + b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") - kwargs = {} - rows, cols = B_array.shape - kwargs["K"] = rows - kwargs["N"] = cols - kwargs["bits"] = 4 - kwargs["block_size"] = self.config.block_size - if self.config.accuracy_level is not None: - kwargs["accuracy_level"] = self.config.accuracy_level + for input in b_graph.input: + if input.name == input_b: + b_graph.input.remove(input) + break - matmul_q4_node = onnx.helper.make_node( - "MatMulNBits", - inputs=input_names, - outputs=[node.output[0]], - name=node.name + "_Q4" if node.name else "", - domain="com.microsoft", - **kwargs, - ) + b_graph.initializer.extend([b_quant, scales_tensor]) + + output_nodes = [] + + if self.config.quant_format == QuantFormat.QOperator: + input_names = [node.input[0], b_quant.name, scales_tensor.name] + if not self.config.is_symmetric: + zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points") + input_names.append(zp_tensor.name) + b_graph.initializer.extend([zp_tensor]) + kwargs = {} + rows, cols = b_ndarray.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = 4 + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: + kwargs["accuracy_level"] = self.config.accuracy_level + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) - logger.info(f"complete quantization of {node.name} ...") + output_nodes.append(matmul_q4_node) + else: + dq_input_names = [b_quant.name, scales_tensor.name] + dq_output_names = [b_quant.name + "_output"] + matmul_input_names = [node.input[0], dq_output_names[0]] + matmul_output_names = [node.output[0]] + if not self.config.is_symmetric: + zp_tensor = onnx.helper.make_tensor( + b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True + ) + dq_input_names.append(zp_tensor.name) + b_graph.initializer.extend([zp_tensor]) + dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + dq_node = onnx.helper.make_node( + "DequantizeLinear", + inputs=dq_input_names, + outputs=dq_output_names, + name=node.name + "_DQ_Q4" if node.name else "", + **dq_kwargs, + ) + matmul_node = onnx.helper.make_node( + "MatMul", + inputs=matmul_input_names, + outputs=matmul_output_names, + name=node.name + "_matmul_Q4" if node.name else "", + ) + output_nodes.extend([dq_node, matmul_node]) - return matmul_q4_node + logger.info(f"complete quantization of {node.name} ...") + return output_nodes class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" + """ + Perform 4b quantization of constant MatMul weights. + If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the + MatMul node. + If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is + replaced by the DequantizeLinear + MatMul nodes. + """ def __init__( self, @@ -475,7 +577,8 @@ def __init__( is_symmetric: bool = False, accuracy_level: int | None = None, nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, + quant_format=QuantFormat.QOperator, + algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: nodes_to_exclude = [] @@ -488,7 +591,10 @@ def __init__( self.node_quantizer = None if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + block_size=block_size, + is_symmetric=is_symmetric, + accuracy_level=accuracy_level, + quant_format=quant_format, ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": @@ -526,15 +632,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - out_node = None + out_nodes = [] if node.name in self.nodes_to_exclude: logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - out_node = node + out_nodes = [node] elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": - out_node = self.node_quantizer.quantize(node, graph_stack) + out_nodes = self.node_quantizer.quantize(node, graph_stack) else: - out_node = self.node_quantizer.quantize(node, graph_stack) - new_nodes.append(out_node) + out_nodes = self.node_quantizer.quantize(node, graph_stack) + new_nodes.extend(out_nodes) graph.ClearField("node") graph.node.extend(new_nodes) @@ -688,6 +794,15 @@ def parse_args(): default=[], help="Specify the nodes to be excluded from quantization with node names", ) + parser.add_argument( + "--quant_format", + default="QOperator", + type=QuantFormat, + choices=list(QuantFormat), + help="QuantFormat {QOperator, QDQ}" + "QOperator format quantizes the model with quantized operators directly." + "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", + ) return parser.parse_args() @@ -699,6 +814,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model + quant_format = args.quant_format if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") @@ -713,7 +829,10 @@ def parse_args(): quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( - block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + block_size=args.block_size, + is_symmetric=args.symmetric, + accuracy_level=args.accuracy_level, + quant_format=quant_format, ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig() diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index f8b74a7ae4c2e..2340c995d3d5b 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -381,6 +381,9 @@ def quantize_static( CalibTensorRangeSymmetric = True/False : Default is False. If enabled, the final range of tensor during calibration will be explicitly set to symmetric to central point "0". + CalibStridedMinMax = Optional[int] : + Default is None. If set to an integer, during calculation of the min-max, only stride amount of + data will be used and then all results will be merged in the end. CalibMovingAverage = True/False : Default is False. If enabled, the moving average of the minimum and maximum values will be computed when the calibration method selected is MinMax. @@ -522,7 +525,19 @@ def inc_dataloader(): use_external_data_format=use_external_data_format, extra_options=calib_extra_options, ) - calibrator.collect_data(calibration_data_reader) + + stride = extra_options.get("CalibStridedMinMax", None) + if stride: + total_data_size = len(calibration_data_reader) + if total_data_size % stride != 0: + raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") + + for start in range(0, total_data_size, stride): + end_index = start + stride + calibration_data_reader.set_range(start_index=start, end_index=end_index) + calibrator.collect_data(calibration_data_reader) + else: + calibrator.collect_data(calibration_data_reader) tensors_range = calibrator.compute_data() if not isinstance(tensors_range, TensorsData): raise TypeError( diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9bc2328cc71b6..ac959d5c061f7 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -206,10 +206,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, "GroupQueryAttention": self._infer_GroupQueryAttention, - "SparseAttention": self._infer_SparseAttention, - "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, + "MatMulNBits": self._infer_MatMulNBits, "MultiHeadAttention": self._infer_MultiHeadAttention, "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, @@ -223,8 +222,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "RestorePadding": self._infer_RestorePadding, "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipGroupNorm": self._infer_SkipGroupNorm, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + "SparseAttention": self._infer_SparseAttention, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -1256,6 +1257,25 @@ def _infer_MatMul(self, node): # noqa: N802 def _infer_MatMulInteger(self, node): # noqa: N802 self._compute_matmul_shape(node, onnx.TensorProto.INT32) + def _infer_MatMulNBits(self, node): # noqa: N802 + lhs_shape = self._get_shape(node, 0) + rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")] + lhs_rank = len(lhs_shape) + assert lhs_rank > 0 + if lhs_rank == 1: + new_shape = rhs_shape[1:] + else: + new_shape = lhs_shape[:-1] + rhs_shape[1:] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[-1], rhs_shape[0]], + allow_broadcast=False, + ) + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + def _infer_NonMaxSuppression(self, node): # noqa: N802 selected = str(self._new_symbolic_dim_from_output(node)) vi = self.known_vi_[node.output[0]] diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index 9ee8f27df5c99..2f335009b59c6 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -15,12 +15,10 @@ from typing import List, Optional TRT_DOCKER_FILES = { - "8.4.cuda_11_6_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4", - "8.5.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5", "8.6.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", "8.6.cuda_12_3_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6", - "10.0.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt10_0", - "10.0.cuda_12_4_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_4_tensorrt10_0", + "10.2.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", + "10.2.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", "BIN": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin", } diff --git a/onnxruntime/python/tools/transformers/dev_benchmark.cmd b/onnxruntime/python/tools/transformers/dev_benchmark.cmd index 7a9b3254a1708..82137de3c0f3b 100644 --- a/onnxruntime/python/tools/transformers/dev_benchmark.cmd +++ b/onnxruntime/python/tools/transformers/dev_benchmark.cmd @@ -41,7 +41,7 @@ set input_counts=1 REM Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased set models_to_test=bert-base-cased -REM If you have mutliple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: +REM If you have multiple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: REM set CUDA_VISIBLE_DEVICES=1 REM This script will generate a logs file with a list of commands used in tests. @@ -163,4 +163,4 @@ IF %FileSize% LSS 10 goto :EOF python -c "import sys; lines=sys.stdin.readlines(); h=lines[0]; print(''.join([h]+list(sorted(set(lines)-set([h])))))" < %1 > sort_%1 FindStr "[^,]" sort_%1 > summary_%1 DEL sort_%1 -goto :EOF \ No newline at end of file +goto :EOF diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 6fba98c14e792..cd8a8756d681e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -27,8 +27,6 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. - `requirements-quant.txt` - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) -- `requirements-70b-model.txt` - - For running the LLaMA-2 70B model on multiple GPUs - `requirements.txt` - Package versions needed in each of the above files @@ -221,18 +219,6 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu --use_gqa ``` -Export LLaMA-2 70B sharded model into 4 partitions -``` -# From source: -# 1. Install necessary packages from requirements-70b-model.txt -$ pip install -r requirements-70b-model.txt - -# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: -$ ./build.sh --config Release --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ - -# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: -$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-distributed --precision fp16 --execution_provider cuda --use_gqa -``` ## Parity Checking LLaMA-2 @@ -395,18 +381,6 @@ CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ --device cuda ``` -9. ONNX Runtime, FP16, convert_to_onnx, LLaMA-2 70B shard to 4 GPUs -``` -CUDA_VISIBLE_DEVICES=4,5,6,7 bash benchmark_70b_model.sh 4 \ - --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-70b-dis/rank_{}_Llama-2-70b-hf_decoder_merged_model_fp16.onnx \ - --model-name meta-llama/Llama-2-70b-hf \ - --cache-dir ./model_cache \ - --precision fp16 \ - --device cuda \ - --warmup-runs 5 \ - --num-runs 100 -``` You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination. diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh deleted file mode 100644 index 38f1916456658..0000000000000 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -NUM_GPUS=${1:-1} - -MPI="mpirun --allow-run-as-root - -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 - --tag-output --npernode $NUM_GPUS --bind-to numa - -x MIOPEN_FIND_MODE=1" - -CMD="$MPI python benchmark.py ${@:2}" - -$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh deleted file mode 100644 index 637d15c10e0c7..0000000000000 --- a/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -NUM_GPUS=${1:-1} - -MPI="mpirun --allow-run-as-root - -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 - --tag-output --npernode $NUM_GPUS --bind-to numa - -x MIOPEN_FIND_MODE=1" - -CMD="$MPI python convert_to_onnx.py ${@:2}" - -$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 8a33544654e05..f701e465b9153 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1052,7 +1052,8 @@ def main(): logger.info(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: - logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) + logger.exception(f"An error occurred while verifying parity: {e}") + sys.exit(-1) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt deleted file mode 100644 index 572cfdb71be4a..0000000000000 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt +++ /dev/null @@ -1,4 +0,0 @@ --r requirements.txt -git+https://github.com/frankdongms/transformers.git@frdong/shard_llama -mpi4py -psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 95f40af3fd746..9cc4878e8022d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -694,7 +694,7 @@ def __init__(self, model, num_heads, hidden_size): self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self) self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) - # TODO: consider retrive max_distance from model. + # TODO: consider retrieve max_distance from model. # math.log(max_distance / (num_buckets // 2)) self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128) diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index 64d6ecde618f6..77d0c3a76624f 100755 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -62,7 +62,7 @@ input_counts=1 # Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased models_to_test="bert-base-cased roberta-base distilbert-base-uncased" -# If you have mutliple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: +# If you have multiple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: # export CUDA_VISIBLE_DEVICES=1 # This script will generate a logs file with a list of commands used in tests. diff --git a/onnxruntime/test/common/path_test.cc b/onnxruntime/test/common/path_test.cc deleted file mode 100644 index d097705773568..0000000000000 --- a/onnxruntime/test/common/path_test.cc +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/path.h" - -#include "gtest/gtest.h" - -#include "core/common/optional.h" -#include "test/util/include/asserts.h" - -namespace onnxruntime { -namespace test { - -TEST(PathTest, Parse) { - auto check_parse = - [](const std::string& path_string, - const std::string& expected_root, - const std::vector& expected_components) { - Path p{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - - std::vector expected_components_ps{}; - std::transform( - expected_components.begin(), expected_components.end(), - std::back_inserter(expected_components_ps), - [](const std::string& s) { return ToPathString(s); }); - EXPECT_EQ(p.GetComponents(), expected_components_ps); - EXPECT_EQ(p.GetRootPathString(), ToPathString(expected_root)); - }; - - check_parse( - "i/am/relative", - "", {"i", "am", "relative"}); -#ifdef _WIN32 - check_parse( - "/i/am/rooted", - R"(\)", {"i", "am", "rooted"}); - check_parse( - R"(\\server\share\i\am\rooted)", - R"(\\server\share\)", {"i", "am", "rooted"}); - check_parse( - R"(C:\i\am\rooted)", - R"(C:\)", {"i", "am", "rooted"}); - check_parse( - R"(C:i\am\relative)", - "C:", {"i", "am", "relative"}); -#else // POSIX - check_parse( - "/i/am/rooted", - "/", {"i", "am", "rooted"}); - check_parse( - "//root_name/i/am/rooted", - "//root_name/", {"i", "am", "rooted"}); -#endif -} - -TEST(PathTest, ParseFailure) { - auto check_parse_failure = - [](const std::string& path_string) { - Path p{}; - EXPECT_FALSE(Path::Parse(ToPathString(path_string), p).IsOK()); - }; - -#ifdef _WIN32 - check_parse_failure(R"(\\server_name_no_separator)"); - check_parse_failure(R"(\\server_name_no_share_name\)"); - check_parse_failure(R"(\\server_name\share_name_no_root_dir)"); -#else // POSIX - check_parse_failure("//root_name_no_root_dir"); -#endif -} - -TEST(PathTest, IsEmpty) { - auto check_empty = - [](const std::string& path_string, bool is_empty) { - Path p{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - - EXPECT_EQ(p.IsEmpty(), is_empty); - }; - - check_empty("", true); - check_empty(".", false); - check_empty("/", false); -} - -TEST(PathTest, IsAbsoluteOrRelative) { - auto check_abs_or_rel = - [](const std::string& path_string, bool is_absolute) { - Path p{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - - EXPECT_EQ(p.IsAbsolute(), is_absolute); - EXPECT_EQ(p.IsRelative(), !is_absolute); - }; - - check_abs_or_rel("relative", false); - check_abs_or_rel("", false); -#ifdef _WIN32 - check_abs_or_rel(R"(\root_relative)", false); - check_abs_or_rel(R"(\)", false); - check_abs_or_rel("C:drive_relative", false); - check_abs_or_rel("C:", false); - check_abs_or_rel(R"(C:\absolute)", true); - check_abs_or_rel(R"(C:\)", true); -#else // POSIX - check_abs_or_rel("/absolute", true); - check_abs_or_rel("/", true); -#endif -} - -TEST(PathTest, ParentPath) { - auto check_parent = - [](const std::string path_string, const std::string& expected_parent_path_string) { - Path p{}, p_expected_parent{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_parent_path_string), p_expected_parent)); - - EXPECT_EQ(p.ParentPath().ToPathString(), p_expected_parent.ToPathString()); - }; - - check_parent("a/b", "a"); - check_parent("/a/b", "/a"); - check_parent("", ""); - check_parent("/", "/"); -} - -TEST(PathTest, Normalize) { - auto check_normalize = - [](const std::string& path_string, - const std::string& expected_normalized_path_string) { - Path p{}, p_expected_normalized{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(path_string), p)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_normalized_path_string), p_expected_normalized)); - - EXPECT_EQ(p.Normalize().ToPathString(), p_expected_normalized.ToPathString()); - }; - - check_normalize("/a/b/./c/../../d/../e", "/a/e"); - check_normalize("a/b/./c/../../d/../e", "a/e"); - check_normalize("/../a/../../b", "/b"); - check_normalize("../a/../../b", "../../b"); - check_normalize("/a/..", "/"); - check_normalize("a/..", "."); - check_normalize("", ""); - check_normalize("/", "/"); - check_normalize(".", "."); -} - -TEST(PathTest, Append) { - auto check_append = - [](const std::string& a, const std::string& b, const std::string& expected_ab) { - Path p_a{}, p_b{}, p_expected_ab{}; - ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(b), p_b)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_ab), p_expected_ab)); - - EXPECT_EQ(p_a.Append(p_b).ToPathString(), p_expected_ab.ToPathString()); - }; - - check_append("/a/b", "c/d", "/a/b/c/d"); - check_append("/a/b", "/c/d", "/c/d"); - check_append("a/b", "c/d", "a/b/c/d"); - check_append("a/b", "/c/d", "/c/d"); -#ifdef _WIN32 - check_append(R"(C:\a\b)", R"(c\d)", R"(C:\a\b\c\d)"); - check_append(R"(C:\a\b)", R"(\c\d)", R"(C:\c\d)"); - check_append(R"(C:\a\b)", R"(D:c\d)", R"(D:c\d)"); - check_append(R"(C:\a\b)", R"(D:\c\d)", R"(D:\c\d)"); - check_append(R"(C:a\b)", R"(c\d)", R"(C:a\b\c\d)"); - check_append(R"(C:a\b)", R"(\c\d)", R"(C:\c\d)"); - check_append(R"(C:a\b)", R"(D:c\d)", R"(D:c\d)"); - check_append(R"(C:a\b)", R"(D:\c\d)", R"(D:\c\d)"); -#else // POSIX - check_append("//root_0/a/b", "c/d", "//root_0/a/b/c/d"); - check_append("//root_0/a/b", "/c/d", "/c/d"); - check_append("//root_0/a/b", "//root_1/c/d", "//root_1/c/d"); -#endif -} - -TEST(PathTest, RelativePath) { - auto check_relative = - [](const std::string& src, - const std::string& dst, - const std::string& expected_rel) { - Path p_src, p_dst, p_expected_rel, p_rel; - ASSERT_STATUS_OK(Path::Parse(ToPathString(src), p_src)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(dst), p_dst)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_rel), p_expected_rel)); - - ASSERT_STATUS_OK(RelativePath(p_src, p_dst, p_rel)); - EXPECT_EQ(p_rel.ToPathString(), p_expected_rel.ToPathString()); - }; - - check_relative( - "/a/b/c/d/e", "/a/b/c/d/e/f/g/h", - "f/g/h"); - check_relative( - "/a/b/c/d/e", "/a/b/f/g/h/i", - "../../../f/g/h/i"); - check_relative( - "a/b/../c/../d", "e/./f/../g/h", - "../../e/g/h"); -} - -TEST(PathTest, RelativePathFailure) { - auto check_relative_failure = - [](const std::string& src, - const std::string& dst) { - Path p_src, p_dst, p_rel; - ASSERT_STATUS_OK(Path::Parse(ToPathString(src), p_src)); - ASSERT_STATUS_OK(Path::Parse(ToPathString(dst), p_dst)); - - EXPECT_FALSE(RelativePath(p_src, p_dst, p_rel).IsOK()); - }; - - check_relative_failure("/rooted", "relative"); - check_relative_failure("relative", "/rooted"); -#ifdef _WIN32 - check_relative_failure("C:/a", "D:/a"); -#else // POSIX - check_relative_failure("//root_0/a", "//root_1/a"); -#endif -} - -#if !defined(ORT_NO_EXCEPTIONS) -TEST(PathTest, Concat) { - auto check_concat = - [](const optional& a, const std::string& b, const std::string& expected_a, bool expect_throw = false) { - Path p_a{}, p_expected_a{}; - if (a.has_value()) { - ASSERT_STATUS_OK(Path::Parse(ToPathString(*a), p_a)); - } - ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a)); - - if (expect_throw) { - EXPECT_THROW(p_a.Concat(ToPathString(b)).ToPathString(), OnnxRuntimeException); - } else { - EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString()); - } - }; - - check_concat({"/a/b"}, "c", "/a/bc"); - check_concat({"a/b"}, "cd", "a/bcd"); - check_concat({""}, "cd", "cd"); - check_concat({}, "c", "c"); -#ifdef _WIN32 - check_concat({"a/b"}, R"(c\d)", "", true /* expect_throw */); -#else - check_concat({"a/b"}, "c/d", "", true /* expect_throw */); -#endif -} -#endif - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index cb1ce885d2d45..9ab4a82463d51 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -8,7 +8,7 @@ #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" @@ -70,7 +70,7 @@ class RandomValueGenerator { // Random values generated are in the range [min, max). template typename std::enable_if< - std::is_same_v, + std::is_same_v || std::is_same_v, std::vector>::type Uniform(gsl::span dims, float min, float max) { std::vector val(detail::SizeFromDims(dims)); diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index b831507988da6..e0891c7ced63e 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 6ce9f5de68f11..5f94d30112f0e 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -4,7 +4,7 @@ #include #include #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index 73da82d4bb039..79070f0788f2a 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -4,7 +4,7 @@ #include #include #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc index d987a1cae427d..69789b84832e0 100644 --- a/onnxruntime/test/contrib_ops/sampling_test.cc +++ b/onnxruntime/test/contrib_ops/sampling_test.cc @@ -4,7 +4,7 @@ #include #include #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" diff --git a/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc b/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc index 17e26a57f5f3e..b2ab2d9a5701b 100644 --- a/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc +++ b/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc @@ -27,7 +27,7 @@ void VerifyTensorProtoFileData(const PathString& tensor_proto_path, gsl::span actual_data{}; actual_data.resize(expected_data.size()); - ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), actual_data.size())); + ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, tensor_proto_path, actual_data.data(), actual_data.size())); ASSERT_EQ(gsl::span(actual_data), expected_data); } @@ -48,7 +48,7 @@ void VerifyTensorProtoFileDataInt4(const PathString& tensor_proto_path, std::vector> actual_data{}; actual_data.resize(expected_data.size()); - ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), num_elems)); + ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, tensor_proto_path, actual_data.data(), num_elems)); ASSERT_EQ(actual_data.size(), expected_data.size()); diff --git a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc index f36dbaf3d1aca..467c5e773589a 100644 --- a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc +++ b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc @@ -9,11 +9,9 @@ #include "gtest/gtest.h" #include "core/common/common.h" -#include "core/common/path.h" #include "core/graph/graph_flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/cpu/cpu_execution_provider.h" - #include "test/flatbuffers/flatbuffers_utils_test.fbs.h" #include "test/util/include/asserts.h" @@ -117,6 +115,10 @@ ONNX_NAMESPACE::TensorProto CreateInitializer(const std::string& name, ORT_THROW("Unsupported data type: ", data_type); } + if constexpr (endian::native != endian::little) { + utils::ConvertRawDataInTensorProto(&tp); + } + return tp; } @@ -230,7 +232,7 @@ TEST(FlatbufferUtilsTest, ExternalWriteReadWithLoadInitializers) { std::vector> fbs_tensors; for (const auto& initializer : initializers) { flatbuffers::Offset fbs_tensor; - ASSERT_STATUS_OK(SaveInitializerOrtFormat(builder, initializer, Path(), fbs_tensor, writer)); + ASSERT_STATUS_OK(SaveInitializerOrtFormat(builder, initializer, std::filesystem::path(), fbs_tensor, writer)); fbs_tensors.push_back(fbs_tensor); } @@ -259,6 +261,9 @@ TEST(FlatbufferUtilsTest, ExternalWriteReadWithLoadInitializers) { for (const auto* fbs_tensor : *fbs_tensors2) { ONNX_NAMESPACE::TensorProto initializer; ASSERT_STATUS_OK(LoadInitializerOrtFormat(*fbs_tensor, initializer, options, reader)); + if constexpr (endian::native != endian::little) { + utils::ConvertRawDataInTensorProto(&initializer); + } loaded_initializers.emplace_back(std::move(initializer)); // also check that the loaded flatbuffer tensors have accurately written to the external_data_offset field if (fbs_tensor->data_type() != fbs::TensorDataType::STRING && fbs_tensor->name()->str() != "tensor_32_small") { @@ -313,7 +318,7 @@ TEST(FlatbufferUtilsTest, ExternalWriteReadWithLoadOrtTensor) { std::vector> fbs_tensors; for (const auto& initializer : initializers) { flatbuffers::Offset fbs_tensor; - ASSERT_STATUS_OK(SaveInitializerOrtFormat(builder, initializer, Path(), fbs_tensor, writer)); + ASSERT_STATUS_OK(SaveInitializerOrtFormat(builder, initializer, std::filesystem::path(), fbs_tensor, writer)); fbs_tensors.push_back(fbs_tensor); } diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 72eee5aca2638..bf15a9d35b56a 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1855,7 +1855,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { status = sess.RegisterExecutionProvider(DefaultCpuExecutionProvider()); ASSERT_TRUE(status.IsOK()); - ASSERT_TRUE(model.Save(model, "./simplified_ssd.onnx").IsOK()); + ASSERT_TRUE(model.Save(model, ORT_TSTR("./simplified_ssd.onnx")).IsOK()); std::string s1; const bool rc = model.ToProto().SerializeToString(&s1); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index d0520ebbcba5a..84389c1d9711c 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1278,7 +1278,7 @@ TEST(InferenceSessionTests, TestOptionalInputs) { } } -static void CreateFuseOpModel(const std::string& model_file_name) { +static void CreateFuseOpModel(const PathString& model_file_name) { onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); @@ -1312,7 +1312,7 @@ static void CreateFuseOpModel(const std::string& model_file_name) { } TEST(ExecutionProviderTest, FunctionTest) { - std::string model_file_name = "execution_provider_test_graph.onnx"; + PathString model_file_name = ORT_TSTR("execution_provider_test_graph.onnx"); CreateFuseOpModel(model_file_name); SessionOptions so; @@ -1365,7 +1365,7 @@ TEST(ExecutionProviderTest, FunctionTest) { } TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) { - std::string model_file_name = "fused_node_shape_inference_test_graph.onnx"; + PathString model_file_name = ORT_TSTR("fused_node_shape_inference_test_graph.onnx"); CreateFuseOpModel(model_file_name); @@ -1393,7 +1393,7 @@ TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) { } TEST(ExecutionProviderTest, OpKernelInfoCanReadConfigOptions) { - std::string model_file_name = "OpKernelInfoCanReadConfigOptions.onnx"; + PathString model_file_name = ORT_TSTR("OpKernelInfoCanReadConfigOptions.onnx"); CreateFuseOpModel(model_file_name); SessionOptions so; @@ -1580,7 +1580,7 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "3-layer-nested-subgraph-test.onnx"; + PathString model_file_name = ORT_TSTR("3-layer-nested-subgraph-test.onnx"); status = onnxruntime::Model::Save(model, model_file_name); ASSERT_TRUE(status.IsOK()); @@ -1732,7 +1732,7 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "2-layer-nested-subgraph-test.onnx"; + PathString model_file_name = ORT_TSTR("2-layer-nested-subgraph-test.onnx"); status = onnxruntime::Model::Save(model, model_file_name); ASSERT_TRUE(status.IsOK()); diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 1804c09043c7b..9278541b07512 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -139,7 +139,7 @@ TEST(TransformerTest, CastRemovalDoesNotLowerPrecisionTest) { status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - // When casting f64 -> f32 -> f64 we should not be optimising away the cast since there is a loss of precision. + // When casting f64 -> f32 -> f64 we should not be optimizing away the cast since there is a loss of precision. EXPECT_EQ(graph.NumberOfNodes(), 2); } @@ -171,7 +171,7 @@ TEST(TransformerTest, CastRemovalDoesNotRemoveSignednessTest) { status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - // When casting i32 -> ui32 -> i32 we should not be optimising away the cast since applying the casts produces a very different result. + // When casting i32 -> ui32 -> i32 we should not be optimizing away the cast since applying the casts produces a very different result. EXPECT_EQ(graph.NumberOfNodes(), 2); } diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index 19c7bf476e6e1..447b0edef879b 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/common.h" +#include "core/common/status.h" #include "core/common/path_string.h" #include "core/framework/data_types.h" #include "core/graph/model.h" @@ -17,75 +19,77 @@ using namespace onnxruntime; namespace onnxruntime { namespace test { -void LoadSaveAndCompareModel(const std::string& input_onnx, - const std::string& input_external_init_file, - const std::string& output_onnx, - const std::string& output_external_init_file, - size_t initializer_size_threshold) { +Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, + const std::filesystem::path& input_external_init_file, + const std::filesystem::path& output_onnx, + const std::filesystem::path& output_external_init_file, + size_t initializer_size_threshold) { + auto logger = DefaultLoggingManager().CreateLogger("LoadSaveAndCompareModel"); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(input_onnx), model, nullptr, DefaultLoggingManager().DefaultLogger())); - std::remove(output_onnx.c_str()); - std::remove(output_external_init_file.c_str()); - ASSERT_STATUS_OK(Model::SaveWithExternalInitializers(*model, ToPathString(output_onnx), output_external_init_file, initializer_size_threshold)); + ORT_RETURN_IF_ERROR(Model::Load(input_onnx, model, nullptr, *logger)); + std::filesystem::remove(output_onnx); + std::filesystem::remove(output_external_init_file); + ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(*model, output_onnx, output_external_init_file, initializer_size_threshold)); std::shared_ptr model_from_external; - ASSERT_STATUS_OK(Model::Load(ToPathString(output_onnx), model_from_external, nullptr, DefaultLoggingManager().DefaultLogger())); + ORT_RETURN_IF_ERROR(Model::Load(output_onnx.native(), model_from_external, nullptr, *logger)); Graph& graph = model->MainGraph(); // Perform shape inference on the graph, if this succeeds then it means that we could correctly read the // integer initializers used by reshape and transpose. - ASSERT_STATUS_OK(graph.Resolve()); + ORT_RETURN_IF_ERROR(graph.Resolve()); Graph& graph_from_external = model_from_external->MainGraph(); InitializedTensorSet initializers = graph.GetAllInitializedTensors(); InitializedTensorSet initializers_from_external = graph_from_external.GetAllInitializedTensors(); - ASSERT_EQ(initializers.size(), initializers_from_external.size()); + ORT_RETURN_IF_NOT(initializers.size() == initializers_from_external.size(), "size mismatch"); // Compare the initializers of the two versions. - Path model_path{}; - Path external_data_path{}; - for (auto i : initializers) { + std::filesystem::path model_path{}; + std::filesystem::path external_data_path{}; + for (const auto& i : initializers) { const std::string kInitName = i.first; const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; std::vector tensor_proto_data; - model_path = Path::Parse(ToPathString(input_onnx)); - external_data_path = (input_external_init_file.size()) ? model_path.ParentPath().Append(Path::Parse(ToPathString(input_external_init_file))) : Path(); - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*tensor_proto, external_data_path, tensor_proto_data)); + model_path = input_onnx; + external_data_path = (!input_external_init_file.empty()) ? (model_path.parent_path() / input_external_init_file) : std::filesystem::path(); + ORT_RETURN_IF_ERROR(utils::UnpackInitializerData(*tensor_proto, external_data_path, tensor_proto_data)); size_t tensor_proto_size = tensor_proto_data.size(); std::vector from_external_tensor_proto_data; - model_path = Path::Parse(ToPathString(output_onnx)); - external_data_path = model_path.ParentPath().Append(Path::Parse(ToPathString(output_external_init_file))); - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*from_external_tensor_proto, model_path, from_external_tensor_proto_data)); + model_path = output_onnx; + external_data_path = model_path.parent_path() / output_external_init_file; + ORT_RETURN_IF_ERROR(utils::UnpackInitializerData(*from_external_tensor_proto, model_path, from_external_tensor_proto_data)); size_t from_external_tensor_proto_size = from_external_tensor_proto_data.size(); if (from_external_tensor_proto_size < initializer_size_threshold) { // 'Small' tensors should be embedded in the onnx file. - EXPECT_EQ(from_external_tensor_proto->data_location(), ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_DEFAULT); + ORT_RETURN_IF_NOT(from_external_tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_DEFAULT, "location mismatch"); } else { // 'Large' tensors should be added to the external binary file. - EXPECT_EQ(from_external_tensor_proto->data_location(), ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + ORT_RETURN_IF_NOT(from_external_tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL, "location mismatch"); } - ASSERT_EQ(tensor_proto_size, from_external_tensor_proto_size); - EXPECT_EQ(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size), 0); + ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); + ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); } // Cleanup. - ASSERT_EQ(std::remove(output_onnx.c_str()), 0); - ASSERT_EQ(std::remove(PathToUTF8String(external_data_path.ToPathString()).c_str()), 0); + ORT_RETURN_IF_NOT(std::filesystem::remove(output_onnx), "delete file failed"); + ORT_RETURN_IF_NOT(std::filesystem::remove(external_data_path), "delete file failed"); + return Status::OK(); } // Original model does not have external initializers TEST(SaveWithExternalInitializers, Mnist) { - LoadSaveAndCompareModel("testdata/mnist.onnx", "", "testdata/mnist_with_external_initializers.onnx", "mnist_external_initializers.bin", 100); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100)); } // Original model has external initializers TEST(SaveWithExternalInitializers, ModelWithOriginalExternalData) { - LoadSaveAndCompareModel("testdata/model_with_orig_ext_data.onnx", "model_with_orig_ext_data.onnx.data", "testdata/model_with_new_external_initializers.onnx", "model_with_new_external_initializers.bin", 0); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0)); } } // namespace test diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 80f23b054a4ad..7bd6b47f52b7d 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -705,6 +705,9 @@ struct InsertIndices { // Conversion on the fly to the target data type std::vector indices(indices_data.cbegin(), indices_data.cend()); indices_tp.mutable_raw_data()->assign(reinterpret_cast(indices.data()), indices.size() * sizeof(T)); + if constexpr (endian::native != endian::little) { + utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&indices_tp); + } } } }; @@ -795,7 +798,7 @@ static void TestConversion(bool use_1D_indices, int32_t indices_type, TensorProto dense; // Path is required for loading external data (if any) // When path is empty it will look for the data in current dir - ASSERT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(node, Path(), dense)); + ASSERT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(node, std::filesystem::path(), dense)); gsl::span expected_span = gsl::make_span(expected.data(), expected.size()); checker(expected_span, dense); @@ -810,7 +813,7 @@ static void TestConversionAllZeros(bool use_1D_indices, TensorProto dense; // Path is required for loading external data (if any) // When path is empty it will look for the data in current dir - ASSERT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(node, Path(), dense)); + ASSERT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(node, std::filesystem::path(), dense)); gsl::span expected_span = gsl::make_span(expected.data(), expected.size()); checker(expected_span, dense); @@ -837,7 +840,7 @@ static void TestConversion( template static void RawDataWriter(const std::vector& values, TensorProto& tp, TensorProto_DataType datatype) { tp.set_data_type(datatype); - tp.set_raw_data(values.data(), values.size() * sizeof(T)); + utils::SetRawDataInTensorProto(tp, values.data(), values.size() * sizeof(T)); } int64_t ActualSize(const TensorProto& actual) { @@ -1109,30 +1112,31 @@ void RawSparseDataChecker(gsl::span expected_bfloat, } template -static void TestDenseToSparseConversionValues(size_t indices_start, - std::function& values, TensorProto& tp)> inserter, - std::function expected, - gsl::span expected_indicies, - const SparseTensorProto& actual)> - checker) { +static Status TestDenseToSparseConversionValues(size_t indices_start, + std::function& values, TensorProto& tp)> inserter, + std::function expected, + gsl::span expected_indicies, + const SparseTensorProto& actual)> + checker) { std::vector expected_values; std::vector expected_indicies; // Path is required for loading external data // Using empty path here since the data is not external - Path model_path; + std::filesystem::path model_path; TensorProto dense_tensor = CreateDenseTensor(indices_start, inserter, expected_values, expected_indicies); SparseTensorProto sparse_tensor; - ASSERT_STATUS_OK(utils::DenseTensorToSparseTensorProto(dense_tensor, model_path, sparse_tensor)); + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(dense_tensor, model_path, sparse_tensor)); gsl::span expected_values_span = gsl::make_span(expected_values.data(), expected_values.size()); gsl::span expected_ind_span = gsl::make_span(expected_indicies.data(), expected_indicies.size()); checker(expected_values_span, expected_ind_span, sparse_tensor); + return Status::OK(); } template -static void TestDenseAllZerosToSparseConversion( +static Status TestDenseAllZerosToSparseConversion( std::function& values, TensorProto& tp)> inserter, std::function expected, gsl::span expected_indicies, @@ -1142,55 +1146,56 @@ static void TestDenseAllZerosToSparseConversion( std::vector expected_indicies; // Path is required for loading external data // Using empty path here since the data is not external - Path model_path; + std::filesystem::path model_path; TensorProto dense_tensor = CreateDenseTensorAllZeros(inserter); SparseTensorProto sparse_tensor; - ASSERT_STATUS_OK(utils::DenseTensorToSparseTensorProto(dense_tensor, model_path, sparse_tensor)); + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(dense_tensor, model_path, sparse_tensor)); gsl::span expected_values_span = gsl::make_span(expected_values.data(), expected_values.size()); gsl::span expected_ind_span = gsl::make_span(expected_indicies.data(), expected_indicies.size()); checker(expected_values_span, expected_ind_span, sparse_tensor); + return Status::OK(); } template -static void TestDenseToSparseConversion(size_t indices_start, - std::function& values, TensorProto& tp)> inserter, - std::function expected, - gsl::span expected_indicies, - const SparseTensorProto& actual)> - checker) { - TestDenseToSparseConversionValues(indices_start, inserter, checker); - TestDenseAllZerosToSparseConversion(inserter, checker); +static Status TestDenseToSparseConversion(size_t indices_start, + std::function& values, TensorProto& tp)> inserter, + std::function expected, + gsl::span expected_indicies, + const SparseTensorProto& actual)> + checker) { + ORT_RETURN_IF_ERROR(TestDenseToSparseConversionValues(indices_start, inserter, checker)); + return TestDenseAllZerosToSparseConversion(inserter, checker); } TEST(SparseTensorConversionTests, TestDenseToSparseConversion) { // This one will test indices that are less than max int8 value // which should result in int8 indices - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_data_type(TensorProto_DataType_FLOAT); tp.set_name("dense_float"); tp.mutable_float_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); // This one will test indices that are max(int8) < ind < max(int16) value // which should result in int16 indices - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( static_cast(std::numeric_limits::max()) + 20U, [](const std::vector& values, TensorProto& tp) { tp.set_data_type(TensorProto_DataType_DOUBLE); tp.set_name("dense_double"); tp.mutable_double_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); // This one will test indices that are max(int16) < ind < max(int32) value // which should result in int32 indices - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( static_cast(std::numeric_limits::max()) + 20U, [](const std::vector& values, TensorProto& tp) { tp.set_data_type(TensorProto_DataType_BFLOAT16); @@ -1199,12 +1204,12 @@ TEST(SparseTensorConversionTests, TestDenseToSparseConversion) { tp.mutable_int32_data()->Add(v.val); } }, - RawSparseDataChecker); + RawSparseDataChecker)); // Protobuf can not hold anything more than 2Gb and it overflows. Can't test 64-bit indices // on conversion unless explicitly created. // which should result in int32 indices - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_data_type(TensorProto_DataType_FLOAT16); @@ -1213,78 +1218,78 @@ TEST(SparseTensorConversionTests, TestDenseToSparseConversion) { tp.mutable_int32_data()->Add(v.val); } }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_int16"); tp.set_data_type(TensorProto_DataType_INT16); tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_uint16"); tp.set_data_type(TensorProto_DataType_UINT16); tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_int32"); tp.set_data_type(TensorProto_DataType_INT32); tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_uint32"); tp.set_data_type(TensorProto_DataType_UINT32); tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_int64"); tp.set_data_type(TensorProto_DataType_INT64); tp.mutable_int64_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_uint64"); tp.set_data_type(TensorProto_DataType_UINT64); tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_int8"); tp.set_data_type(TensorProto_DataType_INT8); tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); }, - RawSparseDataChecker); + RawSparseDataChecker)); - TestDenseToSparseConversion( + ASSERT_STATUS_OK(TestDenseToSparseConversion( 20U, [](const std::vector& values, TensorProto& tp) { tp.set_name("dense_int64"); RawDataWriter(values, tp, TensorProto_DataType_UINT8); }, - RawSparseDataChecker); + RawSparseDataChecker)); } TEST(SparseTensorConversionTests, CsrConversion) { diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 42f82c374348e..6821f582ce2de 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -21,7 +21,7 @@ namespace test { // T must be float for double, and it must match with the 'type' argument template -void TestUnpackFloatTensor(TensorProto_DataType type, const Path& model_path) { +void TestUnpackFloatTensor(TensorProto_DataType type, const std::filesystem::path& model_path) { TensorProto float_tensor_proto; float_tensor_proto.set_data_type(type); T f[4] = {1.1f, 2.2f, 3.3f, 4.4f}; @@ -30,7 +30,7 @@ void TestUnpackFloatTensor(TensorProto_DataType type, const Path& model_path) { for (int i = 0; i < 4; ++i) { memcpy(rawdata + i * sizeof(T), &(f[i]), sizeof(T)); } - float_tensor_proto.set_raw_data(rawdata, len); + utils::SetRawDataInTensorProto(float_tensor_proto, rawdata, len); T float_data2[4]; auto status = UnpackTensor(float_tensor_proto, model_path, float_data2, 4); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); @@ -45,7 +45,7 @@ TEST(TensorProtoUtilsTest, UnpackTensor) { // Path is required for loading external data. // Using empty path here since this test does not test // external data utils - Path model_path; + std::filesystem::path model_path; bool_tensor_proto.set_data_type(TensorProto_DataType_BOOL); bool_tensor_proto.add_int32_data(1); @@ -102,8 +102,25 @@ std::vector CreateValues() { return {BFloat16(0.f), BFloat16(1.f), BFloat16(2.f), BFloat16(3.f)}; } +template +void ConvertEndianessForVector(const std::vector& test_data) { + const size_t element_size = sizeof(T); + const size_t num_elements = test_data.size(); + char* bytes = reinterpret_cast(const_cast(test_data.data())); + for (size_t i = 0; i < num_elements; ++i) { + char* start_byte = bytes + i * element_size; + char* end_byte = start_byte + element_size - 1; + for (size_t count = 0; count < element_size / 2; ++count) { + std::swap(*start_byte++, *end_byte--); + } + } +} + template void WriteDataToFile(FILE* fp, const std::vector& test_data) { + if constexpr (endian::native != endian::little) { + ConvertEndianessForVector(test_data); + } size_t size_in_bytes = test_data.size() * sizeof(T); ASSERT_EQ(size_in_bytes, fwrite(test_data.data(), 1, size_in_bytes, fp)); } @@ -142,11 +159,14 @@ void CreateTensorWithExternalData(TensorProto_DataType type, const std::vector -void UnpackAndValidate(const TensorProto& tensor_proto, const Path& model_path, const std::vector& test_data) { +void UnpackAndValidate(const TensorProto& tensor_proto, const std::filesystem::path& model_path, const std::vector& test_data) { // Unpack tensor with external data std::vector val(test_data.size()); auto st = utils::UnpackTensor(tensor_proto, model_path, val.data(), test_data.size()); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + if constexpr (endian::native != endian::little) { + ConvertEndianessForVector(val); + } // Validate data for (size_t i = 0; i < test_data.size(); i++) { @@ -155,7 +175,7 @@ void UnpackAndValidate(const TensorProto& tensor_proto, const Path& model_path, } template <> -void UnpackAndValidate(const TensorProto& tensor_proto, const Path& model_path, +void UnpackAndValidate(const TensorProto& tensor_proto, const std::filesystem::path& model_path, const std::vector& test_data) { // Unpack tensor with external data auto arr = std::make_unique(test_data.size()); @@ -169,7 +189,7 @@ void UnpackAndValidate(const TensorProto& tensor_proto, const Path& model_ } template -void TestUnpackExternalTensor(TensorProto_DataType type, const Path& model_path) { +void TestUnpackExternalTensor(TensorProto_DataType type, const std::filesystem::path& model_path) { // Create external data std::basic_string filename(ORT_TSTR("tensor_XXXXXX")); TensorProto tensor_proto; @@ -181,7 +201,7 @@ void TestUnpackExternalTensor(TensorProto_DataType type, const Path& model_path) } } // namespace TEST(TensorProtoUtilsTest, UnpackTensorWithExternalData) { - Path model_path; + std::filesystem::path model_path; TestUnpackExternalTensor(TensorProto_DataType_FLOAT, model_path); TestUnpackExternalTensor(TensorProto_DataType_DOUBLE, model_path); TestUnpackExternalTensor(TensorProto_DataType_INT32, model_path); @@ -225,7 +245,7 @@ static void TestConstantNodeConversion(const std::string& attrib_name, [&input, &add_data](AttributeProto& attrib) { add_data(attrib, input); }); TensorProto tp; - Path model_path; + std::filesystem::path model_path; EXPECT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(c, model_path, tp)); EXPECT_THAT(get_data(tp), ::testing::ContainerEq(input)); @@ -311,7 +331,7 @@ template static void TestConstantNodeConversionWithExternalData(TensorProto_DataType type) { // Create a constant node with external data auto test_data = CreateValues(); - Path model_path; + std::filesystem::path model_path; PathString tensor_filename(ORT_TSTR("tensor_XXXXXX")); auto c = CreateConstantNodeWithExternalData(type, tensor_filename, test_data); std::unique_ptr file_deleter(const_cast(tensor_filename.c_str()), @@ -325,6 +345,9 @@ static void TestConstantNodeConversionWithExternalData(TensorProto_DataType type std::vector val(test_data.size()); auto st = utils::UnpackTensor(tp, model_path, val.data(), test_data.size()); ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + if constexpr (endian::native != endian::little) { + ConvertEndianessForVector(val); + } for (size_t i = 0; i < test_data.size(); i++) { ASSERT_EQ(val[i], test_data[i]); } diff --git a/onnxruntime/test/framework/test_tensor_loader.cc b/onnxruntime/test/framework/test_tensor_loader.cc index 71d70abceb82e..73bf351b6c556 100644 --- a/onnxruntime/test/framework/test_tensor_loader.cc +++ b/onnxruntime/test/framework/test_tensor_loader.cc @@ -34,7 +34,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_not_enough_space) { OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); ASSERT_STATUS_NOT_OK( - utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); } @@ -55,7 +55,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_membuffer) { OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); ASSERT_STATUS_OK( - utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); float* real_output; @@ -83,7 +83,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_allocator) { AllocatorPtr tmp_allocator = std::make_shared(); OrtValue value; - ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, tmp_allocator, value)); + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, tmp_allocator, value)); float* real_output; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); @@ -104,6 +104,18 @@ static void run_external_data_test() { std::unique_ptr file_deleter(const_cast(filename.c_str()), DeleteFileFromDisk); float test_data[] = {1.0f, 2.2f, 3.5f}; + if constexpr (endian::native != endian::little) { + const int element_size = sizeof(float); + char* bytes = reinterpret_cast(test_data); + const size_t num_elements = std::size(test_data); + for (size_t i = 0; i < num_elements; ++i) { + char* start_byte = bytes + i * element_size; + char* end_byte = start_byte + element_size - 1; + for (size_t count = 0; count < element_size / 2; ++count) { + std::swap(*start_byte++, *end_byte--); + } + } + } ASSERT_EQ(sizeof(test_data), fwrite(test_data, 1, sizeof(test_data), fp)); ASSERT_EQ(0, fclose(fp)); // construct a tensor proto @@ -128,8 +140,12 @@ static void run_external_data_test() { len = GetCurrentDirectoryW(len, (ORTCHAR_T*)cwd.data()); ASSERT_NE(len, (DWORD)0); cwd.append(ORT_TSTR("\\fake.onnx")); +#else +#if defined(_AIX) + char* p = getcwd(nullptr, PATH_MAX); #else char* p = getcwd(nullptr, 0); +#endif ASSERT_NE(p, nullptr); cwd = p; free(p); @@ -139,7 +155,7 @@ static void run_external_data_test() { OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); ASSERT_STATUS_OK(utils::TensorProtoToOrtValue( - Env::Default(), nullptr, p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); + Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); float* real_output; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); @@ -190,7 +206,7 @@ TEST(CApiTensorTest, load_huge_tensor_with_external_data) { OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); ASSERT_STATUS_OK( - utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(int), cpu_memory_info), value)); int* buffer; diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index 0a99b4bc80212..51b02ee3e7f8c 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -9,7 +9,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/ort_value.h" -#include "core/common/gsl.h" +#include #ifdef USE_CUDA #include "core/providers/providers.h" diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index ff10765741bbe..5fc036790b765 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -388,7 +388,7 @@ TEST_F(GraphTest, UnusedValueInfoSerializes) { std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(std::move(m), model, nullptr, *logger_)); model->MainGraph().SetGraphProtoSyncNeeded(); - EXPECT_TRUE(Model::Save(*model, "graph_with_unused_value_info.onnx").IsOK()); + EXPECT_TRUE(Model::Save(*model, ORT_TSTR("graph_with_unused_value_info.onnx")).IsOK()); } TEST_F(GraphTest, WrongOpset) { @@ -762,7 +762,7 @@ TEST_F(GraphTest, GraphConstruction_CheckIsAcyclic) { auto status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - EXPECT_TRUE(Model::Save(model, "graph_1.onnx").IsOK()); + EXPECT_TRUE(Model::Save(model, ORT_TSTR("graph_1.onnx")).IsOK()); std::shared_ptr model2; EXPECT_TRUE(Model::Load(ORT_TSTR("graph_1.onnx"), model2, nullptr, *logger_).IsOK()); @@ -1476,7 +1476,7 @@ TEST_F(GraphTest, GraphConstruction_TypeInference) { EXPECT_EQ("node_4_out_1", graph.GetOutputs()[0]->Name()); EXPECT_EQ(2u, graph.GetInputs().size()); - EXPECT_TRUE(Model::Save(model, "model_x.onnx").IsOK()); + EXPECT_TRUE(Model::Save(model, ORT_TSTR("model_x.onnx")).IsOK()); std::shared_ptr loaded_model; EXPECT_TRUE(Model::Load(ORT_TSTR("model_x.onnx"), loaded_model, nullptr, *logger_).IsOK()); EXPECT_EQ(2u, loaded_model->MainGraph().GetInputs().size()); @@ -1782,8 +1782,8 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { {initializer_name, ort_value}}; // We do not need actual files there since we are not going to load it. - const auto tensor_data_dir_path = Path::Parse(ToPathString(".")); - const auto tensor_data_dir_relative_path = Path::Parse(ToPathString("external_data.bin")); + const auto tensor_data_dir_path = ORT_TSTR("."); + const auto tensor_data_dir_relative_path = ORT_TSTR("external_data.bin"); const auto tensor_proto = [&]() { @@ -1792,7 +1792,7 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { tensor_proto.add_dims(tensor_data.size()); tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - SetTensorProtoExternalData("location", ToUTF8String(tensor_data_dir_relative_path.ToPathString()), + SetTensorProtoExternalData("location", ToUTF8String(tensor_data_dir_relative_path), tensor_proto); SetTensorProtoExternalData("offset", "0", tensor_proto); SetTensorProtoExternalData("length", std::to_string(tensor_data.size() * sizeof(int32_t)), tensor_proto); @@ -1827,7 +1827,7 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { ASSERT_FALSE(utils::HasExternalData(*with_data)); const auto& original_tensor = ort_value.Get(); Tensor replaced_tensor(original_tensor.DataType(), data_shape, std::make_shared()); - ASSERT_STATUS_OK(utils::TensorProtoToTensor(Env::Default(), tensor_data_dir_path.ToPathString().c_str(), *with_data, + ASSERT_STATUS_OK(utils::TensorProtoToTensor(Env::Default(), tensor_data_dir_path, *with_data, replaced_tensor)); ASSERT_EQ(original_tensor.GetElementType(), replaced_tensor.GetElementType()); const auto original_span = original_tensor.DataAsSpan(); diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp new file mode 100644 index 0000000000000..9d15c9a6bf994 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/mlas/inc/mlas_q4.h" +#include "test/mlas/bench/bench_util.h" +#include "core/util/thread_utils.h" + +static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { + int M = state.range(0); + int N = state.range(1); + int quant_block_size = state.range(2); + int threads = state.range(3); + size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; + + auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); + auto scales = std::vector(scale_size); + auto zero_points = std::vector((scale_size + 1) / 2); + auto dst = std::vector((M * N + 1) / 2); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + for (auto _ : state) { + benchmark::DoNotOptimize(dst.data()); + MlasQDQQuantizeBlockwise( + src.data(), scales.data(), zero_points.data(), dst.data(), + true, M, N, quant_block_size, tp.get()); + benchmark::ClobberMemory(); + } +} + +static void BM_MlasQuantizeBlockwise(benchmark::State& state) { + int M = state.range(0); + int N = state.range(1); + int quant_block_size = state.range(2); + int threads = state.range(3); + size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; + + auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); + auto scales = std::vector(scale_size); + auto zero_points = std::vector((scale_size + 1) / 2); + auto dst = std::vector((M * N + 1) / 2); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + for (auto _ : state) { + benchmark::DoNotOptimize(dst.data()); + MlasQuantizeBlockwise( + dst.data(), scales.data(), zero_points.data(), src.data(), + quant_block_size, true, M, N, N, tp.get()); + benchmark::ClobberMemory(); + } +} + +static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { + int M = state.range(0); + int N = state.range(1); + int quant_block_size = state.range(2); + int threads = state.range(3); + bool add8 = state.range(4) != 0; + int quant_num_M = (M + quant_block_size - 1) / quant_block_size; + int blob_size = (quant_block_size + 1) / 2; + size_t scale_size = quant_num_M * N; + + auto scales = RandomVectorUniform(scale_size, -16.0f, 14.0f); + auto zero_points = RandomVectorUniform(static_cast((scale_size + 1) / 2), 0, 255); + auto dst = RandomVectorUniform(static_cast((M * N + 1) / 2), 0, 255); + auto scales_T = std::vector(scale_size); + auto zero_points_T = std::vector(((quant_num_M + 1) / 2) * N); + auto dst_T = std::vector(quant_num_M * blob_size * N); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + if (add8) { + for (auto _ : state) { + benchmark::DoNotOptimize(dst.data()); + MlasQDQTransposeBlockwiseQuantized( + dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), + true, M, N, quant_block_size, tp.get()); + benchmark::ClobberMemory(); + } + } else { + for (auto _ : state) { + benchmark::DoNotOptimize(dst.data()); + MlasQDQTransposeBlockwiseQuantized( + dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), + true, M, N, quant_block_size, tp.get()); + benchmark::ClobberMemory(); + } + } +} + +BENCHMARK(BM_QDQBlockwiseQuantizer_QuantizeColumnwise) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "quant_block_size", "threads"}); + b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {8}}); + }); + +BENCHMARK(BM_MlasQuantizeBlockwise) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "quant_block_size", "threads"}); + b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {8}}); + }); + +BENCHMARK(BM_QDQBlockwiseQuantizer_TransposeColumnwise) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "quant_block_size", "threads", "add8"}); + b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {2, 8, 16}, {0, 1}}); + }); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index 07f0748fb7ed1..f75002f715154 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -29,10 +29,18 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MatrixGuardBuffer OutputElements; MatrixGuardBuffer OutputScales; MatrixGuardBuffer OutputOffsets; + MatrixGuardBuffer QDQOutputElements; + MatrixGuardBuffer QDQOutputScales; + MatrixGuardBuffer QDQOutputOffsets; + MatrixGuardBuffer QDQTransposedOutputElements; + MatrixGuardBuffer QDQTransposedOutputScales; + MatrixGuardBuffer QDQTransposedOutputOffsets; void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); float* transposed = FpBuf2.GetBuffer(rows * columns, true); + size_t scale_size = (rows + block_size - 1) / block_size * columns; + size_t zp_size = (scale_size + 1) / 2; MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); @@ -49,6 +57,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); + uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); + uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); int v = 7; for (int c = 0; c < columns; c++) { @@ -75,7 +85,11 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } float* scales = InputScales.GetBuffer(q_scale_size); + float* qdq_scales = QDQOutputScales.GetBuffer(scale_size); + float* qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size); uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); + uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); + uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); if (zp) { for (int c = 0; c < meta_cols; c++) { for (int r = 0; r < meta_rows; r += 2) { @@ -112,16 +126,46 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); + if (columnwise) { + bool signed_quant = MlasQDQQuantizeBlockwise( + transposed, qdq_scales, qdq_zp, qdq_weights, + true, rows, columns, block_size, threadpool_ptr); + + ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + + if (symmetric) { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + + } else { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + } + } + for (int c = 0; c < columns; c++) { for (int r = 0; r < rows; r += 2) { int idx = c * q_rows + r / 2; ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (columnwise) { + ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 1 < rows) { ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (columnwise) { + ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } } @@ -132,6 +176,12 @@ class MlasBlockwiseQdqTest : public MlasTestBase { ASSERT_EQ(o_scales[idx], scales[idx]) << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + + if (columnwise) { + ASSERT_EQ(qdq_scales_T[idx], scales[idx]) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } @@ -142,10 +192,20 @@ class MlasBlockwiseQdqTest : public MlasTestBase { ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (columnwise) { + ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } if (r + 1 < meta_rows) { ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (columnwise) { + ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 71a6123b868bc..f391027de4d51 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -419,9 +419,10 @@ static size_t SQNBitGemmRegisterAllShortExecuteTests() { return count; } -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (is_short_execute) { - return SQNBitGemmRegisterAllShortExecuteTests() > 0; - } - return false; -}); +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 6d3e9c2cb7865..3319fdd34646b 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -796,7 +796,7 @@ void LoadTests(const std::vector>& input_paths auto test_case_dir = model_info->GetDir(); auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir.native(); -#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) +#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) && !defined(USE_VSINPU) // to skip some models like *-int8 or *-qdq if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h index 0cb92056d378e..745a1fe9eeb50 100644 --- a/onnxruntime/test/onnx/TestCase.h +++ b/onnxruntime/test/onnx/TestCase.h @@ -53,7 +53,8 @@ class TestModelInfo { public: virtual const std::filesystem::path& GetModelUrl() const = 0; virtual std::filesystem::path GetDir() const { - return GetModelUrl().parent_path(); + const auto& p = GetModelUrl(); + return p.has_parent_path() ? p.parent_path() : std::filesystem::current_path(); } virtual const std::string& GetNodeName() const = 0; virtual const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const = 0; diff --git a/onnxruntime/test/onnx/gen_test_models.py b/onnxruntime/test/onnx/gen_test_models.py index 1a64df2936810..a5224925251cf 100644 --- a/onnxruntime/test/onnx/gen_test_models.py +++ b/onnxruntime/test/onnx/gen_test_models.py @@ -144,7 +144,7 @@ def test_abs(output_dir): ) generate_abs_op_test( TensorProto.UINT16, - np.uint16([-32767, -4, 0, 3, 32767]), + np.uint16([0, 3, 32767, 65535]), os.path.join(output_dir, "test_abs_uint16"), ) generate_abs_op_test( diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0356bf5218cc2..9886d98dcc6d6 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -8,6 +8,8 @@ #include #ifdef _WIN32 #include "getopt.h" +#elif defined(_AIX) +#include #else #include #include @@ -44,7 +46,7 @@ void usage() { "\t-r [repeat]: Specifies the number of times to repeat\n" "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" - "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', " + "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', 'vsinpu'" "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'nnapi', 'qnn', 'snpe' or 'coreml'. " "Default: 'cpu'.\n" "\t-p: Pause after launch, can attach debugger and continue\n" @@ -169,6 +171,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_mem_pattern = true; bool enable_qnn = false; bool enable_nnapi = false; + bool enable_vsinpu = false; bool enable_coreml = false; bool enable_snpe = false; bool enable_dml = false; @@ -248,6 +251,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_qnn = true; } else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) { enable_nnapi = true; + } else if (!CompareCString(optarg, ORT_TSTR("vsinpu"))) { + enable_vsinpu = true; } else if (!CompareCString(optarg, ORT_TSTR("coreml"))) { enable_coreml = true; } else if (!CompareCString(optarg, ORT_TSTR("snpe"))) { @@ -561,6 +566,14 @@ int real_main(int argc, char* argv[], Ort::Env& env) { #else fprintf(stderr, "NNAPI is not supported in this build"); return -1; +#endif + } + if (enable_vsinpu) { +#ifdef USE_VSINPU + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_VSINPU(sf)); +#else + fprintf(stderr, "VSINPU is not supported in this build"); + return -1; #endif } if (enable_coreml) { diff --git a/onnxruntime/test/onnx/microbenchmark/quantize.cc b/onnxruntime/test/onnx/microbenchmark/quantize.cc index fda4324c0e83d..a6ab8484231b8 100644 --- a/onnxruntime/test/onnx/microbenchmark/quantize.cc +++ b/onnxruntime/test/onnx/microbenchmark/quantize.cc @@ -82,12 +82,11 @@ BENCHMARK(BM_Quantize) static void BM_BlockedQuantize_NotLastAxis(benchmark::State& state) { using Int4 = onnxruntime::Int4x2; using UnpackedType = Int4::UnpackedType; - const std::ptrdiff_t M[] = {96, 192, 192}; - const std::ptrdiff_t N[] = {2048, 2048, 4096}; - const int64_t size_idx = state.range(0); - const int64_t threads = state.range(1); + const int64_t M = state.range(0); + const int64_t N = state.range(1); const int64_t block_size = state.range(2); - size_t batch_size = M[size_idx] * N[size_idx]; + const int64_t threads = state.range(3); + size_t batch_size = M * N; size_t quant_block_size = 64; size_t scale_size = batch_size / quant_block_size; @@ -108,7 +107,7 @@ static void BM_BlockedQuantize_NotLastAxis(benchmark::State& state) { benchmark::DoNotOptimize(a_data_quant); onnxruntime::BlockedQuantizeLinear::opNotLastAxis( tp.get(), a_data, scale, reinterpret_cast(zero_point), reinterpret_cast(a_data_quant), - 1, M[size_idx], N[size_idx], static_cast(quant_block_size), + 1, M, N, static_cast(quant_block_size), static_cast(block_size), true); benchmark::ClobberMemory(); } @@ -121,12 +120,11 @@ static void BM_BlockedQuantize_NotLastAxis(benchmark::State& state) { static void BM_BlockedQuantize_LastAxis(benchmark::State& state) { using Int4 = onnxruntime::Int4x2; using UnpackedType = Int4::UnpackedType; - const std::ptrdiff_t M[] = {96, 192, 192}; - const std::ptrdiff_t N[] = {2048, 2048, 4096}; - const int64_t size_idx = state.range(0); - const int64_t threads = state.range(1); + const int64_t M = state.range(0); + const int64_t N = state.range(1); const int64_t quant_block_size = state.range(2); - size_t batch_size = M[size_idx] * N[size_idx]; + const int64_t threads = state.range(3); + size_t batch_size = M * N; size_t scale_size = batch_size / quant_block_size; float* a_data = GenerateArrayWithRandomValue(batch_size, -16, 14); @@ -146,7 +144,7 @@ static void BM_BlockedQuantize_LastAxis(benchmark::State& state) { benchmark::DoNotOptimize(a_data_quant); onnxruntime::BlockedQuantizeLinear::opLastAxis( tp.get(), a_data, scale, reinterpret_cast(zero_point), reinterpret_cast(a_data_quant), - M[size_idx], N[size_idx], static_cast(quant_block_size), true); + M, N, static_cast(quant_block_size), true); benchmark::ClobberMemory(); } aligned_free(a_data_quant); @@ -159,24 +157,14 @@ BENCHMARK(BM_BlockedQuantize_NotLastAxis) ->UseRealTime() ->Unit(benchmark::TimeUnit::kNanosecond) ->Apply([](benchmark::internal::Benchmark* b) { - for (int size_idx : {0, 1, 2}) { - for (int thread : {2, 4, 8}) { - for (int block_size : {64, 128}) { - b->Args({size_idx, thread, block_size}); - } - } - } + b->ArgNames({"M", "N", "block_size", "threads"}); + b->ArgsProduct({{1024, 4096}, {4096}, {128}, {2, 8}}); }); BENCHMARK(BM_BlockedQuantize_LastAxis) ->UseRealTime() ->Unit(benchmark::TimeUnit::kNanosecond) ->Apply([](benchmark::internal::Benchmark* b) { - for (int size_idx : {0, 1, 2}) { - for (int thread : {2, 4, 8}) { - for (int quant_block_size : {16, 64, 256}) { - b->Args({size_idx, thread, quant_block_size}); - } - } - } + b->ArgNames({"M", "N", "quant_block_size", "threads"}); + b->ArgsProduct({{1024, 4096}, {4096}, {64, 128}, {2, 8}}); }); diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index 5df055f862a86..50ab2290c6456 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include "mem_buffer.h" #include "core/common/safeint.h" @@ -68,11 +69,22 @@ static void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_length ORT_CXX_API_THROW(MakeString("UnpackTensor: the pre-allocated size does not match the raw data size, expected ", expected_size_in_bytes, ", got ", raw_data_length), OrtErrorCode::ORT_FAIL); + memcpy(p_data, raw_data, raw_data_length); if constexpr (endian::native != endian::little) { - ORT_CXX_API_THROW("UnpackTensorWithRawData only handles little-endian native byte order for now.", - OrtErrorCode::ORT_NOT_IMPLEMENTED); + /* Convert Endianness */ + char* bytes = reinterpret_cast(p_data); + size_t element_size = sizeof(T); + size_t num_elements = raw_data_length / element_size; + + for (size_t i = 0; i < num_elements; ++i) { + char* start_byte = bytes + i * element_size; + char* end_byte = start_byte + element_size - 1; + /* keep swapping */ + for (size_t count = 0; count < element_size / 2; ++count) { + std::swap(*start_byte++, *end_byte--); + } + } } - memcpy(p_data, raw_data, raw_data_length); } template <> diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 9609ec57c8b26..3e4e845440117 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -70,7 +70,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" @@ -4973,8 +4972,8 @@ TEST_F(GraphTransformationTests, CseWithConstantOfShape) { TensorProto value_tensor; value_tensor.add_dims(1); float value = 2.333f; - value_tensor.set_raw_data(reinterpret_cast(&value), sizeof(float)); value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + utils::SetRawDataInTensorProto(value_tensor, reinterpret_cast(&value), sizeof(float)); builder.AddNode("ConstantOfShape", {shape_out_1}, {constant_of_shape_out_1}).AddAttribute("value", value_tensor); builder.AddNode("ConstantOfShape", {shape_out_2}, {constant_of_shape_out_2}).AddAttribute("value", value_tensor); builder.AddNode("Mul", {input_arg, constant_of_shape_out_1}, {mul_out_1}); @@ -6216,9 +6215,10 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { std::make_unique(strategy, level, test_case.allow_ops), TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - Path p = Path::Parse(test_case.model_uri); - ASSERT_FALSE(p.GetComponents().empty()); - PathString transformed_model_uri = temp_dir.Path() + GetPathSep() + ORT_TSTR("transformed_") + p.GetComponents().back(); + std::filesystem::path p = test_case.model_uri; + PathString model_filename = ORT_TSTR("transformed_"); + model_filename += p.filename(); + std::filesystem::path transformed_model_uri = std::filesystem::path(temp_dir.Path()) / model_filename; ASSERT_STATUS_OK(Model::Save(*p_model, transformed_model_uri)); // Load the transformed model to validate ASSERT_STATUS_OK(Model::Load(transformed_model_uri, p_model, nullptr, *logger_)); @@ -7691,80 +7691,6 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } -TEST_F(GraphTransformationTests, ShapeInputMerge) { - auto build_test_case = [&](ModelTestBuilder& builder) { - std::vector> input_shape; - input_shape.reserve(5); - input_shape.emplace_back("dim0"); - input_shape.emplace_back(512); - input_shape.emplace_back(1); - input_shape.emplace_back(1536); - input_shape.emplace_back("dim4"); - auto* input_arg = builder.MakeSymbolicInput(input_shape); - auto* neg_out = builder.MakeIntermediate(); - auto* axes_initializer = builder.MakeInitializer({1}, {static_cast(2)}); - auto* squeeze_out = builder.MakeIntermediate(); - auto* cast_out = builder.MakeIntermediate(); - auto* unsqueeze_out = builder.MakeOutput(); - auto* shape_1_out = builder.MakeOutput(); - auto* shape_2_out = builder.MakeOutput(); - auto* shape_3_out = builder.MakeOutput(); - auto* shape_4_out = builder.MakeOutput(); - auto* shape_5_out = builder.MakeOutput(); - builder.AddNode("Neg", {input_arg}, {neg_out}); - builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out}); - builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast(10)); - builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out}); - builder.AddNode("Shape", {input_arg}, {shape_1_out}); - builder.AddNode("Shape", {neg_out}, {shape_2_out}); - builder.AddNode("Shape", {squeeze_out}, {shape_3_out}); - builder.AddNode("Shape", {cast_out}, {shape_4_out}); - builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - InlinedHashMap ref_count; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Shape") { - std::string name = node.InputDefs()[0]->Name(); - if (ref_count.find(name) == ref_count.end()) { - ref_count[name] = 1; - } else { - ref_count[name]++; - } - } - } - TEST_RETURN_IF_NOT(ref_count.size() == 5); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - InlinedHashMap ref_count; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Shape") { - std::string name = node.InputDefs()[0]->Name(); - if (ref_count.find(name) == ref_count.end()) { - ref_count[name] = 1; - } else { - ref_count[name]++; - } - } - } - TEST_RETURN_IF_NOT(ref_count.size() == 2); - int sum = 0, mul = 1; - for (auto& entry : ref_count) { - sum += entry.second; - mul *= entry.second; - } - TEST_RETURN_IF_NOT(sum == 5 && mul == 6); - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); -} - #if !defined(DISABLE_CONTRIB_OPS) TEST_F(GraphTransformationTests, MatMulNBitsBiasFusion) { diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 73c8b3f119103..2cbfbbb317642 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -61,7 +61,7 @@ NodeArg* ModelTestBuilder::MakeInitializer(gsl::span shape, ONNX_NAMESPACE::TensorProto tensor_proto; tensor_proto.set_name(name); tensor_proto.set_data_type(elem_type); - tensor_proto.set_raw_data(raw_data.data(), raw_data.size()); + utils::SetRawDataInTensorProto(tensor_proto, raw_data.data(), raw_data.size()); for (auto& dim : shape) { tensor_proto.add_dims(dim); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 1e2d34e5aefc4..6214094a26c4f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -13,6 +13,7 @@ #include "core/framework/int4.h" #include "core/optimizer/graph_transformer_level.h" #include "core/graph/onnx_protobuf.h" +#include "core/framework/tensorprotoutils.h" #include "test/framework/test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" @@ -249,7 +250,7 @@ class ModelTestBuilder { tensor_proto.set_data_type(utils::ToTensorProtoElementType()); std::unique_ptr data_buffer = std::make_unique(data.size()); for (size_t i = 0; i < data.size(); ++i) data_buffer[i] = data[i]; - tensor_proto.set_raw_data(data_buffer.get(), data.size()); + utils::SetRawDataInTensorProto(tensor_proto, data_buffer.get(), data.size()); for (auto& dim : shape) { tensor_proto.add_dims(dim); @@ -339,8 +340,10 @@ class ModelTestBuilder { bool use_ms_domain = false) { std::vector input_args; input_args.push_back(input_arg); - input_args.push_back(Make1DInitializer(input_scales)); - input_args.push_back(Make1DInitializer(input_zero_points)); + + std::vector qparams_shape = {static_cast(input_scales.size())}; + input_args.push_back(MakeInitializer(qparams_shape, input_scales)); + input_args.push_back(MakeInitializer(qparams_shape, input_zero_points)); std::string domain = use_ms_domain ? kMSDomain : ""; return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes); @@ -415,8 +418,10 @@ class ModelTestBuilder { bool use_ms_domain = false) { std::vector input_args; input_args.push_back(input_arg); - input_args.push_back(Make1DInitializer(input_scales)); - input_args.push_back(Make1DInitializer(input_zero_points)); + + std::vector qparams_shape = {static_cast(input_scales.size())}; + input_args.push_back(MakeInitializer(qparams_shape, input_scales)); + input_args.push_back(MakeInitializer(qparams_shape, input_zero_points)); std::string domain = use_ms_domain ? kMSDomain : ""; return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes); diff --git a/onnxruntime/test/optimizer/initializer_test.cc b/onnxruntime/test/optimizer/initializer_test.cc index ee93cfaa67e2a..391942acfca35 100644 --- a/onnxruntime/test/optimizer/initializer_test.cc +++ b/onnxruntime/test/optimizer/initializer_test.cc @@ -8,7 +8,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" @@ -51,12 +51,12 @@ TEST(OptimizerInitializerTest, LoadExternalData) { return tensor_data; }(); const gsl::span tensor_data_span = gsl::make_span(tensor_data); - const auto tensor_data_dir_path = Path::Parse(ToPathString(".")); - const auto tensor_data_dir_relative_path = Path::Parse(ToPathString("OptimizerInitializerTest_LoadExternalData.bin")); + const std::filesystem::path tensor_data_dir_path = ORT_TSTR("."); + const std::filesystem::path tensor_data_dir_relative_path = ORT_TSTR("OptimizerInitializerTest_LoadExternalData.bin"); ScopedFileDeleter file_deleter{}; ASSERT_STATUS_OK(WriteExternalDataFile( - tensor_data_span, (tensor_data_dir_path / tensor_data_dir_relative_path).ToPathString(), file_deleter)); + tensor_data_span, tensor_data_dir_path / tensor_data_dir_relative_path, file_deleter)); const auto tensor_proto_base = [&]() { @@ -65,7 +65,7 @@ TEST(OptimizerInitializerTest, LoadExternalData) { tensor_proto.add_dims(tensor_data.size()); tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - SetTensorProtoExternalData("location", ToUTF8String(tensor_data_dir_relative_path.ToPathString()), tensor_proto); + SetTensorProtoExternalData("location", ToUTF8String(tensor_data_dir_relative_path.native()), tensor_proto); SetTensorProtoExternalData("offset", "0", tensor_proto); SetTensorProtoExternalData("length", std::to_string(tensor_data.size() * sizeof(int32_t)), tensor_proto); return tensor_proto; @@ -95,8 +95,8 @@ TEST(OptimizerInitializerTest, LoadExternalData) { check_initializer_load(0, tensor_data.size() + 1); // bad model paths - EXPECT_THROW(Initializer i(tensor_proto_base, Path{}), OnnxRuntimeException); - EXPECT_THROW(Initializer i(tensor_proto_base, Path::Parse(ToPathString("invalid/directory"))), OnnxRuntimeException); + EXPECT_THROW(Initializer i(tensor_proto_base, std::filesystem::path()), OnnxRuntimeException); + EXPECT_THROW(Initializer i(tensor_proto_base, ORT_TSTR("invalid/directory")), std::filesystem::filesystem_error); // bad length { @@ -163,9 +163,9 @@ void TestInitializerRawData() { tensor_proto.set_name("OptimizerInitializerTest_RawData"); tensor_proto.add_dims(3); tensor_proto.add_dims(4); - tensor_proto.set_raw_data(data.data(), data.size() * sizeof(T)); - const Initializer init(tensor_proto, Path()); + utils::SetRawDataInTensorProto(tensor_proto, data.data(), data.size() * sizeof(T)); + const Initializer init(tensor_proto, std::filesystem::path()); for (size_t idx = 0; idx < data.size(); idx++) { EXPECT_EQ(data[idx], init.data()[idx]); @@ -220,35 +220,35 @@ void TestInitializerDataField() { AddData(data, idx, tensor_proto); } - const Initializer init(tensor_proto, Path()); + const Initializer init(tensor_proto, std::filesystem::path()); for (size_t idx = 0; idx < data.size(); idx++) { EXPECT_EQ(data[idx], init.data()[idx]); } } -#define TestInitializerDataFieldSpecialized(type) \ - template <> \ - void TestInitializerDataField() { \ - std::vector data{ \ - 0, 1, 2, 3, \ - 4, 5, 6, 7, \ - 8, 9, 10, 11}; \ - \ - ONNX_NAMESPACE::TensorProto tensor_proto; \ - tensor_proto.set_data_type(GetTensorProtoDataType()); \ - tensor_proto.set_name("OptimizerInitializerTest_DataField"); \ - tensor_proto.add_dims(3); \ - tensor_proto.add_dims(4); \ - for (size_t idx = 0; idx < data.size(); idx++) { \ - tensor_proto.add_##type##_data(data[idx]); \ - } \ - \ - const Initializer init(tensor_proto, Path()); \ - \ - for (size_t idx = 0; idx < data.size(); idx++) { \ - EXPECT_EQ(data[idx], init.data()[idx]); \ - } \ +#define TestInitializerDataFieldSpecialized(type) \ + template <> \ + void TestInitializerDataField() { \ + std::vector data{ \ + 0, 1, 2, 3, \ + 4, 5, 6, 7, \ + 8, 9, 10, 11}; \ + \ + ONNX_NAMESPACE::TensorProto tensor_proto; \ + tensor_proto.set_data_type(GetTensorProtoDataType()); \ + tensor_proto.set_name("OptimizerInitializerTest_DataField"); \ + tensor_proto.add_dims(3); \ + tensor_proto.add_dims(4); \ + for (size_t idx = 0; idx < data.size(); idx++) { \ + tensor_proto.add_##type##_data(data[idx]); \ + } \ + \ + const Initializer init(tensor_proto, std::filesystem::path()); \ + \ + for (size_t idx = 0; idx < data.size(); idx++) { \ + EXPECT_EQ(data[idx], init.data()[idx]); \ + } \ } typedef int64_t int64; diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 8e4edc9e0abbb..538f60040418c 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -6,6 +6,7 @@ #include "core/mlas/inc/mlas.h" #include "core/session/environment.h" #include "core/session/inference_session.h" +#include "core/framework/tensorprotoutils.h" #include "test/compare_ortvalue.h" #include "test/test_environment.h" #include "test/framework/test_utils.h" @@ -62,7 +63,7 @@ struct NchwcTestHelper { ONNX_NAMESPACE::TensorProto tensor_proto; tensor_proto.set_name(name); tensor_proto.set_data_type(utils::ToTensorProtoElementType()); - tensor_proto.set_raw_data(data.data(), data.size() * sizeof(T)); + utils::SetRawDataInTensorProto(tensor_proto, data.data(), data.size() * sizeof(T)); for (auto& dim : shape) { tensor_proto.add_dims(dim); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index c338d542b0b79..1c77121ba9df1 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3675,7 +3675,9 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicQDQCleanup) { }; // if we block removal of the DQ node the Q node in the pair will not be removed either - const int expected_qdq_count = 0 + (block_removal_of_first_dq ? 1 : 0) + (block_removal_of_last_dq ? 1 : 0); + // TODO(yilyu): block_removal_of_first_dq is not functional, need to fix it + // TODO(yilyu): block_removal_of_last_dq is not functional, need to fix it + const int expected_qdq_count = 0; // + (block_removal_of_first_dq ? 1 : 0) + (block_removal_of_last_dq ? 1 : 0); // blocking removal of DQ by adding an additional edge will cause EnsureUniqueDQForNodeUnit to duplicate the DQ, // so we expect twice as many DQ's as original QDQ pairs const int expected_dq_count = expected_qdq_count * 2; @@ -3725,19 +3727,136 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicQDQCleanup) { }; test_case({{1, 2, 4}, {1, 3, 4}}, false, false); // Do not block removal - test_case({{1, 2, 4}, {1, 3, 4}}, true, false); // Block removal of first dq - test_case({{1, 2, 4}, {1, 3, 4}}, false, true); // Block removal of last dq + test_case({{1, 2, 4}, {1, 3, 4}}, true, false); // Block removal of last dq + test_case({{1, 2, 4}, {1, 3, 4}}, false, true); // Block removal of first dq test_case({{1, 2, 4}, {1, 3, 4}}, true, true); // Block removal of first and last dq #if !defined(DISABLE_CONTRIB_OPS) // Use contrib QDQ ops test_case({{1, 2, 4}, {1, 3, 4}}, false, false, true); // Do not block removal - test_case({{1, 2, 4}, {1, 3, 4}}, true, false, true); // Block removal of first dq - test_case({{1, 2, 4}, {1, 3, 4}}, false, true, true); // Block removal of last dq + test_case({{1, 2, 4}, {1, 3, 4}}, true, false, true); // Block removal of last dq + test_case({{1, 2, 4}, {1, 3, 4}}, false, true, true); // Block removal of first dq test_case({{1, 2, 4}, {1, 3, 4}}, true, true, true); // Block removal of first and last dq #endif } +// test removal of Q->DQs pairs by QDQFinalCleanupTransformer +TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicQDQsCleanup) { + auto test_case = [&](bool is_input_q, + bool is_dq1_output, + bool is_dq2_output, + bool use_contrib_qdq) { + // create model with float Input -> (Transpose1 ->) Q -> DQ1 -> (Transpose2 ->) Output1 + // -> DQ2 -> (Transpose3 ->) Output2 + // If we enable cleanup and don't run the QDQ transformer we should drop all the Q->DQ pairs + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_args = builder.MakeInput({1, 2, 4}, -1.f, 1.f); + + // Add Input -> (Transpose1 ->) + NodeArg* q_input = nullptr; + if (is_input_q) { + q_input = input_args; + } else { + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {input_args}, {transpose_output}); + q_input = transpose_output; + } + + // Add Q -> + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(q_input, 0.05f, 128, q_output, use_contrib_qdq); + + // Add DQ1 -> (Transpose2 ->) Output1 + if (is_dq1_output) { + auto* dq1_output = builder.MakeOutput(); + builder.AddDequantizeLinearNode(q_output, 0.05f, 128, dq1_output, use_contrib_qdq); + } else { + auto* dq1_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(q_output, 0.05f, 128, dq1_output, use_contrib_qdq); + auto* output1 = builder.MakeOutput(); + builder.AddNode("Transpose", {dq1_output}, {output1}); + } + + // Add DQ2 -> (Transpose3 ->) Output2 + if (is_dq2_output) { + auto* dq2_output = builder.MakeOutput(); + builder.AddDequantizeLinearNode(q_output, 0.05f, 128, dq2_output, use_contrib_qdq); + } else { + auto* dq2_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(q_output, 0.05f, 128, dq2_output, use_contrib_qdq); + auto* output2 = builder.MakeOutput(); + builder.AddNode("Transpose", {dq2_output}, {output2}); + } + }; + + const int expected_transpose_count = (is_input_q ? 0 : 1) + (is_dq1_output ? 0 : 1) + (is_dq2_output ? 0 : 1); + + auto check_graph = [expected_transpose_count, + use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["Transpose"], expected_transpose_count); + }; + + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableQuantQDQCleanup, "1")); + }; + + // we increase the tolerance as removing the QDQ nodes means there's no round-trip to 8-bit and back + // essentially rounding the input values. + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + 0.025f /*per_sample_tolerance*/, + 0.01f /*relative_per_sample_tolerance*/, + std::make_unique(true /*enable_q_dq_cleanup*/), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + 0.025f /*per_sample_tolerance*/, + 0.01f /*relative_per_sample_tolerance*/, + std::make_unique(true /*enable_q_dq_cleanup*/), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + 0.025f /*per_sample_tolerance*/, + 0.01f /*relative_per_sample_tolerance*/, + std::make_unique(true /*enable_q_dq_cleanup*/), + add_session_options); + }; + + test_case(true, true, true, false); // is_input_q, is_dq1_output, is_dq2_output + test_case(false, true, true, false); // is_dq1_output, is_dq2_output + test_case(true, false, true, false); // is_input_q, is_dq2_output + test_case(false, false, true, false); // is_dq2_output + test_case(true, true, false, false); // is_input_q, is_dq1_output + test_case(false, true, false, false); // is_dq1_output + test_case(true, false, false, false); // is_input_q + test_case(false, false, false, false); + +#if !defined(DISABLE_CONTRIB_OPS) + // Use contrib QDQ ops + test_case(true, true, true, true); // is_input_q, is_dq1_output, is_dq2_output + test_case(false, true, true, true); // is_dq1_output, is_dq2_output + test_case(true, false, true, true); // is_input_q, is_dq2_output + test_case(false, false, true, true); // is_dq2_output + test_case(true, true, false, true); // is_input_q, is_dq1_output + test_case(false, true, false, true); // is_dq1_output + test_case(true, false, false, true); // is_input_q + test_case(false, false, false, true); +#endif +} + TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicDQQCleanUp) { auto test_case = [](bool use_matching_qdq_params, bool use_contrib_qdq) { // input -> Q -> DQ -> Q -> DQ -> output @@ -3801,6 +3920,129 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicDQQCleanUp) { #endif } +// test removal of DQ->Qs pairs by QDQFinalCleanupTransformer +TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicDQQsCleanup) { + auto test_case = [&](bool is_input_dq, + bool is_q1_output, + bool is_q2_output, + bool use_contrib_qdq) { + // create model with float Input -> (Transpose1 ->) DQ -> Q1 -> (Transpose2 ->) Output1 + // -> Q2 -> (Transpose3 ->) Output2 + // If we enable cleanup and don't run the QDQ transformer we should drop all the DQ->Q pairs + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_args = builder.MakeInput({1, 2, 4}, + std::numeric_limits::min(), + std::numeric_limits::max()); + + // Add Input -> (Transpose1 ->) + NodeArg* dq_input = nullptr; + if (is_input_dq) { + dq_input = input_args; + } else { + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {input_args}, {transpose_output}); + dq_input = transpose_output; + } + + // Add DQ -> + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(dq_input, 0.05f, 128, dq_output, use_contrib_qdq); + + // Add Q1 -> (Transpose2 ->) Output1 + if (is_q1_output) { + auto* q1_output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(dq_output, 0.05f, 128, q1_output, use_contrib_qdq); + } else { + auto* q1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(dq_output, 0.05f, 128, q1_output, use_contrib_qdq); + auto* output1 = builder.MakeOutput(); + builder.AddNode("Transpose", {q1_output}, {output1}); + } + + // Add Q2 -> (Transpose3 ->) Output2 + if (is_q2_output) { + auto* q2_output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(dq_output, 0.05f, 128, q2_output, use_contrib_qdq); + } else { + auto* q2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(dq_output, 0.05f, 128, q2_output, use_contrib_qdq); + auto* output2 = builder.MakeOutput(); + builder.AddNode("Transpose", {q2_output}, {output2}); + } + }; + + const int expected_transpose_count = (is_input_dq ? 0 : 1) + (is_q1_output ? 0 : 1) + (is_q2_output ? 0 : 1); + + auto check_graph = [expected_transpose_count, + use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["Transpose"], expected_transpose_count); + }; + + auto add_session_options = [](SessionOptions& so) { + // The function EnsureUniqueDQForEachExplicitOutputEdge does not account for this particular case. Disable it to + // prevent test failures. + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableQuantQDQCleanup, "1")); + }; + + // we increase the tolerance as removing the QDQ nodes means there's no round-trip to 8-bit and back + // essentially rounding the input values. + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + 0.025f /*per_sample_tolerance*/, + 0.01f /*relative_per_sample_tolerance*/, + std::make_unique(true /*enable_q_dq_cleanup*/), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + 0.025f /*per_sample_tolerance*/, + 0.01f /*relative_per_sample_tolerance*/, + std::make_unique(true /*enable_q_dq_cleanup*/), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + 0.025f /*per_sample_tolerance*/, + 0.01f /*relative_per_sample_tolerance*/, + std::make_unique(true /*enable_q_dq_cleanup*/), + add_session_options); + }; + + test_case(true, true, true, false); // is_input_dq, is_q1_output, is_q2_output + test_case(false, true, true, false); // is_q1_output, is_q2_output + test_case(true, false, true, false); // is_input_dq, is_q2_output + test_case(false, false, true, false); // is_q2_output + test_case(true, true, false, false); // is_input_dq, is_q1_output + test_case(false, true, false, false); // is_q1_output + test_case(true, false, false, false); // is_input_dq + test_case(false, false, false, false); + +#if !defined(DISABLE_CONTRIB_OPS) + // Use contrib QDQ ops + test_case(true, true, true, true); // is_input_dq, is_q1_output, is_q2_output + test_case(false, true, true, true); // is_q1_output, is_q2_output + test_case(true, false, true, true); // is_input_dq, is_q2_output + test_case(false, false, true, true); // is_q2_output + test_case(true, true, false, true); // is_input_dq, is_q1_output + test_case(false, true, false, true); // is_q1_output + test_case(true, false, false, true); // is_input_dq + test_case(false, false, false, true); +#endif +} + // test removal when we have graph input -> Q/DQ pair -> graph output TEST(QDQTransformerTests, QDQFinalCleanupTransformer_GraphInputToOutput) { auto test_case = [](bool is_q_dq, bool use_contrib_qdq) { diff --git a/onnxruntime/test/optimizer/resnet50_fusion_test.cc b/onnxruntime/test/optimizer/resnet50_fusion_test.cc index 04b11b46e5002..5cb0206156a84 100644 --- a/onnxruntime/test/optimizer/resnet50_fusion_test.cc +++ b/onnxruntime/test/optimizer/resnet50_fusion_test.cc @@ -61,13 +61,13 @@ TEST_F(ResNet50FusionTests, FuseConvIntegrationTest) { ASSERT_STATUS_OK(graph_transformation_mgr_32.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr_32.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr_32.ApplyTransformers(fp32_graph, TransformerLevel::Level3, *logger)); - ASSERT_STATUS_OK(Model::Save(*fp32_model, "resnet50_fused.onnx")); + ASSERT_STATUS_OK(Model::Save(*fp32_model, ORT_TSTR("resnet50_fused.onnx"))); onnxruntime::GraphTransformerManager graph_transformation_mgr_16{5}; ASSERT_STATUS_OK(graph_transformation_mgr_16.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr_16.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr_16.ApplyTransformers(fp16_graph, TransformerLevel::Level3, *logger)); - ASSERT_STATUS_OK(Model::Save(*fp16_model, "resnet50_fp16_fused.onnx")); + ASSERT_STATUS_OK(Model::Save(*fp16_model, ORT_TSTR("resnet50_fp16_fused.onnx"))); // std::cout << "-------Op Counts After Fusion---------" << std::endl; fp32_op_count = CountOpsInGraph(fp32_graph); fp16_op_count = CountOpsInGraph(fp16_graph); @@ -91,7 +91,7 @@ TEST_F(ResNet50FusionTests, FuseConvAddReluUnitTest) { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); op_to_count = CountOpsInGraph(graph); - ASSERT_STATUS_OK(Model::Save(*p_model, "conv_add_relu_fp16_fused.onnx")); + ASSERT_STATUS_OK(Model::Save(*p_model, ORT_TSTR("conv_add_relu_fp16_fused.onnx"))); ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph } @@ -109,7 +109,7 @@ TEST_F(ResNet50FusionTests, FuseConvAddUnitTest) { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); op_to_count = CountOpsInGraph(graph); - ASSERT_STATUS_OK(Model::Save(*p_model, "conv_add_fp16_fused.onnx")); + ASSERT_STATUS_OK(Model::Save(*p_model, ORT_TSTR("conv_add_fp16_fused.onnx"))); ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph } TEST_F(ResNet50FusionTests, FuseConvReluUnitTest) { @@ -126,9 +126,9 @@ TEST_F(ResNet50FusionTests, FuseConvReluUnitTest) { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level3)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); op_to_count = CountOpsInGraph(graph); - ASSERT_STATUS_OK(Model::Save(*p_model, "conv_relu_fp16_fused.onnx")); + ASSERT_STATUS_OK(Model::Save(*p_model, ORT_TSTR("conv_relu_fp16_fused.onnx"))); ASSERT_TRUE(op_to_count["Relu"] == 0); // Add removed from graph } #endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && !defined(DISABLE_CONTRIB_OPS) } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 175079d8197bf..b7c99fa66a1ea 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -261,6 +261,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.machine_config.provider_type_name = onnxruntime::kSnpeExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) { test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("vsinpu"))) { + test_config.machine_config.provider_type_name = onnxruntime::kVSINPUExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("coreml"))) { test_config.machine_config.provider_type_name = onnxruntime::kCoreMLExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("dml"))) { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 81053bf400a63..ff782da35cbe6 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -47,6 +47,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device const TestModelInfo& m) : rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) { Ort::SessionOptions session_options; + provider_name_ = performance_test_config.machine_config.provider_type_name; if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL @@ -221,150 +222,6 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device session_options.AppendExecutionProvider_CUDA(cuda_options); #else ORT_THROW("TensorRT is not supported in this build\n"); -#endif - } else if (provider_name_ == onnxruntime::kOpenVINOExecutionProvider) { -#ifdef USE_OPENVINO -#ifdef _MSC_VER - std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); -#else - std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; -#endif - std::unordered_map ov_options; - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [OpenVINO] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - - if (key == "device_type") { - std::set ov_supported_device_types = {"CPU", "GPU", - "GPU.0", "GPU.1", "NPU"}; - std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32", - "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16"}; - if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { - ov_options[key] = value; - } else if (deprecated_device_types.find(value) != deprecated_device_types.end()) { - ov_options[key] = value; - } else if (value.find("HETERO:") == 0) { - ov_options[key] = value; - } else if (value.find("MULTI:") == 0) { - ov_options[key] = value; - } else if (value.find("AUTO:") == 0) { - ov_options[key] = value; - } else { - ORT_THROW( - "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " - "Select from 'CPU', 'GPU', 'GPU.0', 'GPU.1', 'NPU' or from" - " HETERO/MULTI/AUTO options available. \n"); - } - } else if (key == "device_id") { - if (value == "CPU" || value == "GPU" || value == "NPU") { - ov_options[key] = value; - } else { - ORT_THROW("[ERROR] [OpenVINO] Unsupported device_id is selected. Select from available options."); - } - } else if (key == "precision") { - auto device_type = ov_options["device_type"]; - if (device_type.find("GPU") != std::string::npos) { - if (value == "") { - ov_options[key] = "FP16"; - continue; - } else if (value == "ACCURACY" || value == "FP16" || value == "FP32") { - ov_options[key] = value; - continue; - } else { - ORT_THROW( - "[ERROR] [OpenVINO] Unsupported inference precision is selected. " - "GPU only supported FP32 / FP16. \n"); - } - } else if (device_type.find("NPU") != std::string::npos) { - if (value == "" || value == "ACCURACY" || value == "FP16") { - ov_options[key] = "FP16"; - continue; - } else { - ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. NPU only supported FP16. \n"); - } - } else if (device_type.find("CPU") != std::string::npos) { - if (value == "" || value == "ACCURACY" || value == "FP32") { - ov_options[key] = "FP32"; - continue; - } else { - ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. CPU only supports FP32 . \n"); - } - } - } else if (key == "enable_npu_fast_compile") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "enable_opencl_throttling") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_opencl_throttling' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "enable_qdq_optimizer") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_qdq_optimizer' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "disable_dynamic_shapes") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW( - "[ERROR] [OpenVINO] The value for the key 'enable_dynamic_shapes' " - "should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "num_of_threads") { - if (std::stoi(value) <= 0) { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_of_threads' should be greater than 0\n"); - } else { - ov_options[key] = value; - } - } else if (key == "model_priority") { - ov_options[key] = value; - } else if (key == "cache_dir") { - ov_options[key] = value; - } else if (key == "context") { - ov_options[key] = value; - } else if (key == "num_streams") { - if (std::stoi(value) <= 0 && std::stoi(value) > 8) { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_streams' should be in the range of 1-8 \n"); - } else { - ov_options[key] = value; - } - } else if (key == "export_ep_ctx_blob") { - if (value == "true" || value == "True" || - value == "false" || value == "False") { - ov_options[key] = value; - } else { - ORT_THROW( - "[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' " - "should be a boolean i.e. true or false. Default value is false.\n"); - } - } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); - } - } - session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); -#else - ORT_THROW("OpenVINO is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kQnnExecutionProvider) { #ifdef USE_QNN @@ -540,6 +397,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, nnapi_flags)); #else ORT_THROW("NNAPI is not supported in this build\n"); +#endif + } else if (provider_name_ == onnxruntime::kVSINPUExecutionProvider) { +#ifdef USE_VSINPU + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_VSINPU(session_options)); +#else + ORT_THROW("VSINPU is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { #ifdef __APPLE__ @@ -716,7 +579,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else ORT_THROW("VitisAI is not supported in this build\n"); #endif - } else if (!provider_name_.empty() && provider_name_ != onnxruntime::kCpuExecutionProvider) { + } else if (!provider_name_.empty() && + provider_name_ != onnxruntime::kCpuExecutionProvider && + provider_name_ != onnxruntime::kOpenVINOExecutionProvider) { ORT_THROW("This backend is not included in perf test runner.\n"); } @@ -805,6 +670,151 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } } } + if (provider_name_ == onnxruntime::kOpenVINOExecutionProvider) { +#ifdef USE_OPENVINO +#ifdef _MSC_VER + std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); +#else + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; +#endif + std::unordered_map ov_options; + std::istringstream ss(ov_string); + std::string token; + while (ss >> token) { + if (token == "") { + continue; + } + auto pos = token.find("|"); + if (pos == std::string::npos || pos == 0 || pos == token.length()) { + ORT_THROW("[ERROR] [OpenVINO] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + auto key = token.substr(0, pos); + auto value = token.substr(pos + 1); + + if (key == "device_type") { + std::set ov_supported_device_types = {"CPU", "GPU", + "GPU.0", "GPU.1", "NPU"}; + std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32", + "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", + "GPU.0_FP16", "GPU.1_FP16"}; + if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { + ov_options[key] = value; + } else if (deprecated_device_types.find(value) != deprecated_device_types.end()) { + ov_options[key] = value; + } else if (value.find("HETERO:") == 0) { + ov_options[key] = value; + } else if (value.find("MULTI:") == 0) { + ov_options[key] = value; + } else if (value.find("AUTO:") == 0) { + ov_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " + "Select from 'CPU', 'GPU', 'GPU.0', 'GPU.1', 'NPU' or from" + " HETERO/MULTI/AUTO options available. \n"); + } + } else if (key == "device_id") { + if (value == "CPU" || value == "GPU" || value == "NPU") { + ov_options[key] = value; + } else { + ORT_THROW("[ERROR] [OpenVINO] Unsupported device_id is selected. Select from available options."); + } + } else if (key == "precision") { + auto device_type = ov_options["device_type"]; + if (device_type.find("GPU") != std::string::npos) { + if (value == "") { + ov_options[key] = "FP16"; + continue; + } else if (value == "ACCURACY" || value == "FP16" || value == "FP32") { + ov_options[key] = value; + continue; + } else { + ORT_THROW( + "[ERROR] [OpenVINO] Unsupported inference precision is selected. " + "GPU only supported FP32 / FP16. \n"); + } + } else if (device_type.find("NPU") != std::string::npos) { + if (value == "" || value == "ACCURACY" || value == "FP16") { + ov_options[key] = "FP16"; + continue; + } else { + ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. NPU only supported FP16. \n"); + } + } else if (device_type.find("CPU") != std::string::npos) { + if (value == "" || value == "ACCURACY" || value == "FP32") { + ov_options[key] = "FP32"; + continue; + } else { + ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. CPU only supports FP32 . \n"); + } + } + } else if (key == "enable_npu_fast_compile") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); + } + } else if (key == "enable_opencl_throttling") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_opencl_throttling' should be a boolean i.e. true or false. Default value is false.\n"); + } + } else if (key == "enable_qdq_optimizer") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_qdq_optimizer' should be a boolean i.e. true or false. Default value is false.\n"); + } + } else if (key == "disable_dynamic_shapes") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [OpenVINO] The value for the key 'enable_dynamic_shapes' " + "should be a boolean i.e. true or false. Default value is false.\n"); + } + } else if (key == "num_of_threads") { + if (std::stoi(value) <= 0) { + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_of_threads' should be greater than 0\n"); + } else { + ov_options[key] = value; + } + } else if (key == "model_priority") { + ov_options[key] = value; + } else if (key == "cache_dir") { + ov_options[key] = value; + } else if (key == "context") { + ov_options[key] = value; + } else if (key == "num_streams") { + if (std::stoi(value) <= 0 && std::stoi(value) > 8) { + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_streams' should be in the range of 1-8 \n"); + } else { + ov_options[key] = value; + } + } else if (key == "export_ep_ctx_blob") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' " + "should be a boolean i.e. true or false. Default value is false.\n"); + } + } else { + ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); + } + } + session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); +#else + ORT_THROW("OpenVINO is not supported in this build\n"); +#endif + } session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); diff --git a/onnxruntime/test/platform/file_io_test.cc b/onnxruntime/test/platform/file_io_test.cc index 3611e3b33446c..ccc703716844f 100644 --- a/onnxruntime/test/platform/file_io_test.cc +++ b/onnxruntime/test/platform/file_io_test.cc @@ -14,7 +14,7 @@ #include #endif -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 8d84c689cd23e..01de15e6f8ec8 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -73,7 +73,7 @@ void BaseTester::AddInitializers(onnxruntime::Graph& graph) { } } else { auto buffer_size = tensor.DataType()->Size() * shape.Size(); - tensor_proto.set_raw_data(tensor.DataRaw(), buffer_size); + utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), buffer_size); } // 4. name @@ -428,6 +428,7 @@ bool SetEpsForAllNodes(Graph& graph, if (provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kTensorrtExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || + provider_type == onnxruntime::kVSINPUExecutionProvider || provider_type == onnxruntime::kCoreMLExecutionProvider || provider_type == onnxruntime::kDnnlExecutionProvider || provider_type == onnxruntime::kQnnExecutionProvider || @@ -649,6 +650,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, kAclExecutionProvider, kArmNNExecutionProvider, kNnapiExecutionProvider, + kVSINPUExecutionProvider, kRocmExecutionProvider, kCoreMLExecutionProvider, kCoreMLExecutionProviderMLProgram, @@ -688,6 +690,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultTensorrtExecutionProvider(); else if (provider_type == onnxruntime::kNnapiExecutionProvider) execution_provider = DefaultNnapiExecutionProvider(); + else if (provider_type == onnxruntime::kVSINPUExecutionProvider) + execution_provider = DefaultVSINPUExecutionProvider(); else if (provider_type == onnxruntime::kRknpuExecutionProvider) execution_provider = DefaultRknpuExecutionProvider(); else if (provider_type == onnxruntime::kAclExecutionProvider) diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index be049d1cf0ce3..ec9b1614488a7 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -256,7 +256,7 @@ TEST(Random, MultinomialGoodCase) { const std::vector output_dims{batch_size, num_samples}; #ifdef _WIN32 const std::vector expected_output{2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 0}; -#elif defined(__MACH__) || defined(__ANDROID__) || defined(__FreeBSD__) || defined(__wasm__) +#elif defined(__MACH__) || defined(__ANDROID__) || defined(__FreeBSD__) || defined(__wasm__) || defined(_AIX) const std::vector expected_output{1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 1, 1, 2, 0, 2, 2, 0, 2, 1, 1}; #else const std::vector expected_output{2, 0, 0, 1, 0, 1, 2, 0, 1, 0, 0, 1, 1, 0, 1, 0, 2, 0, 2, 0}; @@ -294,7 +294,7 @@ TEST(Random, MultinomialDefaultDType) { #ifdef _WIN32 const std::vector expected_output_1{2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 0}; const std::vector expected_output_2{0, 0, 1, 0, 2, 2, 2, 0, 2, 1, 2, 1, 0, 2, 0, 2, 2, 1, 2, 1}; -#elif defined(__MACH__) || defined(__ANDROID__) || defined(__FreeBSD__) || defined(__wasm__) +#elif defined(__MACH__) || defined(__ANDROID__) || defined(__FreeBSD__) || defined(__wasm__) || defined(_AIX) const std::vector expected_output_1{1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 1, 1, 2, 0, 2, 2, 0, 2, 1, 1}; const std::vector expected_output_2{1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 2, 0, 1, 1, 0, 2, 2, 2, 1}; #else diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 423ea3f682f4c..73f31787e0597 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -274,16 +274,11 @@ TEST(Einsum, ExplicitEinsumAsMatmul_OutputTransposed) { } TEST(Einsum, ExplicitEinsumAsMatmul_2) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2068): The parameter is incorrect."; - } - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "ij,jk->ik"); test.AddInput("x", {2, 1}, {2.f, 3.f}); - test.AddInput("y", {2, 2}, {1.f, 2.f, 3.f, 4.f}); - test.AddOutput("o", {2, 2}, {8.f, 12.f, 12.f, 18.f}); + test.AddInput("y", {1, 2}, {1.f, 2.f}); + test.AddOutput("o", {2, 2}, {2.f, 4.f, 3.f, 6.f}); test.Run(); } @@ -325,16 +320,11 @@ TEST(Einsum, ImplicitEinsumAsBatchedMatmulWithBroadcasting_0) { } TEST(Einsum, ImplicitEinsumAsMatmul_2) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2068): The parameter is incorrect."; - } - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "ij,jk"); test.AddInput("x", {2, 1}, {2.f, 3.f}); - test.AddInput("y", {2, 2}, {1.f, 2.f, 3.f, 4.f}); - test.AddOutput("o", {2, 2}, {8.f, 12.f, 12.f, 18.f}); + test.AddInput("y", {1, 2}, {1.f, 2.f}); + test.AddOutput("o", {2, 2}, {2.f, 4.f, 3.f, 6.f}); test.Run(); } @@ -434,11 +424,6 @@ TEST(Einsum, ExplicitEinsumAsBatchedDiagonalOp_1) { // Implicit (Implicit diagonal ops will sum up diagonal values) TEST(Einsum, ImplicitEinsumAsDiagonalOp) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 5, which exceeds threshold"; - } - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "ii"); test.AddInput("x", {2, 2}, {1.f, 2.f, 3.f, 4.f}); @@ -447,11 +432,6 @@ TEST(Einsum, ImplicitEinsumAsDiagonalOp) { } TEST(Einsum, ImplicitEinsumAsDiagonalOp_1) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: error: The difference between expected[i] and output[i] is 15, which exceeds threshold"; - } - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "iii"); test.AddInput("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); @@ -576,8 +556,17 @@ TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeLeft) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "bsnh,btnh->bnts"); test.AddInput("x", {2, 1, 2, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); - test.AddInput("y", {2, 2, 2, 1}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - test.AddOutput("o", {2, 2, 2, 1}, {3.f, 9.f, 6.f, 12.f, 15.f, 21.f, 18.f, 24.f}); + test.AddInput("y", {2, 2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); + test.AddOutput("o", {2, 2, 2, 1}, {5.f, 17.f, 11.f, 23.f, 29.f, 41.f, 35.f, 47.f}); + test.Run(); +} + +TEST(Einsum, ExplicitEinsumAsTensorContractionSameInput) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "nchw,nchw->nch"); + test.AddInput("x", {1, 3, 2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + test.AddInput("y", {1, 3, 2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + test.AddOutput("o", {1, 3, 2}, {30.f, 174.f, 54.f, 230.f, 86.f, 294.f}); test.Run(); } diff --git a/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc b/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc index f7b8a61f48dab..c7fc73456dcba 100644 --- a/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "core/common/gsl.h" +#include using namespace std; namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index dcb592a4a254e..cb9887314eb66 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -30,6 +30,10 @@ #include "core/providers/nnapi/nnapi_provider_factory.h" #endif +#ifdef USE_VSINPU +#include "core/providers/vsinpu/vsinpu_provider_factory.h" +#endif + #ifdef USE_RKNPU #include "core/providers/rknpu/rknpu_provider_factory.h" #endif @@ -238,6 +242,11 @@ TEST_P(ModelTest, Run) { ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); } #endif +#ifdef USE_VSINPU + else if (provider_name == "vsinpu") { + ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); + } +#endif #ifdef USE_RKNPU else if (provider_name == "rknpu") { ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); @@ -406,6 +415,9 @@ static constexpr ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl"); #if defined(USE_NNAPI) && defined(__ANDROID__) static constexpr ORT_STRING_VIEW provider_name_nnapi = ORT_TSTR("nnapi"); #endif +#ifdef USE_VSINPU +static ORT_STRING_VIEW provider_name_vsinpu = ORT_TSTR("vsinpu"); +#endif #ifdef USE_RKNPU static constexpr ORT_STRING_VIEW provider_name_rknpu = ORT_TSTR("rknpu"); #endif @@ -447,6 +459,9 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #if defined(USE_NNAPI) && defined(__ANDROID__) provider_names[provider_name_nnapi] = {opset7, opset8, opset9, opset10, opset11, opset12, opset13, opset14, opset15, opset16, opset17, opset18}; #endif +#ifdef USE_VSINPU + provider_names[provider_name_vsinpu] = {}; +#endif #ifdef USE_RKNPU provider_names[provider_name_rknpu] = {}; #endif diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 4c77908322df4..421561a5a859b 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -5,7 +5,7 @@ #include "boost/mp11.hpp" -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index bd97306142f18..4fc2e6c7c909b 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -18,13 +18,17 @@ constexpr double DOUBLE_NINF = -std::numeric_limits::infinity(); constexpr double DOUBLE_NAN = std::numeric_limits::quiet_NaN(); template -void run_is_inf_test(int opset, int64_t detect_positive, int64_t detect_negative, const std::initializer_list& input, const std::initializer_list& output) { +void run_is_inf_test(int opset, int64_t detect_positive, int64_t detect_negative, const std::initializer_list& input, const std::initializer_list& output, bool skip_trt = false) { OpTester test("IsInf", opset); test.AddAttribute("detect_positive", detect_positive); test.AddAttribute("detect_negative", detect_negative); test.AddInput("X", {onnxruntime::narrow(input.size())}, input); test.AddOutput("Y", {onnxruntime::narrow(output.size())}, output); - test.Run(); + if (skip_trt) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + } else { + test.Run(); + } } TEST(IsInfTest, test_isinf_float10) { @@ -124,7 +128,7 @@ TEST(IsInfTest, test_isinf_bfloat16) { std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, BFloat16::NegativeInfinity, BFloat16::Infinity}; std::initializer_list output = {false, false, true, false, true, true}; - run_is_inf_test(20, 1, 1, input, output); + run_is_inf_test(20, 1, 1, input, output, true); // Skip as TRT10 supports BF16 but T4 GPU run on TRT CIs doesn't } TEST(IsInfTest, test_isinf_positive_bfloat16) { @@ -146,7 +150,7 @@ TEST(IsInfTest, test_Float8E4M3FN) { std::initializer_list input = { Float8E4M3FN(-1.0f), Float8E4M3FN(FLOAT_NAN, false), Float8E4M3FN(1.0f), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_INF, false)}; std::initializer_list output = {false, false, false, false, false, false}; - run_is_inf_test(20, 1, 1, input, output); + run_is_inf_test(20, 1, 1, input, output, true); // Skip as TRT10.1 supports Float8 but T4 GPU run on TRT CIs doesn't } TEST(IsInfTest, test_Float8E4M3FNUZ) { diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 1da9c9df299c9..af54ae96ef86b 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -35,8 +35,10 @@ void RunSliceTest(const std::vector& input_dims, excluded_providers.insert(excluded_providers_input.cbegin(), excluded_providers_input.cend()); // NNAPI EP does not support empty output + // VSINPU EP does not support empty output if (std::any_of(output_dims.cbegin(), output_dims.cend(), [](int64_t i) { return i == 0; })) { excluded_providers.insert(kNnapiExecutionProvider); + excluded_providers.insert(kVSINPUExecutionProvider); } // TODO: ORT behavior when step < 0 and end = INT_MAX is wrong. Fix it and @@ -515,6 +517,9 @@ TEST(SliceTest, Slice1D_ReverseAllAxes_1) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{2,2}] did not match run output shape [{0,0}] for output"; } + if (DefaultVSINPUExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{4}] did not match run output shape [{0}] for output"; + } RunSliceTest({4}, {1.0f, 2.0f, 3.0f, 4.0f}, diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 15a7d7cd9fdbf..066302a4a37d1 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -815,5 +815,62 @@ TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) { RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false); } +TEST(SplitOperatorTest, Split3Inner) { + constexpr int64_t axis = -1; + using ShapeAndDataT = ShapeAndData; + std::vector outputs; + int64_t num_outputs = -1; // when provides split_sizes, then num_outputs should not be provided + const int batch = 16; + const int data_len = 96; // should be multiple of 3 + + // create input with shape {b, l}, and data from 1 ~ b*l + auto input = CreateInput({batch, data_len}); // input is 1.f ~ 48.f + + // slice the input data by start and end in axis of -1 + auto gen_output = [&](int start, int end) { + std::vector data0; + auto input_data = input.second; + for (int b = 0; b < batch; b++) { + for (int i = start; i < end; i++) { + data0.push_back(input_data[b * data_len + i]); + } + } + return ShapeAndDataT{{batch, end - start}, data0}; + }; + + auto do_test = [&](std::vector& splits) { + outputs.clear(); + outputs.push_back(gen_output(0, splits[0])); + outputs.push_back(gen_output(splits[0], splits[1])); + outputs.push_back(gen_output(splits[1], data_len)); + + RunTest(axis, {splits[0], splits[1] - splits[0], data_len - splits[1]}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs); + }; + + // split into 3 same size, and aligned to 16 + std::vector splits{data_len / 3, data_len / 3 * 2}; + do_test(splits); + + // test split with data alignment is 8 + splits[0] = splits[0] + 8; + splits[1] = splits[1] + 8; + do_test(splits); + + // test split with data alignment is 4 + splits[0] = splits[0] + 4; + splits[1] = splits[1] + 4; + do_test(splits); + + // test split with data alignment is 2 + splits[0] = splits[0] + 2; + splits[1] = splits[1] + 2; + do_test(splits); + + // test split with data alignment is 1 + splits[0] = splits[0] + 1; + splits[1] = splits[1] + 1; + do_test(splits); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc index 6237521b34dfd..03db49a313af6 100644 --- a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc @@ -3,7 +3,7 @@ #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h index fa1c739c04e3a..f96c8ce9ce729 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -13,7 +13,7 @@ */ #pragma once - +#if defined(CUDA_VERSION) && CUDA_VERSION <= 12030 #include "test/cuda_host/blkq4_fp16_quant_sm80.h" #include @@ -197,3 +197,4 @@ void run_blkq4_small_gemm(int m, int n, int k); } // namespace test } // namespace cuda } // namespace onnxruntime +#endif diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index b95e093e41eab..3fcb9045ee7e6 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -10,7 +10,7 @@ * This part requires gtest header files, which do not play * well with CUTLASS headers. */ - +#if defined(CUDA_VERSION) && CUDA_VERSION <= 12030 #include "blkq4_fp16_gemm_sm80.h" #include "gtest/gtest.h" @@ -341,3 +341,4 @@ TEST(BlkQ4_GEMM, Sm80SmallTileKernelTest) { } // namespace test } // namespace onnxruntime +#endif diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index f5600ca9885a3..8b27c3d8c3aed 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -11,6 +11,9 @@ * well with gtest headers. */ +// This test has build error with cuda 12.5 +#if defined(CUDA_VERSION) && CUDA_VERSION <= 12030 + #include "blkq4_fp16_gemm_sm80.h" #include @@ -532,3 +535,5 @@ template void run_blkq4_small_gemm<128, false, false>(int m, int n, int k); } // namespace test } // namespace cuda } // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc index d8384b432786b..022c6250138d4 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc @@ -10,7 +10,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #include "core/providers/cuda/cuda_allocator.h" diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index b3e1025e7367c..1af7bdea68b67 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -14,6 +14,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/util/include/asserts.h" +#include "test/util/include/current_test_name.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test/test_environment.h" @@ -36,10 +37,6 @@ using namespace ::onnxruntime::logging; namespace onnxruntime { namespace test { -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - #if !defined(ORT_MINIMAL_BUILD) // Since NNAPI EP handles Reshape and Flatten differently, @@ -65,7 +62,8 @@ TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) { feeds.insert(std::make_pair("X", ml_value_x)); feeds.insert(std::make_pair("Y", ml_value_y)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -88,7 +86,8 @@ TEST(NnapiExecutionProviderTest, SigmoidSupportedInputRankTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.SigmoidSupportedInputRankTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds, {ExpectedEPNodeAssignment::None} /* params */); #else @@ -115,7 +114,8 @@ TEST(NnapiExecutionProviderTest, DynamicGraphInputTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.DynamicGraphInputTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -144,7 +144,8 @@ TEST(NnapiExecutionProviderTest, InternalUint8SupportTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.InternalUint8SupportTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -208,7 +209,8 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { feeds.insert(std::make_pair("Y", ml_value_y)); feeds.insert(std::make_pair("Z", ml_value_z)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.FunctionTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -273,7 +275,8 @@ static void RunQDQModelTest( const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); #if defined(__ANDROID__) - RunAndVerifyOutputsWithEP(model_data_span, "NnapiExecutionProviderTest.TestQDQModel", + RunAndVerifyOutputsWithEP(model_data_span, + CurrentTestName(), std::make_unique(0), helper.feeds_, params); #else @@ -513,6 +516,31 @@ TEST(NnapiExecutionProviderTest, TestGather) { {ExpectedEPNodeAssignment::All}); } +TEST(NnapiExecutionProviderTest, SharedInitializersDoNotGetSkipped) { + // NNAPI EP's Clip op builder will mark the max initializer as skipped but it is also used by the Div op. + // Test that the shared initializer is still present in the NNAPI model for the Div op. + constexpr auto* model_file_name = ORT_TSTR("testdata/clip_div_shared_initializer.onnx"); + +#if defined(__ANDROID__) + AllocatorPtr cpu_allocator = std::make_shared(); + + std::vector x_dims{3, 2}; + std::vector x_values(3.0f, 3 * 2); + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, x_dims, x_values, &ml_value_x); + + NameMLValMap feeds{{"input_0", ml_value_x}}; + + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), + std::make_unique(0), + feeds, + {ExpectedEPNodeAssignment::All}); +#else + TestModelLoad(model_file_name, std::make_unique(0), ExpectedEPNodeAssignment::All); +#endif +} + #endif // !(ORT_MINIMAL_BUILD) TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) { @@ -541,7 +569,8 @@ TEST(NnapiExecutionProviderTest, TestOrtFormatModel) { NameMLValMap feeds; feeds.insert(std::make_pair("Input3", ml_value)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.TestOrtFormatModel", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else diff --git a/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc b/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc index e86151008e24d..c514cf16b2f3c 100644 --- a/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc +++ b/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc @@ -15,28 +15,7 @@ namespace onnxruntime { namespace test { -// Builds a float32 model with ArgMin/ArgMax. -static GetTestModelFn BuildArgMxxTestCase(const std::string& op_type, TestInputDef input_def, - const std::vector& attrs) { - return [op_type, input_def, attrs](ModelTestBuilder& builder) { - auto* input = MakeTestInput(builder, input_def); - - auto* argm_output = builder.MakeIntermediate(); - Node& argm_node = builder.AddNode(op_type, {input}, {argm_output}); - for (const auto& attr : attrs) { - argm_node.AddAttributeProto(attr); - } - - // Add cast to uint32 - auto* output = builder.MakeOutput(); - Node& cast_node = builder.AddNode("Cast", {argm_output}, {output}); - const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; - cast_node.AddAttribute("to", static_cast(dst_type)); - }; -} - -// Builds a QDQ model with ArgMin/ArgMax and a Cast to uint32. The quantization parameters are computed from the provided -// input definition. +// Builds a QDQ model with ArgMin/ArgMax. The quantization parameters are computed from the provided input definition. template static GetTestQDQModelFn BuildQDQArgMxxTestCase(const std::string& op_type, TestInputDef input_def, const std::vector& attrs) { @@ -49,17 +28,11 @@ static GetTestQDQModelFn BuildQDQArgMxxTestCase(const std::string& op_typ // input -> Q -> DQ -> auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - auto* argm_output = builder.MakeIntermediate(); + auto* argm_output = builder.MakeOutput(); Node& argm_node = builder.AddNode(op_type, {input_qdq}, {argm_output}); for (const auto& attr : attrs) { argm_node.AddAttributeProto(attr); } - - // Cast to uint32 (HTP does not support int64 as graph output) - auto* output = builder.MakeOutput(); - Node& cast_node = builder.AddNode("Cast", {argm_output}, {output}); - const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; - cast_node.AddAttribute("to", static_cast(dst_type)); }; } @@ -77,7 +50,7 @@ static void RunCPUArgMxxOpTest(const std::string& op_type, TestInputDef i provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildArgMxxTestCase(op_type, input_def, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -98,7 +71,7 @@ static void RunQDQArgMxxOpTest(const std::string& op_type, TestInputDef i provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildArgMxxTestCase(op_type, input_def, attrs), // baseline float32 model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, attrs), // baseline float32 model BuildQDQArgMxxTestCase(op_type, input_def, attrs), // QDQ model provider_options, opset, @@ -190,48 +163,6 @@ TEST_F(QnnHTPBackendTests, ArgMaxMinU8_RankGreaterThan4_Unsupported) { ExpectedEPNodeAssignment::None, 13); } -// Test that ArgMax/ArgMin are not supported if they generate a graph output. -TEST_F(QnnHTPBackendTests, ArgMaxMin_AsGraphOutputUnsupported) { - ProviderOptions provider_options; - -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Utility function that creates a QDQ model with ArgMax/ArgMin that produce a graph output. - auto model_builder_func = [](const std::string& op_type, const TestInputDef& input_def, - const std::vector& attrs) -> GetTestModelFn { - return [op_type, input_def, attrs](ModelTestBuilder& builder) { - QuantParams input_qparams = GetTestInputQuantParams(input_def); - - auto* input = MakeTestInput(builder, input_def); - auto* output = builder.MakeOutput(); - - // input -> Q -> DQ -> - auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - Node& argm_node = builder.AddNode(op_type, {input_qdq}, {output}); - for (const auto& attr : attrs) { - argm_node.AddAttributeProto(attr); - } - }; - }; - - const int expected_nodes_in_graph = -1; // Don't care exactly how many nodes in graph assigned to CPU EP. - RunQnnModelTest(model_builder_func("ArgMax", TestInputDef({1, 3, 4}, false, -1.0f, 1.0f), {}), - provider_options, - 13, - ExpectedEPNodeAssignment::None, // No nodes should be assigned to QNN EP! - expected_nodes_in_graph); - RunQnnModelTest(model_builder_func("ArgMin", TestInputDef({1, 3, 4}, false, -1.0f, 1.0f), {}), - provider_options, - 13, - ExpectedEPNodeAssignment::None, // No nodes should be assigned to QNN EP! - expected_nodes_in_graph); -} - #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index 0b2d168b06ca1..f03782c33c30a 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -107,6 +107,20 @@ TEST_F(QnnHTPBackendTests, TestCastFloatToInt32HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, true); } + +// Cast int64_t to int32_t on HTP +// Supported in QNN SDK 2.23 +TEST_F(QnnHTPBackendTests, TestCastInt64ToInt32HTP) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, + ExpectedEPNodeAssignment::All, true); +} + +// Cast int32_t to int64_t on HTP +// Supported in QNN SDK 2.23 +TEST_F(QnnHTPBackendTests, TestCastInt32ToInt64HTP) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64, + ExpectedEPNodeAssignment::All, true); +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 5177a629ce292..b07951d2a2e6d 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -182,7 +182,12 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s static_cast(weight_quant_axis), true); TensorShape weights_shape = weights_def.GetTensorShape(); - std::vector quantized_weights(weights_shape.Size()); + std::vector quantized_weights; + size_t num_weight_storage_elems = weights_shape.Size(); + if constexpr (std::is_same_v || std::is_same_v) { + num_weight_storage_elems = Int4x2::CalcNumInt4Pairs(weights_shape.Size()); + } + quantized_weights.resize(num_weight_storage_elems); QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, weight_scales, weight_zero_points, weight_quant_axis); @@ -727,6 +732,80 @@ TEST_F(QnnHTPBackendTests, ConvU8S8S32_PerChannel) { 13); // opset } +// Test per-channel QDQ Conv with INT4 weights. in0: u16, in1 (weight): s4, in2 (bias): s32, out: u8 +TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} + +// Test per-channel QDQ Conv with INT4 weights. in0: u16, in1 (weight): s4, in2 (bias): s32, out: u8 +// TODO(adrianlizarraga): Investigate inaccuracy for QNN EP. +// +// Output values for all EPs: +// CPU EP (f32 model): 25.143 21.554 17.964 10.785 7.195 3.605 -3.574 -7.164 -10.753 +// CPU EP (qdq model): 24.670 21.103 17.536 10.254 6.689 2.972 -4.161 -7.728 -10.700 +// QNN EP (qdq model): 27.186 27.186 27.186 21.541 6.685 -8.022 -10.548 -10.548 -10.548 +TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + // Wrote out input data explicitly for easier reproduction. + // std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size()); + std::vector input_data = {-10.000f, -9.355f, -8.710f, -8.065f, -7.419f, -6.774f, -6.129f, -5.484f, -4.839f, + -4.194f, -3.548f, -2.903f, -2.258f, -1.613f, -0.968f, -0.323f, 0.323f, 0.968f, + 1.613f, 2.258f, 2.903f, 3.548f, 4.194f, 4.839f, 5.484f, 6.129f, 6.774f, + 7.419f, 8.065f, 8.710f, 9.355f, 10.000f}; + + // std::vector weight_data = GetFloatDataInRange(-1.0f, 1.0f, TensorShape(weight_shape).Size()); + std::vector weight_data = {-1.000f, -0.913f, -0.826f, -0.739f, -0.652f, -0.565f, -0.478f, -0.391f, -0.304f, + -0.217f, -0.130f, -0.043f, 0.043f, 0.130f, 0.217f, 0.304f, 0.391f, 0.478f, + 0.565f, 0.652f, 0.739f, 0.826f, 0.913f, 1.000f}; + + // std::vector bias_data = GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size()); + std::vector bias_data = {-1.000f, 0.000f, 1.000f}; + + TestInputDef input_def(input_shape, false, input_data); + TestInputDef weight_def(weight_shape, true, weight_data); + TestInputDef bias_def(bias_shape, true, bias_data); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} + // Test per-channel QDQ Conv is rejected with weight axis != 0 TEST_F(QnnHTPBackendTests, Conv_PerChannel_UnsupportedAxis) { std::vector input_shape = {1, 2, 4, 4}; diff --git a/onnxruntime/test/providers/qnn/pad_op_test.cpp b/onnxruntime/test/providers/qnn/pad_op_test.cpp index 4ef71457d5bfe..a6b8664c6c0c9 100644 --- a/onnxruntime/test/providers/qnn/pad_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pad_op_test.cpp @@ -98,18 +98,33 @@ static void RunPadOpTest(const TestInputDef& data_def, const std::vector& attrs, ExpectedEPNodeAssignment expected_ep_assignment, bool has_constant_value = true, - int opset = 18) { + int opset = 18, + bool use_htp = false, + bool enable_fp16_precision = false, + float f32_abs_err = 1e-5f) { ProviderOptions provider_options; + if (use_htp) { #if defined(_WIN32) - provider_options["backend_path"] = "QnnCpu.dll"; + provider_options["backend_path"] = "QnnHtp.dll"; #else - provider_options["backend_path"] = "libQnnCpu.so"; + provider_options["backend_path"] = "libQnnHtp.so"; #endif + } else { +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + } + + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } RunQnnModelTest(BuildPadTestCase(data_def, pads_def, constant_value_def, attrs, has_constant_value), provider_options, opset, - expected_ep_assignment); + expected_ep_assignment, f32_abs_err); } // Runs a QDQ Pad model on the QNN HTP backend. Checks the graph node assignment, and that inference @@ -229,6 +244,60 @@ TEST_F(QnnCPUBackendTests, Pad6d) { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: +TEST_F(QnnHTPBackendTests, PadNoConstantValue_fp16_test) { + bool has_constant_value_input = false; + bool use_htp = true; + bool enable_fp16_precision = true; + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 2, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "constant")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input, + 18, // opset + use_htp, + enable_fp16_precision, + 2e-3f); +} + +TEST_F(QnnHTPBackendTests, PadReflectMode_fp16) { + bool has_constant_value_input = false; + bool use_htp = true; + bool enable_fp16_precision = true; + RunPadOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.6f}), + TestInputDef({4}, true, {0, 1, 0, 0}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "reflect")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input, + 18, // opset + use_htp, + enable_fp16_precision, + 2e-3f); +} + +// HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::flat_from_vtcm +// HTP\HTP\src\hexagon\prepare\graph_prepare.cc:1238:ERROR:Op 0x104100000011 preparation failed with err:-1 +// Completed stage: Graph Transformations and Optimizations (13372 us) +// QnnDsp "node" generated: could not create op +// QnnDsp RouterWindows graph prepare failed 12 +// QnnDsp Failed to finalize graph (id: 1) with err 1002 +TEST_F(QnnHTPBackendTests, DISABLED_PadReflectMode_FP16_big_data) { + bool has_constant_value_input = false; + bool use_htp = true; + bool enable_fp16_precision = true; + RunPadOpTest(TestInputDef({1, 4, 512, 512}, false, GetFloatDataInRange(1.0f, 10.0f, 4 * 512 * 512)), + TestInputDef({8}, true, {0, 0, 3, 3, 0, 0, 3, 3}), + TestInputDef({1}, true, {0.0f}), + {utils::MakeAttribute("mode", "reflect")}, + ExpectedEPNodeAssignment::All, + has_constant_value_input, + 18, // opset + use_htp, + enable_fp16_precision, + 2e-3f); +} + // // QDQ Pad TEST_F(QnnHTPBackendTests, PadNoConstantValue) { diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 6173f46839a81..9489d354755e4 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -236,6 +236,51 @@ TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { Ort::Session session(*ort_env, ort_model_path, so); } +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +TEST_F(QnnHTPBackendTests, RunConvInt4Model) { + Ort::SessionOptions so; + + so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP. + so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + so.AppendExecutionProvider("QNN", options); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv.int4_weights.qdq.onnx"; + Ort::Session session(*ort_env, ort_model_path, so); + + TensorShape input_shape = {1, 3, 8, 8}; + std::vector input0_data(input_shape.Size(), 0.2f); + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add input0 + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), &input_shape[0], input_shape.NumDimensions())); + ort_input_names.push_back("input_0"); + + // Run session and get outputs + std::array output_names{"output_0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output shape. + Ort::Value& ort_output = ort_outputs[0]; + auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); + std::vector output_shape = typeshape.GetShape(); + + EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 5, 6, 6)); +} +#endif // #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + // Helper function that runs an ONNX model with a NHWC Resize operator to test that // type/shape inference succeeds during layout transformation. // Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index dd47e1df80000..ad54e644af3f7 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -34,6 +34,15 @@ struct QuantParams { QType zero_point; static QuantParams Compute(float rmin, float rmax, bool symmetric = false) { + return Compute( + rmin, + rmax, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()), + symmetric); + } + + static QuantParams Compute(float rmin, float rmax, QType qmin, QType qmax, bool symmetric = false) { // Ensure a minimum range of 0.0001 (required by QNN) rmax = std::max(rmax, rmin + 0.0001f); @@ -41,27 +50,27 @@ struct QuantParams { rmin = std::min(rmin, 0.0f); rmax = std::max(rmax, 0.0f); - constexpr float qmin = static_cast(std::numeric_limits::min()); - constexpr float qmax = static_cast(std::numeric_limits::max()); - if (symmetric) { const float abs_max = std::max(std::abs(rmin), std::abs(rmax)); rmax = abs_max; rmin = -abs_max; } - const float scale = (rmax - rmin) / (qmax - qmin); + float qmin_flt = static_cast(qmin); + float qmax_flt = static_cast(qmax); + const float scale = (rmax - rmin) / (qmax_flt - qmin_flt); float initial_zero_point = 0.0f; if (symmetric) { // Symmetric uses same formula for zero-point as asymmetric, but we can cancel out terms for // increased numerical accuracy. - initial_zero_point = (qmin + qmax) / 2.0f; + initial_zero_point = (qmin_flt + qmax_flt) / 2.0f; } else { - initial_zero_point = qmin - (rmin / scale); + initial_zero_point = qmin_flt - (rmin / scale); } - const QType zero_point = static_cast(RoundHalfToEven(std::max(qmin, std::min(qmax, initial_zero_point)))); + const QType zero_point = static_cast(RoundHalfToEven(std::max(qmin_flt, + std::min(qmax_flt, initial_zero_point)))); return QuantParams{scale, zero_point}; } @@ -238,7 +247,7 @@ struct TestInputDef { assert(which_type == 0); const std::vector& raw_data = std::get(data_info_).data; - std::pair init_range(std::numeric_limits::max(), std::numeric_limits::min()); + std::pair init_range(std::numeric_limits::max(), std::numeric_limits::lowest()); std::vector> per_axis_ranges(num_ranges, init_range); TensorShape shape(shape_); size_t num_blocks = shape.SizeToDimension(axis); @@ -292,6 +301,37 @@ static void GetTestInputQuantParamsPerChannel(const TestInputDef& input_d } } +// Define functions to get the quantization parameters (i.e., scale/zp) for input data that will be quantized +// as int4 per-channel. +#define DEF_GET_INPUT_QPARAMS_PER_CHAN_INT4_FUNC(INT4x2_TYPE) \ + template <> \ + inline void GetTestInputQuantParamsPerChannel(const TestInputDef& input_def, \ + std::vector& scales, \ + std::vector& zero_points, \ + size_t axis, bool symmetric) { \ + using UnpackedType = typename INT4x2_TYPE::UnpackedType; \ + const auto f32_ranges = input_def.GetRangePerChannel(axis); \ + const size_t num_ranges = f32_ranges.size(); \ + \ + scales.resize(num_ranges); \ + zero_points.resize(INT4x2_TYPE::CalcNumInt4Pairs(num_ranges)); \ + \ + for (size_t i = 0; i < num_ranges; i++) { \ + const auto& range = f32_ranges[i]; \ + QuantParams params = QuantParams::Compute(range.first, range.second, \ + INT4x2_TYPE::min_val, \ + INT4x2_TYPE::max_val, symmetric); \ + scales[i] = params.scale; \ + \ + size_t r = i >> 1; \ + size_t c = i & 0x1; \ + zero_points[r].SetElem(c, params.zero_point); \ + } \ + } + +DEF_GET_INPUT_QPARAMS_PER_CHAN_INT4_FUNC(Int4x2) +DEF_GET_INPUT_QPARAMS_PER_CHAN_INT4_FUNC(UInt4x2) + template static void QuantizeValues(gsl::span input, gsl::span output, const TensorShape& shape, gsl::span scales, gsl::span zero_points, @@ -332,6 +372,52 @@ static void QuantizeValues(gsl::span input, gsl::span \ + inline void QuantizeValues(gsl::span input, \ + gsl::span output, \ + const TensorShape& shape, \ + gsl::span scales, \ + gsl::span zero_points, \ + std::optional axis) { \ + using UnpackedType = typename INT4x2_TYPE::UnpackedType; \ + const size_t input_rank = shape.NumDimensions(); \ + const size_t num_int4_elems = static_cast(shape.Size()); \ + ORT_ENFORCE(input.size() == num_int4_elems); \ + ORT_ENFORCE(output.size() == INT4x2_TYPE::CalcNumInt4Pairs(num_int4_elems)); \ + \ + size_t block_count = 1; \ + size_t broadcast_dim = 1; \ + size_t block_size = num_int4_elems; \ + \ + if (axis.has_value()) { \ + size_t axis_no_neg = *axis < 0 ? static_cast(*axis) + input_rank : static_cast(*axis); \ + block_count = shape.SizeToDimension(axis_no_neg); \ + broadcast_dim = shape[axis_no_neg]; \ + block_size = shape.SizeFromDimension(axis_no_neg + 1); \ + } \ + \ + ORT_ENFORCE(scales.size() == broadcast_dim); \ + ORT_ENFORCE(zero_points.empty() || zero_points.size() == INT4x2_TYPE::CalcNumInt4Pairs(broadcast_dim)); \ + \ + size_t i = 0; \ + \ + for (size_t n = 0; n < block_count; n++) { \ + for (size_t bd = 0; bd < broadcast_dim; bd++) { \ + size_t bd_i = bd >> 1; /* bd / 2 */ \ + size_t bd_j = bd & 0x1; /* bd % 2 */ \ + UnpackedType zp = !zero_points.empty() ? zero_points[bd_i].GetElem(bd_j) : 0; \ + QUANT_FUNC(&input[i], output.data(), i, i + block_size, scales[bd], INT4x2_TYPE(zp, 0), nullptr); \ + i += block_size; \ + } \ + } \ + assert(i == (block_count * broadcast_dim * block_size)); \ + } + +DEF_QUANTIZE_VALUES_INT4_FUNC(Int4x2, ParQuantizeLinearStdS4) +DEF_QUANTIZE_VALUES_INT4_FUNC(UInt4x2, ParQuantizeLinearStdU4) + /** * Inferences a given serialized model. Returns output values via an out-param. * @@ -414,6 +500,10 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); + + // Uncomment to dump LOGGER() output to stdout. + // logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(log_severity); // Create float model and serialize it to a string. diff --git a/onnxruntime/test/providers/qnn/reduce_op_test.cc b/onnxruntime/test/providers/qnn/reduce_op_test.cc index e39ba5fb40cf7..13173d9a87f55 100644 --- a/onnxruntime/test/providers/qnn/reduce_op_test.cc +++ b/onnxruntime/test/providers/qnn/reduce_op_test.cc @@ -60,30 +60,42 @@ static GetTestModelFn BuildReduceOpTestCase(const std::string& reduce_op_type, } /** - * Runs a ReduceOp model on the QNN CPU backend. Checks the graph node assignment, and that inference + * Runs a ReduceOp model on the QNN CPU/NPU backend. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. * * \param op_type The ReduceOp type (e.g., ReduceSum). * \param input_def The input definition (shape, data, etc.) * \param axes The axes of reduction. + * \param keepdims Common attribute for all reduce operations. * \param opset The opset version. Some opset versions have "axes" as an attribute or input. * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None) - * \param keepdims Common attribute for all reduce operations. + * \param fp32_abs_err Error tolerance. + * \param enable_fp16 Enable fp32 model with FP16 precision on NPU. */ template -static void RunReduceOpCpuTest(const std::string& op_type, - const TestInputDef& input_def, - const std::vector& axes, - bool keepdims, - int opset, - ExpectedEPNodeAssignment expected_ep_assignment, - float fp32_abs_err = 1e-5f) { +static void RunReduceTest(const std::string& op_type, + const TestInputDef& input_def, + const std::vector& axes, + bool keepdims, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment, + float fp32_abs_err = 1e-5f, + bool enable_fp16 = false) { ProviderOptions provider_options; + if (enable_fp16) { #if defined(_WIN32) - provider_options["backend_path"] = "QnnCpu.dll"; + provider_options["backend_path"] = "QnnHtp.dll"; #else - provider_options["backend_path"] = "libQnnCpu.so"; + provider_options["backend_path"] = "libQnnHtp.so"; #endif + provider_options["enable_htp_fp16_precision"] = "1"; + } else { +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + } RunQnnModelTest(BuildReduceOpTestCase(op_type, input_def, //{2, 2}, // input shape @@ -107,12 +119,12 @@ static void RunReduceOpCpuTest(const std::string& op_type, // - The input and output data type is int32. // - Uses opset 13, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceSumOpset13_Int32) { - RunReduceOpCpuTest("ReduceSum", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 13, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceSum", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); } // Test creates a graph with a ReduceSum node, and checks that all @@ -121,12 +133,12 @@ TEST_F(QnnCPUBackendTests, ReduceSumOpset13_Int32) { // - The input and output data type is int32. // - Uses opset 11, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceSumOpset11_Int32) { - RunReduceOpCpuTest("ReduceSum", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 11, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceSum", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 11, + ExpectedEPNodeAssignment::All); } // Test creates a graph with a ReduceSum node, and checks that all @@ -135,12 +147,12 @@ TEST_F(QnnCPUBackendTests, ReduceSumOpset11_Int32) { // - The input and output data type is float. // - Uses opset 13, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceSumOpset13_Float) { - RunReduceOpCpuTest("ReduceSum", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 13, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceSum", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); } // Test creates a graph with a ReduceSum node, and checks that all @@ -149,12 +161,12 @@ TEST_F(QnnCPUBackendTests, ReduceSumOpset13_Float) { // - The input and output data type is float. // - Uses opset 11, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceSumOpset11_Float) { - RunReduceOpCpuTest("ReduceSum", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 11, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceSum", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 11, + ExpectedEPNodeAssignment::All); } // @@ -167,24 +179,24 @@ TEST_F(QnnCPUBackendTests, ReduceSumOpset11_Float) { // - The input and output data type is float. // - Uses opset 18, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceProdOpset18) { - RunReduceOpCpuTest("ReduceProd", - TestInputDef({2, 2}, false, {-10.0f, -8.2f, 0.0f, 10.0f}), - std::vector{0, 1}, - true, // keepdims - 18, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceProd", + TestInputDef({2, 2}, false, {-10.0f, -8.2f, 0.0f, 10.0f}), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All); } // TODO: Investigate slight inaccuracy. x64 Windows/Linux require a slightly larger error tolerance greater than 1.5e-5f. // LOG: ... the value pair (208.881729, 208.881744) at index #0 don't match, which is 1.52588e-05 from 208.882 TEST_F(QnnCPUBackendTests, ReduceProdOpset18_SlightlyInaccurate_WindowsLinuxX64) { - RunReduceOpCpuTest("ReduceProd", - TestInputDef({2, 2}, false, {3.21289f, -5.9981f, -1.72799f, 6.27263f}), - std::vector{0, 1}, - true, // keepdims - 18, - ExpectedEPNodeAssignment::All, - 2e-5f); // x64 Linux & Windows require larger tolerance. + RunReduceTest("ReduceProd", + TestInputDef({2, 2}, false, {3.21289f, -5.9981f, -1.72799f, 6.27263f}), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All, + 2e-5f); // x64 Linux & Windows require larger tolerance. } // Test creates a graph with a ReduceProd node, and checks that all @@ -193,12 +205,12 @@ TEST_F(QnnCPUBackendTests, ReduceProdOpset18_SlightlyInaccurate_WindowsLinuxX64) // - The input and output data type is float. // - Uses opset 13, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceProdOpset13) { - RunReduceOpCpuTest("ReduceProd", - TestInputDef({2, 2}, false, {-10.0f, -8.2f, 0.0f, 10.0f}), - std::vector{0, 1}, - true, // keepdims - 13, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceProd", + TestInputDef({2, 2}, false, {-10.0f, -8.2f, 0.0f, 10.0f}), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); } // @@ -211,12 +223,12 @@ TEST_F(QnnCPUBackendTests, ReduceProdOpset13) { // - The input and output data type is float. // - Uses opset 18, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceMaxOpset18) { - RunReduceOpCpuTest("ReduceMax", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 18, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceMax", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All); } // Test creates a graph with a ReduceMax node, and checks that all @@ -225,12 +237,12 @@ TEST_F(QnnCPUBackendTests, ReduceMaxOpset18) { // - The input and output data type is float. // - Uses opset 13, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceMaxOpset13) { - RunReduceOpCpuTest("ReduceMax", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 13, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceMax", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); } // @@ -243,12 +255,12 @@ TEST_F(QnnCPUBackendTests, ReduceMaxOpset13) { // - The input and output data type is float. // - Uses opset 18, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceMinOpset18) { - RunReduceOpCpuTest("ReduceMin", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 18, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceMin", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All); } // Test creates a graph with a ReduceMin node, and checks that all @@ -257,12 +269,12 @@ TEST_F(QnnCPUBackendTests, ReduceMinOpset18) { // - The input and output data type is float. // - Uses opset 13, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceMinOpset13) { - RunReduceOpCpuTest("ReduceMin", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 13, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceMin", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); } // @@ -275,12 +287,12 @@ TEST_F(QnnCPUBackendTests, ReduceMinOpset13) { // - The input and output data type is float. // - Uses opset 18, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceMeanOpset18) { - RunReduceOpCpuTest("ReduceMean", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 18, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceMean", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All); } // Test creates a graph with a ReduceMean node, and checks that all @@ -289,16 +301,33 @@ TEST_F(QnnCPUBackendTests, ReduceMeanOpset18) { // - The input and output data type is float. // - Uses opset 13, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceMeanOpset13) { - RunReduceOpCpuTest("ReduceMean", - TestInputDef({2, 2}, false, -10.0f, 10.0f), - std::vector{0, 1}, - true, // keepdims - 13, - ExpectedEPNodeAssignment::All); + RunReduceTest("ReduceMean", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Test creates a graph with a ReduceSum node, and checks that all nodes are supported by the QNN EP +// HTP backend with FP16 precision, and that the inference results match the CPU EP results. +// +// Failed QNN Opvalidation because of 5D input. It runs OK if bypass the op validation +TEST_F(QnnHTPBackendTests, DISABLED_ReduceSumOpset11_5D_FP16) { + float fp32_abs_err = 3e-2f; + bool enable_fp16 = true; + RunReduceTest("ReduceSum", + TestInputDef({1, 12, 249, 2, 4}, false, -10.0f, 10.0f), + std::vector{-1}, + false, // keepdims + 13, + ExpectedEPNodeAssignment::All, + fp32_abs_err, + enable_fp16); +} + // Creates the following graph if axes is an input (newer opsets): // _______________________ // input (f32) -> Q -> DQ -> | | -> Q -> DQ -> output (f32) diff --git a/onnxruntime/test/providers/qnn/topk_op_test.cc b/onnxruntime/test/providers/qnn/topk_op_test.cc index 93e725af5f20e..5a9351b9366ec 100644 --- a/onnxruntime/test/providers/qnn/topk_op_test.cc +++ b/onnxruntime/test/providers/qnn/topk_op_test.cc @@ -18,27 +18,18 @@ namespace test { template inline GetTestModelFn BuildTopKTestCase(const TestInputDef& input_def, const TestInputDef& k_def, - const std::vector& attrs, - bool cast_output_indices = true) { - return [input_def, k_def, attrs, cast_output_indices](ModelTestBuilder& builder) { + const std::vector& attrs) { + return [input_def, k_def, attrs](ModelTestBuilder& builder) { NodeArg* input = MakeTestInput(builder, input_def); NodeArg* k_input = MakeTestInput(builder, k_def); NodeArg* values_output = builder.MakeOutput(); - NodeArg* indices_output = cast_output_indices ? builder.MakeIntermediate() : builder.MakeOutput(); + NodeArg* indices_output = builder.MakeOutput(); Node& topk_node = builder.AddNode("TopK", {input, k_input}, {values_output, indices_output}); for (const auto& attr : attrs) { topk_node.AddAttributeProto(attr); } - - // Cast indices to uint32 - if (cast_output_indices) { - auto* uint32_indices_output = builder.MakeOutput(); - Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); - const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; - cast_node.AddAttribute("to", static_cast(dst_type)); - } }; } @@ -58,7 +49,7 @@ static void RunTopKTestOnCPU(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs, false /*cast_output_indices*/), + RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs), provider_options, opset, expected_ep_assignment); @@ -131,26 +122,19 @@ GetTestQDQModelFn BuildQDQTopKTestCase(const TestInputDef& inp // K input NodeArg* k_input = MakeTestInput(builder, k_def); - // Reshape op + // TopK_values_output -> Q -> DQ -> output + // NOTE: Create output QDQ nodes before the TopK node so that TopK's 'values' output is the graph's first output. NodeArg* values_output = builder.MakeIntermediate(); - NodeArg* indices_output = builder.MakeIntermediate(); + output_qparams[0] = input_qparams; // Input and output qparams must be equal. + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + // TopK node + NodeArg* indices_output = builder.MakeOutput(); Node& topk_node = builder.AddNode("TopK", {input_qdq, k_input}, {values_output, indices_output}); for (const auto& attr : attrs) { topk_node.AddAttributeProto(attr); } - - // op_output -> Q -> DQ -> output - // NOTE: Input and output quantization parameters must be equal for Reshape. - output_qparams[0] = input_qparams; // Overwrite! - AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, - input_qparams.zero_point, use_contrib_qdq); - - // Cast indices to uint32 (HTP backend does not support int64 graph outputs) - auto* uint32_indices_output = builder.MakeOutput(); - Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); - const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; - cast_node.AddAttribute("to", static_cast(dst_type)); }; } @@ -171,7 +155,7 @@ static void RunQDQTopKTestOnHTP(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs, true /*cast_output_indices*/); + auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs); auto qdq_model_builder = BuildQDQTopKTestCase(input_def, k_def, attrs, use_contrib_qdq); TestQDQModelAccuracy(f32_model_builder, qdq_model_builder, @@ -189,18 +173,12 @@ TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) { } // Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis -// TODO: Inaccuracy detected for output 'output_0', element 6. -// Output quant params: scale=0.00061036087572574615, zero_point=32768. -// Expected val: -7.2340402603149414 -// QNN QDQ val: -17.446556091308594 (err 10.212515830993652) -// CPU QDQ val: -7.2339968681335449 (err 4.3392181396484375e-05) -TEST_F(QnnHTPBackendTests, DISABLED_TopK_LargestFloats_U16_LastAxis) { +TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U16_LastAxis) { RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)), TestInputDef({1}, true /* is_initializer */, {2}), {}, // Attributes ExpectedEPNodeAssignment::All, - 19, // opset - true); // Use com.microsoft Q/DQ ops + 21); // opset } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 4d2538c947dcc..2b5b82d0fc16a 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -66,7 +66,7 @@ void VerifyOutputs(const std::vector& fetches, const std::vector dims, bool add_non_zero_node = false) { @@ -165,7 +165,7 @@ void RunSession2(InferenceSession& session_object, VerifyOutputs(fetches, expected_dims, expected_values); } -void RunWithOneSessionSingleThreadInference(std::string model_name, std::string sess_log_id) { +void RunWithOneSessionSingleThreadInference(PathString model_name, std::string sess_log_id) { SessionOptions so; so.session_logid = sess_log_id; RunOptions run_options; @@ -222,7 +222,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); } -void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string sess_log_id, bool has_non_zero_node = false) { +void RunWithOneSessionMultiThreadsInference(PathString model_name, std::string sess_log_id, bool has_non_zero_node = false) { SessionOptions so; so.session_logid = sess_log_id; RunOptions run_options; @@ -289,7 +289,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) { std::vector threads; - std::string model_name = "trt_execution_provider_multithreading_test.onnx"; + PathString model_name = ORT_TSTR("trt_execution_provider_multithreading_test.onnx"); std::string graph_name = "multithreading_test"; std::string sess_log_id = "TRTEPMultiThreadingTestWithOneSessionSingleThread"; std::vector dims = {1, 3, 2}; @@ -305,7 +305,7 @@ TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceW } TEST(TensorrtExecutionProviderTest, SessionCreationWithSingleThreadAndInferenceWithMultiThreads) { - std::string model_name = "trt_execution_provider_multithreading_test.onnx"; + PathString model_name = ORT_TSTR("trt_execution_provider_multithreading_test.onnx"); std::string graph_name = "multithreading_test"; std::string sess_log_id = "TRTEPMultiThreadingTestWithOneSessionMultiThreads"; std::vector dims = {1, 3, 2}; @@ -360,7 +360,7 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { } TEST(TensorrtExecutionProviderTest, EPContextNode) { - std::string model_name = "EPContextNode_test.onnx"; + PathString model_name = ORT_TSTR("EPContextNode_test.onnx"); std::string graph_name = "EPContextNode_test"; std::string sess_log_id = "EPContextNode_test"; std::vector dims = {1, 3, 2}; @@ -461,7 +461,7 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { */ InferenceSession session_object3{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params3; - model_name = params.trt_ep_context_file_path; + model_name = ToPathString(params.trt_ep_context_file_path); params3.trt_engine_cache_enable = 1; execution_provider = TensorrtExecutionProviderWithOptions(¶ms3); EXPECT_TRUE(session_object3.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); @@ -490,7 +490,7 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { */ InferenceSession session_object4{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params4; - model_name = "./context_model_folder/EPContextNode_test_ctx.onnx"; + model_name = ORT_TSTR("./context_model_folder/EPContextNode_test_ctx.onnx"); execution_provider = TensorrtExecutionProviderWithOptions(¶ms4); EXPECT_TRUE(session_object4.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); status = session_object4.Load(model_name); @@ -514,7 +514,7 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { params5.trt_dump_ep_context_model = 1; params5.trt_ep_context_embed_mode = 1; params5.trt_ep_context_file_path = "EP_Context_model_2.onnx"; - model_name = "EPContextNode_test.onnx"; + model_name = ORT_TSTR("EPContextNode_test.onnx"); execution_provider = TensorrtExecutionProviderWithOptions(¶ms5); EXPECT_TRUE(session_object5.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); status = session_object5.Load(model_name); @@ -528,7 +528,7 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { InferenceSession session_object6{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params6; params6.trt_ep_context_embed_mode = 1; - model_name = params5.trt_ep_context_file_path; + model_name = ToPathString(params5.trt_ep_context_file_path); execution_provider = TensorrtExecutionProviderWithOptions(¶ms6); EXPECT_TRUE(session_object6.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); status = session_object6.Load(model_name); @@ -546,7 +546,7 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { } TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { - std::string model_name = "testdata/trt_plugin_custom_op_test.onnx"; + PathString model_name = ORT_TSTR("testdata/trt_plugin_custom_op_test.onnx"); SessionOptions so; so.session_logid = "TensorrtExecutionProviderTRTPluginsTest"; RunOptions run_options; @@ -575,7 +575,6 @@ TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { OrtTensorRTProviderOptionsV2 params; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - std::cout << model_name << std::endl; auto status = session_object.Load(model_name); ASSERT_TRUE(status.IsOK()); status = session_object.Initialize(); @@ -591,9 +590,12 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { size_t pos = param.find("_"); std::string input_type = param.substr(pos + 1); ASSERT_NE(pos, std::string::npos); - std::string cache_type = ToUTF8String(param.substr(0, pos)); - - std::string model_name = "trt_execution_provider_" + cache_type + "caching_test_" + input_type + ".onnx"; + std::string cache_type_mbs = param.substr(0, pos); + PathString cache_type = ToPathString(cache_type_mbs); + std::basic_ostringstream oss; + oss << ORT_TSTR("trt_execution_provider_") << cache_type << ORT_TSTR("_caching_test_") << ToPathString(input_type) + << ORT_TSTR(".onnx"); + PathString model_name = oss.str(); std::vector dims; if (input_type.compare("dynamic") == 0) { dims = {1, -1, -1}; // dynamic shape input @@ -601,10 +603,10 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { dims = {1, 3, 2}; } - CreateBaseModel(model_name, cache_type + "cachingtest", dims); + CreateBaseModel(model_name, cache_type_mbs + "cachingtest", dims); SessionOptions so; - so.session_logid = "TensorrtExecutionProvider" + cache_type + "cacheTest"; + so.session_logid = "TensorrtExecutionProvider" + cache_type_mbs + "cacheTest"; RunOptions run_options; run_options.run_tag = so.session_logid; InferenceSession session_object{so, GetEnvironment()}; @@ -633,7 +635,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { std::vector expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; OrtTensorRTProviderOptionsV2 params; - if (cache_type.compare("engine") == 0) { + if (cache_type_mbs.compare("engine") == 0) { /* Following code block tests the functionality of engine and optimization profile of ORT TRT, including: * - engine cache serialization/de-serialization * - profile cache serialization/de-serialization @@ -807,7 +809,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { } } - } else if (cache_type.compare("timing") == 0) { + } else if (cache_type_mbs.compare("timing") == 0) { /* Following code block tests the functionality of timing cache, including: * - timing cache cache serialization/de-serialization * - TODO: benefir of usign a timing cache no matter if dynamic / static input @@ -917,7 +919,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "trt_execution_provider_function_test.onnx"; + PathString model_file_name = ORT_TSTR("trt_execution_provider_function_test.onnx"); status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -1019,7 +1021,7 @@ TEST(TensorrtExecutionProviderTest, DISABLED_NodeIndexMappingTest) { // [W:onn auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "trt_execution_provider_nodeindexmapping_test.onnx"; + PathString model_file_name = ORT_TSTR("trt_execution_provider_nodeindexmapping_test.onnx"); status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -1131,7 +1133,7 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "trt_execution_provider_removecycle_test.onnx"; + PathString model_file_name = ORT_TSTR("trt_execution_provider_removecycle_test.onnx"); status = onnxruntime::Model::Save(model, model_file_name); std::vector dims_mul_x = {1, 3, 2}; diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index eca1430448e8e..29680c98fb4de 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -594,6 +594,55 @@ def test_dequantize_linear_ms_domain(self): ] self._check_shapes(graph, inferred.graph, expected_shapes) + def test_matmulnbits(self): + """ + Test ORT MatMulNBits op. + Check that the output shape is propagated from the inputs and that the output data + type comes from the first input. + """ + b_np = numpy.random.randint(0, 255, (4, 1, 8), numpy.uint8) + b = numpy_helper.from_array(b_np, name="b") + scale_np = numpy.random.rand(4).astype(numpy.float32) + scale = numpy_helper.from_array(scale_np, name="scale") + zero_point_np = numpy.random.randint(0, 255, (4), numpy.uint8) + zero_point = numpy_helper.from_array(zero_point_np, name="zero_point") + + initializers = [b, scale, zero_point] + + kwargs = {"K": 10, "N": 4, "block_size": 16} + + nodes = [ + helper.make_node( + "MatMulNBits", + inputs=[ + "input_f32", + "b", + "scale", + "zero_point", + ], + outputs=["output_f32"], + **kwargs, + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["x", 2, 3, 10]), + ] + + outputs = [ + helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "MatMulNBits_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["x", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index b30282f2ab41f..cf7fc292ea86b 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -217,10 +217,13 @@ def rewind(self): self.iter_next = iter(self.data_feeds) -def input_feeds_neg_one_zero_one(n, name2shape): +def input_feeds_neg_one_zero_one(n, name2shape, seed=None): """ randomize n feed according to shape, its values are from -1, 0, and 1 """ + if seed is not None: + np.random.seed(seed) + input_data_list = [] for _i in range(n): inputs = {} @@ -231,6 +234,120 @@ def input_feeds_neg_one_zero_one(n, name2shape): return dr +def input_feeds_neg_one_zero_one_list(n, name2shape, seed=None): + """ + randomize n feed according to shape, its values are from -1, 0, and 1 + """ + if seed is not None: + np.random.seed(seed) + + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + return input_data_list + + +class GenerateCalibrationData(CalibrationDataReader): + def __init__(self, data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last=None): + print("Generating calibration dataset from " + str(data_list)) + print("input nodes are ", input_nodes, "input shapes are ", input_shapes) + if inputs_conv_channel_last: + print(f"Inputs that will be converted to channel last: {inputs_conv_channel_last}") + + self.enum_data_dicts = [] + self.input_nodes = input_nodes + self.input_shapes = input_shapes + self.inputs_conv_channel_last = inputs_conv_channel_last + self.calibration_dataset = data_list + + def __len__(self): + return len(self.calibration_dataset) + + def get_next(self): + feed_dict = {} + inp = next(self.calibration_dataset, None) + if inp is not None: + for i in range(len(self.input_nodes)): + input_data = inp[i].reshape(self.input_shapes[i]) + if self.inputs_conv_channel_last is not None and self.input_nodes[i] in self.inputs_conv_channel_last: + input_data = np.moveaxis(input_data, 1, -1) + dict_item = {self.input_nodes[i]: input_data} + feed_dict.update(dict_item) + return feed_dict + else: + return None + + +class StridedDataReader(GenerateCalibrationData): + def __init__( + self, + data_list, + input_nodes, + input_shapes, + no_tensor_num, + in_dtypes, + inputs_conv_channel_last=None, + stride=1, + start_index=0, + end_index=None, + ): + super().__init__(data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last) + + self.stride = max(1, stride) # Ensure stride is at least 1 + self.start_index = start_index + self.end_index = ( + end_index if end_index is not None else len(self.calibration_dataset) + ) # Default to the end of the dataset + self.enum_data_dicts = iter([]) + + def get_next(self): + iter_data = next(self.enum_data_dicts, None) + if iter_data: + return iter_data + + self.enum_data_dicts = None + if self.start_index < self.end_index: + print(f"start index is {self.start_index}") + data = self.load_serial() + + self.start_index += self.stride + self.enum_data_dicts = iter(data) + + return next(self.enum_data_dicts, None) + else: + return None + + def load_serial(self): + batch_data = [] + end_loop = min(self.end_index, self.start_index + self.stride) + for i in range(self.start_index, end_loop): + print(f"debugging the load serial index {i}") + data_item = self.calibration_dataset[i] + processed_item = self.process_data_item(data_item) + batch_data.append(processed_item) + return batch_data + + def process_data_item(self, data_item): + feed_dict = {} + for _, node in enumerate(self.input_nodes): + # input_data = data_item[i].reshape(self.input_shapes[i]) + feed_dict[node] = data_item["input"] + return feed_dict + + def set_range(self, start_index, end_index=None): + self.start_index = start_index + self.end_index = end_index if end_index is not None else len(self.calibration_dataset) + self.enum_data_dicts = iter([]) + + def rewind(self): + """Rewind the data reader to the beginning of the dataset.""" + self.start_index = 0 + self.enum_data_dicts = iter([]) + + def check_op_type_order(testcase, model_to_check, ops): if isinstance(model_to_check, str): model = onnx.load(model_to_check) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 88e5052db4e2e..4cc8a0c151d14 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -14,7 +14,7 @@ import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type from onnxruntime.quantization import quant_utils @@ -105,8 +105,9 @@ def make_matmul( [output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + # blocked quantization requires DQ op set >= 21 + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 # use stable onnx ir version onnx.save(model, output_model_path) @@ -116,9 +117,12 @@ def quant_test( data_reader: TestDataFeeds, block_size: int, is_symmetric: bool, + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, ): + use_qdq = quant_format == quant_utils.QuantFormat.QDQ + name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits" model_int4_path = str( - Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + Path(self._tmp_model_dir.name).joinpath(f"{name_prefix}_{block_size}_{is_symmetric}.onnx").absolute() ) # Quantize fp32 model to int4 model @@ -126,15 +130,33 @@ def quant_test( model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric + block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format ) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) - quant_nodes = {"MatMulNBits": 1} + quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} check_op_type_count(self, model_int4_path, **quant_nodes) + if use_qdq: + dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4 + dqnode_io_qtypes = ( + { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } + if is_symmetric + else { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + ) + check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes) + data_reader.rewind() try: @@ -211,6 +233,26 @@ def test_quantize_matmul_int4_offsets(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_symmetric_qdq(self): + np.random.seed(13) + + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=True) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_offsets_qdq(self): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ) + @unittest.skipIf( find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 7b3fc08982ac1..96d841654adbd 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -37,41 +37,45 @@ def _compute_scale_zp(rmin, rmax, qmin, qmax, qtype, symmetric=False, min_real_r assert isinstance(scale, numpy.ndarray) return [float(zp), float(scale)] - self.assertEqual(_compute_scale_zp(0.0, 0.0, -127, 127, numpy.int8, symmetric=True), [0, 1.0]) - self.assertEqual(_compute_scale_zp(1.0, -1.0, -127, 127, numpy.int8, symmetric=True), [0, 1.0]) - self.assertEqual(_compute_scale_zp(0.0, 0.0, 0, 255, numpy.uint8, symmetric=True), [0, 1.0]) - self.assertEqual(_compute_scale_zp(1.0, -1.0, 0, 255, numpy.uint8, symmetric=True), [0, 1.0]) + numpy.testing.assert_allclose(_compute_scale_zp(0.0, 0.0, -127, 127, numpy.int8, symmetric=True), [0, 1.0]) + numpy.testing.assert_allclose(_compute_scale_zp(1.0, -1.0, -127, 127, numpy.int8, symmetric=True), [0, 1.0]) + numpy.testing.assert_allclose(_compute_scale_zp(0.0, 0.0, 0, 255, numpy.uint8, symmetric=True), [0, 1.0]) + numpy.testing.assert_allclose(_compute_scale_zp(1.0, -1.0, 0, 255, numpy.uint8, symmetric=True), [0, 1.0]) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(-1.0, 2.0, -127, 127, numpy.int8, symmetric=True), [0, numpy.float32(2.0 / 127)] ) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(-1.0, 2.0, -127, 127, numpy.int8, symmetric=False), [-42, numpy.float32(3.0 / 254)] ) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(-1.0, 2.0, 0, 255, numpy.uint8, symmetric=True), [128, numpy.float32(4.0 / 255)] ) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(-1.0, 2.0, 0, 255, numpy.uint8, symmetric=False), [85, numpy.float32(3.0 / 255)] ) tiny_float = numpy.float32(numpy.finfo(numpy.float32).tiny * 0.1) - self.assertEqual(_compute_scale_zp(-tiny_float, tiny_float, 0, 255, numpy.uint8, symmetric=True), [0, 1.0]) - self.assertEqual(_compute_scale_zp(-tiny_float, 0.0, 0, 255, numpy.uint8, symmetric=False), [0, 1.0]) + numpy.testing.assert_allclose( + _compute_scale_zp(-tiny_float, tiny_float, 0, 255, numpy.uint8, symmetric=True), [0, 1.0] + ) + numpy.testing.assert_allclose( + _compute_scale_zp(-tiny_float, 0.0, 0, 255, numpy.uint8, symmetric=False), [0, 1.0] + ) # Test enforcing a minimum floatint-point range. - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(0.0, 0.0, 0, 255, numpy.uint8, symmetric=False, min_real_range=0.0001), [0, 0.0001 / 255] ) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(0.0, 0.0, -128, 127, numpy.int8, symmetric=True, min_real_range=0.0001), [0, 0.0002 / 255] ) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(0.0, 0.0, 0, 65535, numpy.uint16, symmetric=False, min_real_range=0.0001), [0, 0.0001 / 65535], ) - self.assertEqual( + numpy.testing.assert_allclose( _compute_scale_zp(0.0, 0.0, -32768, 32767, numpy.int16, symmetric=True, min_real_range=0.0001), [0, 0.0002 / 65535], ) diff --git a/onnxruntime/test/python/quantization/test_quantize_static.py b/onnxruntime/test/python/quantization/test_quantize_static.py index 5ad5a49f00c14..01976ba633137 100644 --- a/onnxruntime/test/python/quantization/test_quantize_static.py +++ b/onnxruntime/test/python/quantization/test_quantize_static.py @@ -13,8 +13,15 @@ import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import check_model_correctness, generate_random_initializer, input_feeds_neg_one_zero_one - +from op_test_utils import ( + StridedDataReader, + check_model_correctness, + generate_random_initializer, + input_feeds_neg_one_zero_one, + input_feeds_neg_one_zero_one_list, +) + +import onnxruntime as ort from onnxruntime.quantization import QuantType, StaticQuantConfig, quantize, quantize_static @@ -89,6 +96,51 @@ def test_save_as_external(self): check_model_correctness(self, self._model_fp32_path, quant_model_path, data_reader.get_next()) data_reader.rewind() + def run_inference(self, model_path, input_data): + session = ort.InferenceSession(model_path) + input_name = session.get_inputs()[0].name + output_name = session.get_outputs()[0].name + result = session.run([output_name], {input_name: input_data}) + return result + + def test_stride_effect_on_data_collection(self): + # Define the stride and test quantize_static with different stride values + stride = 5 + input_shapes = [1, self._channel_size, 1, 3] + data_list = input_feeds_neg_one_zero_one_list(10, {"input": [1, self._channel_size, 1, 3]}, 123) + input_nodes = ["input"] + in_dtypes = [np.float32] # Example dtype, adjust as needed + + # strided calibration + quant_model_path_1 = str(Path(self._tmp_model_dir.name) / "quant.strided.onnx") + data_reader_1 = StridedDataReader( + data_list, input_nodes, input_shapes, no_tensor_num=0, in_dtypes=in_dtypes, stride=stride + ) + quant_config_1 = StaticQuantConfig(data_reader_1, extra_options={"CalibStridedMinMax": stride}) + quantize(str(self._model_fp32_path), str(quant_model_path_1), quant_config_1) + + # non-strided calibration + quant_model_path_2 = str(Path(self._tmp_model_dir.name) / "quant.non.strided.onnx") + data_reader_2 = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}, 123) + quant_config_2 = StaticQuantConfig(data_reader_2) + quantize(str(self._model_fp32_path), str(quant_model_path_2), quant_config_2) + + # Inference with both models and assert output closeness + np.random.seed(123) + input_data = np.random.choice([-1, 0, 1], size=[1, self._channel_size, 1, 3]).astype(np.float32) + + result_1 = self.run_inference(quant_model_path_1, input_data) + result_2 = self.run_inference(quant_model_path_2, input_data) + + # Assert that the outputs are close + np.testing.assert_allclose( + result_1, + result_2, + rtol=0.01, + atol=0.01, + err_msg="Outputs from strided and non-strided models are not close enough.", + ) + def test_static_quant_config(self): data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}) quant_config = StaticQuantConfig(data_reader) diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py index 9e9d05fae027d..eafab0c03a951 100644 --- a/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py @@ -87,7 +87,7 @@ def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, qu absmax[block_idx] = block_absmax if block_len % 2 != 0: - block = np.append(block, 0.0) + block = np.append(block, np.float32(0.0)) block_len += 1 block *= reciprocal_absmax @@ -131,8 +131,8 @@ def test_quantize_blockwise_bnb4(self): matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type) quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type) quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type) - assert np.allclose(quant_value_ref, quant_value) - assert np.allclose(absmax_ref, absmax) + np.testing.assert_allclose(quant_value_ref, quant_value) + np.testing.assert_allclose(absmax_ref, absmax) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py new file mode 100644 index 0000000000000..b781ccf03f138 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -0,0 +1,221 @@ +import argparse +import os +import time +from typing import Optional + +import torch +from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention + + +def save_results(results, filename): + import pandas as pd + + df = pd.DataFrame( + results, + columns=[ + "Inference Interval (ms)", + "Throughput (samples/second)", + "Batch Size", + "Max Sequence Length", + "Sequence Length", + "Past Sequence Length", + "Model Name", + ], + ) + # df = df.transpose() # This line swaps the rows and columns + df.to_csv(filename, header=True, index=False) + print(f"Results saved in {filename}!") + + +def benchmark( + batch_size: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + max_seq_len: int, + sequence_length: int = 1, + past_sequence_length: int = 0, + local_window_size: Optional[int] = None, + model_name: str = "Llama3-8B", +): + warmup = 15 + repeat = 100 + + config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_seq_len, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + local_window_size=local_window_size if local_window_size else -1, + do_rotary=True, # Most models use rotary positional embeddings + is_packed_qkv=model_name in ["Phi-3-mini-128k", "Phi-3-small-128k"], + device="cuda", + ) + + obj = OrtGroupQueryAttention(config) + + for _ in range(warmup): + obj.infer() + + intervals = [] + for _ in range(repeat): + infer_start = time.perf_counter() + obj.infer() + infer_interval = time.perf_counter() - infer_start + intervals.append(infer_interval) + avg_infer_interval = sum(intervals) / len(intervals) + avg_infer_interval_ms = avg_infer_interval * 1000 + print(f"Average inference interval: {avg_infer_interval_ms:.6f} milliseconds") + avg_throughput = batch_size / avg_infer_interval + print(f"Average throughput: {avg_throughput:.6f} samples/second") + + return [avg_infer_interval_ms, avg_throughput] + + +def run_performance_tests(args): + device_id = torch.cuda.current_device() + memory_in_gb = torch.cuda.get_device_properties(device_id).total_memory / (1024 * 1024 * 1024) + + configures = [ + (32, 128, 8, 8192, None, "Llama3-8B"), + (64, 128, 8, 8192, None, "Llama3-70B"), + (48, 128, 8, 32768, None, "Mixtral-8x22B-v0.1"), + (32, 96, 32, 131072, None, "Phi-3-mini-128k"), + (32, 128, 8, 65536, None, "Phi-3-small-128k"), # Sparsity is not used in this test + (40, 128, 10, 32768, None, "Phi-3-medium-128K"), + ] + if args.kernel == "flash_attention": + configures.append((32, 128, 8, 32768, 4096, "Mistral-7B-v0.1")) + + # Reduce max sequence length when GPU memory is not enough. + threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768 + + all_metrics = [] + for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: + prompt_metrics_model = [] + token_metrics_model = [] + for batch_size in [1, 4]: + # Benchmark prompt + for sequence_length in [ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + ]: + if sequence_length >= min(threshold, max_seq_len): + continue + print( + f"Prompt: batch_size={batch_size}, num_heads={num_heads}, kv_num_heads={kv_num_heads}, head_size={head_size}, sequence_length={sequence_length}, max_seq_len={max_seq_len}, local_window_size={local_window_size}, model_name={model_name}" + ) + metrics = benchmark( + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + sequence_length=sequence_length, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + model_name=model_name, + ) + metrics = [*metrics, batch_size, max_seq_len, sequence_length, 0, model_name] + prompt_metrics_model.append(metrics) + all_metrics.append(metrics) + # Benchmark token + for past_sequence_length in [ + 0, + 3, + 7, + 15, + 31, + 63, + 127, + 255, + 511, + 1023, + 2047, + 4095, + 8191, + 16383, + 32767, + 65535, + 131071, + ]: + if past_sequence_length >= min(threshold, max_seq_len): + continue + print( + f"Token: batch_size={batch_size}, num_heads={num_heads}, kv_num_heads={kv_num_heads}, head_size={head_size}, past_sequence_length={past_sequence_length}, max_seq_len={max_seq_len}, local_window_size={local_window_size}, model_name={model_name}" + ) + metrics = benchmark( + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + past_sequence_length=past_sequence_length, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + model_name=model_name, + ) + metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, model_name] + token_metrics_model.append(metrics) + all_metrics.append(metrics) + # Calculate average inference interval and throughput for each model + avg_prompt_infer_interval = sum([metrics[0] for metrics in prompt_metrics_model]) / len(prompt_metrics_model) + avg_prompt_throughput = sum([metrics[1] for metrics in prompt_metrics_model]) / len(prompt_metrics_model) + avg_token_infer_interval = sum([metrics[0] for metrics in token_metrics_model]) / len(token_metrics_model) + avg_token_throughput = sum([metrics[1] for metrics in token_metrics_model]) / len(token_metrics_model) + print(f"Average {model_name} prompt inference interval: {avg_prompt_infer_interval:.6f} milliseconds") + print(f"Average {model_name} prompt throughput: {avg_prompt_throughput:.6f} samples/second") + print(f"Average {model_name} token inference interval: {avg_token_infer_interval:.6f} milliseconds") + print(f"Average {model_name} token throughput: {avg_token_throughput:.6f} samples/second") + all_metrics.append( + [avg_prompt_infer_interval, avg_prompt_throughput, 0, max_seq_len, 0, 0, model_name + " (Average Prompt)"] + ) + all_metrics.append( + [avg_token_infer_interval, avg_token_throughput, 0, max_seq_len, 0, 0, model_name + " (Average Token)"] + ) + + save_results(all_metrics, args.output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="End-to-end benchmarking for gen-ai") + parser.add_argument( + "-o", + "--output", + type=str, + default="benchmark_results.csv", + help="Output CSV file name or path (with .csv extension)", + ) + parser.add_argument( + "-k", + "--kernel", + type=str, + default="flash_attention", + help="GQA Kernel to use for benchmarking. Options: flash_attention, memory_efficient", + ) + args = parser.parse_args() + + if args.kernel == "memory_efficient": + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + else: + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + s = torch.cuda.Stream() + with torch.cuda.stream(s), torch.no_grad(): + run_performance_tests(args) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 22578175846f7..797461bae2efd 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -441,7 +441,7 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea sequence_length=sequence_length, num_heads=num_heads, head_size=head_size, - causal=True, + causal=causal, use_kv_cache=use_kv_cache, past_sequence_length=past_sequence_length, max_cache_sequence_length=None, diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 1d49583b3f20c..84bf30b65a742 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from packaging import version from parameterized import parameterized from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -1878,11 +1879,31 @@ def parity_check_gqa_past_no_buff( numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) +def has_flash_attention(): + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 and ( + platform.system() == "Linux" + or (platform.system() == "Windows" and version.parse(torch.version.cuda) >= version.parse("12.0")) + ) + + +def has_memory_efficient(): + if not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return False + return True + + def packed_mha_test_cases(): batches = [2] if pipeline_mode else [1, 5] - seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + seqs = [1024, 1025] if pipeline_mode else [1024, 1025, 2048] num_h = [1, 3] if pipeline_mode else [1, 6, 16] h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: for s in seqs: for n in num_h: @@ -1911,6 +1932,7 @@ def mha_test_cases(): ) num_h = [1, 3] if pipeline_mode else [1, 6, 16] h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: for s, s2 in seqs: for n in num_h: @@ -1922,21 +1944,17 @@ def mha_test_cases(): class TestMHA(unittest.TestCase): @parameterized.expand(packed_mha_test_cases()) def test_packed_mha(self, _, config): - if not torch.cuda.is_available() or platform.system() != "Linux": - return - major, _ = torch.cuda.get_device_capability() - if major < 8: + if not has_flash_attention(): return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" print("-------- TEST PACKED MHA ---------") parity_check_mha(config, True) @parameterized.expand(mha_test_cases()) def test_mha(self, _, config): - if not torch.cuda.is_available() or platform.system() != "Linux": - return - major, _ = torch.cuda.get_device_capability() - if major < 8: + if not has_flash_attention(): return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" print("-------- TEST MHA ---------") parity_check_mha(config, False) @@ -2106,10 +2124,7 @@ def gqa_past_flash_attention_test_cases(): class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed): - if not torch.cuda.is_available(): - return - major, minor = torch.cuda.get_device_capability() - if major < 5 or (major == 5 and minor < 3): + if not has_memory_efficient(): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") @@ -2135,10 +2150,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave @parameterized.expand(gqa_no_past_flash_attention_test_cases()) def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - if not torch.cuda.is_available(): - return - major, _ = torch.cuda.get_device_capability() - if major < 8 or platform.system() != "Linux": + if not has_flash_attention(): return print("------- FLASH ATTENTION (PROMPT CASE) --------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" @@ -2162,10 +2174,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte @parameterized.expand(gqa_past_memory_efficient_test_cases()) def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed): - if not torch.cuda.is_available(): - return - major, minor = torch.cuda.get_device_capability() - if major < 5 or (major == 5 and minor < 3): + if not has_memory_efficient(): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") @@ -2191,10 +2200,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, @parameterized.expand(gqa_past_flash_attention_test_cases()) def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - if not torch.cuda.is_available(): - return - major, _ = torch.cuda.get_device_capability() - if major < 8 or platform.system() != "Linux": + if not has_flash_attention(): return print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py new file mode 100644 index 0000000000000..880f4175e00b7 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -0,0 +1,86 @@ +import platform +import unittest + +import torch +from parameterized import parameterized +from test_flash_attn_cuda import ( + Formats, + gqa_no_past_flash_attention_test_cases, + gqa_past_flash_attention_test_cases, + parity_check_gqa_past, + parity_check_gqa_past_no_buff, + parity_check_gqa_prompt, + parity_check_gqa_prompt_no_buff, +) + +import onnxruntime + + +class TestGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return + print("------- FLASH ATTENTION (PROMPT CASE) --------") + + parity_check_gqa_prompt( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.001, + atol=0.005, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.001, + atol=0.005, + ) + + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return + print("------- FLASH ATTENTION (TOKEN GEN) -------") + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.001, + atol=0.005, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.001, + atol=0.005, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index aa480a1af4587..be288d8b6e360 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -10,6 +10,7 @@ # license information. # ------------------------------------------------------------------------- +import platform import time import unittest @@ -375,6 +376,8 @@ def benchmark(self): class TestMoE(unittest.TestCase): def test_moe_small(self): + if platform.system() == "Windows": + pytest.skip("Skip on Windows") rt = MoE( batch_size=2, num_rows=8, diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index f33a56ee4e1f9..f18bcdba65579 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -8,14 +8,16 @@ """ import math import unittest -from typing import Optional +from typing import Optional, Union import torch +from benchmark_mha import InputFormats from onnx import TensorProto, helper +from parameterized import parameterized from torch import Tensor -from onnxruntime import InferenceSession, SessionOptions -from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager +from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from onnxruntime.transformers.io_binding_helper import CudaSession ENABLE_DEBUG = False @@ -34,6 +36,7 @@ def __init__( softmax_scale: Optional[float], do_rotary: bool, rotary_interleaved: bool, + provider: str = "CUDAExecutionProvider", device="cuda", dtype=torch.float16, share_buffer: bool = True, @@ -62,11 +65,13 @@ def __init__( self.do_rotary = do_rotary self.rotary_interleaved = rotary_interleaved + + self.provider = provider self.device = device + self.dtype = dtype self.share_buffer = share_buffer self.is_packed_qkv = is_packed_qkv - self.dtype = dtype def shape_dict(self): shapes = { @@ -106,7 +111,7 @@ def get_cos_sin_cache(self, dtype): def random_inputs(self): device = self.device # Since bfloat16 is not supported in ORT python I/O binding API, we always use float16 as model inputs. - dtype = torch.float16 + dtype = torch.float16 if self.dtype == torch.bfloat16 else self.dtype # Always use non-packed qkv to generate same inputs for Torch and ORT. packed = self.is_packed_qkv # Save the original value. @@ -153,7 +158,9 @@ def __init__( softmax_scale=None, do_rotary: bool = False, rotary_interleaved: bool = False, + provider: str = "CUDAExecutionProvider", device="cuda", + dtype=torch.float16, local_window_size: int = -1, attention_mask=None, is_packed_qkv=False, @@ -162,17 +169,19 @@ def __init__( ): super().__init__( "GroupQueryAttention", - batch_size, - sequence_length, - max_sequence_length, - past_sequence_length, - num_heads, - kv_num_heads, - head_size, - softmax_scale, - do_rotary, - rotary_interleaved, - device, + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + do_rotary=do_rotary, + rotary_interleaved=rotary_interleaved, + provider=provider, + device=device, + dtype=dtype, is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, @@ -220,24 +229,28 @@ def __init__( softmax_scale=None, do_rotary: bool = False, rotary_interleaved: bool = False, + provider: str = "CUDAExecutionProvider", device="cuda", + dtype=torch.float16, is_packed_qkv=False, max_cache_sequence_length=None, max_rotary_sequence_length=None, ): super().__init__( "SparseAttention", - batch_size, - sequence_length, - max_sequence_length, - past_sequence_length, - num_heads, - kv_num_heads, - head_size, - softmax_scale, - do_rotary, - rotary_interleaved, - device, + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + do_rotary=do_rotary, + rotary_interleaved=rotary_interleaved, + provider=provider, + device=device, + dtype=dtype, is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, @@ -288,17 +301,19 @@ def random_inputs(self): def get_comparable_ort_gqa_config(self, use_local=False) -> GroupQueryAttentionConfig: return GroupQueryAttentionConfig( - self.batch_size, - self.sequence_length, - self.max_sequence_length, - self.past_sequence_length, - self.num_heads, - self.kv_num_heads, - self.head_size, - self.softmax_scale, - self.do_rotary, - self.rotary_interleaved, - self.device, + batch_size=self.batch_size, + sequence_length=self.sequence_length, + max_sequence_length=self.max_sequence_length, + past_sequence_length=self.past_sequence_length, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + head_size=self.head_size, + softmax_scale=self.softmax_scale, + do_rotary=self.do_rotary, + rotary_interleaved=self.rotary_interleaved, + provider=self.provider, + device=self.device, + dtype=self.dtype, local_window_size=self.local_blocks * self.sparse_block_size if use_local else -1, is_packed_qkv=self.is_packed_qkv, max_cache_sequence_length=self.max_cache_sequence_length, @@ -314,17 +329,19 @@ def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttenti attention_mask = attention_mask[:, :, -self.sequence_length :, :] return GroupQueryAttentionConfig( - self.batch_size, - self.sequence_length, - self.max_sequence_length, - self.past_sequence_length, - self.num_heads, - self.kv_num_heads, - self.head_size, - self.softmax_scale, - self.do_rotary, - self.rotary_interleaved, - self.device, + batch_size=self.batch_size, + sequence_length=self.sequence_length, + max_sequence_length=self.max_sequence_length, + past_sequence_length=self.past_sequence_length, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + head_size=self.head_size, + softmax_scale=self.softmax_scale, + do_rotary=self.do_rotary, + rotary_interleaved=self.rotary_interleaved, + provider=self.provider, + device=self.device, + dtype=self.dtype, attention_mask=attention_mask, is_packed_qkv=False, # torch reference implementation does not support packed qkv. max_cache_sequence_length=self.max_cache_sequence_length, @@ -375,7 +392,7 @@ def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size): def create_sparse_attention_onnx_model(config: SparseAttentionConfig): # ORT Python I/O binding API does not support bf16, so always use fp16 as graph inputs/outputs. - io_float_type = TensorProto.FLOAT16 + io_float_type = TensorProto.FLOAT if config.dtype == torch.float32 else TensorProto.FLOAT16 suffix = "_bf16" if config.dtype == torch.bfloat16 else "" nodes = [ @@ -487,9 +504,9 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig): def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): - assert config.dtype == torch.float16 + assert config.dtype in [torch.float16, torch.float32] - float_type = TensorProto.FLOAT16 + float_type = TensorProto.FLOAT16 if config.dtype in [torch.float16] else TensorProto.FLOAT nodes = [ helper.make_node( "GroupQueryAttention", @@ -599,7 +616,10 @@ def group_query_attention_reference( attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) result = attn_output.transpose(1, 2).contiguous() - torch.cuda.synchronize() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + return result @@ -671,25 +691,42 @@ def infer(self): ) +def create_ort_session( + config: Union[SparseAttentionConfig, GroupQueryAttentionConfig], session_options=None, enable_cuda_graph=False +) -> CudaSession: + if isinstance(config, SparseAttentionConfig): + onnx_model_str = create_sparse_attention_onnx_model(config) + else: + onnx_model_str = create_group_query_attention_onnx_model(config) + + if config.provider == "CUDAExecutionProvider": + device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index + provider_options = CudaSession.get_cuda_provider_options( + device_id, enable_cuda_graph=enable_cuda_graph, stream=torch.cuda.current_stream().cuda_stream + ) + providers = [(config.provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + # Note that CudaSession could work with both CUDA and CPU providers. + cuda_session = CudaSession(ort_session, config.device, enable_cuda_graph=enable_cuda_graph) + shape_dict = config.shape_dict() + cuda_session.allocate_buffers(shape_dict) + + buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} + for input_name, output_name in buffer_sharing.items(): + cuda_session.set_buffer_sharing(input_name, output_name) + + return cuda_session + + class OrtGroupQueryAttention: """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" def __init__(self, config: GroupQueryAttentionConfig): - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_group_query_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=config.device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) + self.session = create_ort_session(config) + self.feed_dict = config.random_inputs() if ENABLE_DEBUG and not config.is_packed_qkv: @@ -709,28 +746,14 @@ def __init__(self, config: GroupQueryAttentionConfig): print("seqlens_k (BSNH, GQA)", self.feed_dict["seqlens_k"]) def infer(self): - return self.gpu_binding.infer(self.feed_dict) + return self.session.infer(self.feed_dict) class OrtSparseAttention: """A wrapper of ORT SparseAttention to test relevance and performance.""" def __init__(self, config: SparseAttentionConfig): - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_sparse_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=config.device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) + self.session = create_ort_session(config) self.feed_dict = config.random_inputs() if ENABLE_DEBUG and not config.is_packed_qkv: @@ -753,19 +776,196 @@ def __init__(self, config: SparseAttentionConfig): print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"]) def infer(self): - return self.gpu_binding.infer(self.feed_dict) + return self.session.infer(self.feed_dict) + + +def get_provider_support_info(provider: str, use_kv_cache: bool): + if provider == "CUDAExecutionProvider": + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.QKV_BSN3H] + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + else: + assert provider == "CPUExecutionProvider" + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.QKV_BSN3H] + device = torch.device("cpu") + dtype = torch.float + return device, dtype, formats + + +def has_cuda_support(): + if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm in [75, 80, 86, 89, 90] + + return False + + +def get_simple_test_case(provider: str, has_past_kv: bool): + """A simple test case for debugging purpose.""" + device, dtype, _formats = get_provider_support_info(provider, False) + if provider == "CPUExecutionProvider": + # A simple case for debugging purpose. + max_sequence_length = 16 + sequence_length = 15 + packed_qkv = False + config = SparseAttentionConfig( + batch_size=1, + sequence_length=1 if has_past_kv else sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=sequence_length if has_past_kv else 0, + num_heads=4, + kv_num_heads=2, + head_size=8, + sparse_block_size=4, + num_layout=2, + local_blocks=2, + vert_stride=2, + softmax_scale=0.0, + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=packed_qkv, + max_cache_sequence_length=max_sequence_length, + ) + yield config + + +def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rotary=False): + if provider == "CUDAExecutionProvider" and not has_cuda_support(): + return + yield + + device, dtype, formats = get_provider_support_info(provider, False) + batch_sizes = [1, 2, 3] + sequence_lengths = [1, 64, 127, 128, 192, 256] + heads = [4, 8, 16] + + # SparseAttention CUDA kernel only supports head size 128 + head_sizes = [128] if provider == "CUDAExecutionProvider" else [128, 256] + + if comprehensive: + for batch_size in batch_sizes: + for sequence_length in sequence_lengths: + for num_heads in heads: + for head_size in head_sizes: + for format in formats: + packed_qkv = format == InputFormats.QKV_BSN3H + + non_prompt_len = 1 + if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary: + # Generate case of sequence_length > 1 when it is not prompt for CPU provider. + non_prompt_len = batch_size + + query_sequence_length = non_prompt_len if has_past_kv else sequence_length + config = SparseAttentionConfig( + batch_size=batch_size, + sequence_length=query_sequence_length, + max_sequence_length=256, + past_sequence_length=( + min(256 - query_sequence_length, sequence_length) if has_past_kv else 0 + ), + num_heads=num_heads, + kv_num_heads=num_heads // 2, + head_size=head_size, + sparse_block_size=64, + num_layout=2, + local_blocks=2, + vert_stride=2, + softmax_scale=1.8 / (128**0.5), + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=packed_qkv, + do_rotary=do_rotary, + rotary_interleaved=sequence_length <= 128, + max_cache_sequence_length=None if sequence_length >= 128 else 128, + ) + yield config + else: + test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) + for i in range(test_cases): + batch_size = batch_sizes[i % len(batch_sizes)] + sequence_length = sequence_lengths[i % len(sequence_lengths)] + num_heads = heads[i % len(heads)] + head_size = head_sizes[i % len(head_sizes)] + format = formats[i % len(formats)] + packed_qkv = format == InputFormats.QKV_BSN3H + + non_prompt_len = 1 + if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary: + # Generate case of sequence_length > 1 when it is not prompt for CPU provider. + non_prompt_len = batch_size + + query_sequence_length = non_prompt_len if has_past_kv else sequence_length + + config = SparseAttentionConfig( + batch_size=batch_size, + sequence_length=query_sequence_length, + max_sequence_length=256, + past_sequence_length=min(256 - query_sequence_length, sequence_length) if has_past_kv else 0, + num_heads=num_heads, + kv_num_heads=num_heads // 2, + head_size=head_size, + sparse_block_size=64, + num_layout=2, + local_blocks=2, + vert_stride=2, + softmax_scale=1.8 / (128**0.5), + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=packed_qkv, + do_rotary=do_rotary, + rotary_interleaved=sequence_length <= 128, + max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer. + ) + yield config + + +# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. +comprehensive_mode = False class TestSparseAttention(unittest.TestCase): - @unittest.skipUnless(torch.cuda.is_available(), "cuda not available") + @unittest.skipUnless(has_cuda_support(), "cuda not available") def test_sparse_attention(self): major, minor = torch.cuda.get_device_capability() sm = major * 10 + minor + self.run_relevance_test(sm) - if sm not in [75, 80, 86, 89, 90]: - self.skipTest("SparseAttention is not supported on this GPU") + @parameterized.expand(get_simple_test_case("CPUExecutionProvider", True), skip_on_empty=True) + def test_simple_token_cpu(self, config: SparseAttentionConfig): + self.run_one_relevance_test(config) - self.run_relevance_test(sm) + @parameterized.expand(get_simple_test_case("CPUExecutionProvider", False), skip_on_empty=True) + def test_simple_prompt_cpu(self, config: SparseAttentionConfig): + self.run_one_relevance_test(config) + + @parameterized.expand( + get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True), skip_on_empty=True + ) + def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig): + # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense. + if config.sparse_block_size * config.local_blocks > config.total_sequence_length: + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CUDAExecutionProvider", True, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_token_gpu(self, config): + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_token_cpu(self, config): + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_prompt_cpu(self, config): + self.run_one_relevance_test(config) + + @parameterized.expand(get_test_cases("CUDAExecutionProvider", False, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_prompt_gpu(self, config): + self.run_one_relevance_test(config) def run_one_relevance_test(self, config: SparseAttentionConfig): if (not config.do_rotary) and config.total_sequence_length <= 2048: @@ -774,6 +974,10 @@ def run_one_relevance_test(self, config: SparseAttentionConfig): obj = TorchGroupQueryAttention(gqa_config) expected_out = obj.infer() else: + if config.dtype == torch.bfloat16: + # Skip test since create_group_query_attention_onnx_model does not support bfloat16 right now. + return + # Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only). gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False) obj = OrtGroupQueryAttention(gqa_config) @@ -881,6 +1085,8 @@ def run_relevance_no_past_128k(self, sm: int, device): local_blocks=2048, # use dense to compare with GQA vert_stride=8, softmax_scale=None, + provider="CUDAExecutionProvider", + dtype=torch.float16, device=device, is_packed_qkv=packed_qkv, ) @@ -907,6 +1113,8 @@ def run_relevance_past_128k(self, sm: int, device): local_blocks=2048, # use dense to compare with GQA vert_stride=8, softmax_scale=None, + provider="CUDAExecutionProvider", + dtype=torch.float16, device=device, is_packed_qkv=packed_qkv, ) @@ -921,7 +1129,8 @@ def run_relevance_test(self, sm: int): device = torch.device("cuda", device_id) with torch.no_grad(): # Test long sequence when GPU memory is enough (need about 12 GB for 128K sequence length) - if torch.cuda.get_device_properties(device_id).total_memory > 13 * 1024 * 1024 * 1024: + # The 128k tests fails randomly in T4 GPU, increase memory threshold for now. + if torch.cuda.get_device_properties(device_id).total_memory > 20 * 1024 * 1024 * 1024: self.run_relevance_no_past_128k(sm, device) self.run_relevance_past_128k(sm, device) self.run_relevance_no_past(sm, device) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 58c185f818df7..eacd41e6b9c6d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -31,7 +31,7 @@ #include "test_fixture.h" #include "utils.h" #include "custom_op_utils.h" -#include "core/common/gsl.h" +#include #ifdef _WIN32 #include diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc index e8160d1619cb3..5fa4fb31e1c91 100644 --- a/onnxruntime/test/shared_lib/test_nontensor_types.cc +++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc @@ -9,7 +9,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "test_allocator.h" -#include "core/common/gsl.h" +#include #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/testdata/clip_div_shared_initializer.onnx b/onnxruntime/test/testdata/clip_div_shared_initializer.onnx new file mode 100644 index 0000000000000..223d2d2febbb4 Binary files /dev/null and b/onnxruntime/test/testdata/clip_div_shared_initializer.onnx differ diff --git a/onnxruntime/test/testdata/clip_div_shared_initializer.py b/onnxruntime/test/testdata/clip_div_shared_initializer.py new file mode 100644 index 0000000000000..e3c4ab438b0fd --- /dev/null +++ b/onnxruntime/test/testdata/clip_div_shared_initializer.py @@ -0,0 +1,33 @@ +from onnx import TensorProto, checker, helper, save + +graph_proto = helper.make_graph( + [ + helper.make_node( + "Clip", + inputs=["input_0", "initializer_0", "initializer_1"], + outputs=["clip_output"], + name="clip", + ), + helper.make_node( + "Div", + inputs=["clip_output", "initializer_1"], + outputs=["output_0"], + name="div", + ), + ], + "Main_graph", + [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor("initializer_0", TensorProto.FLOAT, [], [0.0]), + helper.make_tensor("initializer_1", TensorProto.FLOAT, [], [6.0]), + ], +) + +model = helper.make_model(graph_proto) +checker.check_model(model, True) +save(model, "clip_div_shared_initializer.onnx") diff --git a/onnxruntime/test/testdata/conv.int4_weights.qdq.onnx b/onnxruntime/test/testdata/conv.int4_weights.qdq.onnx new file mode 100644 index 0000000000000..56f965d0b4101 Binary files /dev/null and b/onnxruntime/test/testdata/conv.int4_weights.qdq.onnx differ diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc index 6a2d27ee95170..2d7d8e7cc735d 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "my_ep_factory.h" #include "my_execution_provider.h" -#include "core/common/gsl.h" +#include #include "core/providers/shared/common.h" #include #include "core/framework/provider_options_utils.h" diff --git a/onnxruntime/test/testdata/make_conv_int4_weights_model.py b/onnxruntime/test/testdata/make_conv_int4_weights_model.py new file mode 100644 index 0000000000000..004342b531026 --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_int4_weights_model.py @@ -0,0 +1,98 @@ +import numpy as np +import onnx + +from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize +from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config + +INPUT0_SHAPE = (1, 3, 8, 8) +INPUT0_NAME = "input_0" + + +def create_f32_model(): + input_0 = onnx.helper.make_tensor_value_info(INPUT0_NAME, onnx.TensorProto.FLOAT, INPUT0_SHAPE) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + weight_data = [ + [ + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + ], + [ + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + ], + [ + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + ], + [ + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + ], + [ + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + ], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias_data = [-10.0, -8.0, 0.0, 8.0, 10.0] + bias = onnx.numpy_helper.from_array(np.array(bias_data, dtype=np.float32), "bias") + + conv_node = onnx.helper.make_node("Conv", [INPUT0_NAME, "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + + return model + + +class DataReader(CalibrationDataReader): + def __init__(self): + self.enum_data = None + self.data_list = [] + + # Generate 10 random input values for calibration + for _ in range(10): + input_data = {INPUT0_NAME: np.random.random(INPUT0_SHAPE).astype(np.float32)} + self.data_list.append(input_data) + + self.datasize = len(self.data_list) + + def get_next(self): + if self.enum_data is None: + self.enum_data = iter(self.data_list) + return next(self.enum_data, None) + + def rewind(self): + self.enum_data = None + + +def create_qdq_model(model_f32): + # Use tensor quantization overrides to quantize Conv's weight input to 4 bits on axis 0. + init_overrides = {"weight": [{"quant_type": QuantType.QInt4, "axis": 0, "symmetric": True}]} + qnn_config = get_qnn_qdq_config( + model_f32, + DataReader(), + init_overrides=init_overrides, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QUInt8, + ) + + quantize(model_f32, "conv.int4_weights.qdq.onnx", qnn_config) + + +if __name__ == "__main__": + model_f32 = create_f32_model() + create_qdq_model(model_f32) diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index b7c3b38538421..1d89272680e47 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -26,9 +26,24 @@ #include "test/test_environment.h" std::unique_ptr ort_env; + +// ortenv_setup is used by /onnxruntime/test/xctest/xcgtest.mm so can't be file local void ortenv_setup() { OrtThreadingOptions tpo; - ort_env.reset(new Ort::Env(&tpo, ORT_LOGGING_LEVEL_WARNING, "Default")); + + // allow verbose logging to be enabled by setting this environment variable to a numeric log level + constexpr auto kLogLevelEnvironmentVariableName = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; + OrtLoggingLevel log_level = ORT_LOGGING_LEVEL_WARNING; + if (auto log_level_override = onnxruntime::ParseEnvironmentVariable(kLogLevelEnvironmentVariableName); + log_level_override.has_value()) { + *log_level_override = std::clamp(*log_level_override, + static_cast(ORT_LOGGING_LEVEL_VERBOSE), + static_cast(ORT_LOGGING_LEVEL_FATAL)); + std::cout << "Setting log level to " << *log_level_override << "\n"; + log_level = static_cast(*log_level_override); + } + + ort_env.reset(new Ort::Env(&tpo, log_level, "Default")); } #ifdef USE_TENSORRT @@ -76,17 +91,6 @@ int TEST_MAIN(int argc, char** argv) { ortenv_setup(); ::testing::InitGoogleTest(&argc, argv); - // allow verbose logging to be enabled by setting this environment variable to a numeric log level - constexpr auto kLogLevelEnvironmentVariableName = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; - if (auto log_level = onnxruntime::ParseEnvironmentVariable(kLogLevelEnvironmentVariableName); - log_level.has_value()) { - *log_level = std::clamp(*log_level, - static_cast(ORT_LOGGING_LEVEL_VERBOSE), - static_cast(ORT_LOGGING_LEVEL_FATAL)); - std::cout << "Setting log level to " << *log_level << "\n"; - ort_env->UpdateEnvWithCustomLogLevel(static_cast(*log_level)); - } - status = RUN_ALL_TESTS(); } ORT_CATCH(const std::exception& ex) { diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 6f07385729555..312aa86277994 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -109,7 +109,8 @@ std::unique_ptr OpenVINOExecutionProviderWithOptions(const O std::unique_ptr DefaultOpenVINOExecutionProvider() { #ifdef USE_OPENVINO ProviderOptions provider_options_map; - return OpenVINOProviderFactoryCreator::Create(&provider_options_map)->CreateProvider(); + SessionOptions session_options; + return OpenVINOProviderFactoryCreator::Create(&provider_options_map, &session_options)->CreateProvider(); #else return nullptr; #endif @@ -189,6 +190,14 @@ std::unique_ptr DefaultNnapiExecutionProvider() { #endif } +std::unique_ptr DefaultVSINPUExecutionProvider() { +#if defined(USE_VSINPU) + return VSINPUProviderFactoryCreator::Create()->CreateProvider(); +#else + return nullptr; +#endif +} + std::unique_ptr DefaultRknpuExecutionProvider() { #ifdef USE_RKNPU return RknpuProviderFactoryCreator::Create()->CreateProvider(); diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index ae8e89c386994..606dfc068d399 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -20,6 +20,7 @@ std::shared_ptr CreateExecutionProviderFactory_MIGrap std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); // std::shared_ptr CreateExecutionProviderFactory_Tvm(const char*); +std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params); @@ -50,6 +51,7 @@ std::unique_ptr MIGraphXExecutionProviderWithOptions(const O std::unique_ptr OpenVINOExecutionProviderWithOptions(const OrtOpenVINOProviderOptions* params); std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); +std::unique_ptr DefaultVSINPUExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); std::unique_ptr DefaultAclExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); diff --git a/onnxruntime/test/util/include/providers.h b/onnxruntime/test/util/include/providers.h index aa489e6cd958b..a73b237ae10df 100644 --- a/onnxruntime/test/util/include/providers.h +++ b/onnxruntime/test/util/include/providers.h @@ -16,6 +16,9 @@ #ifdef USE_NNAPI #include "core/providers/nnapi/nnapi_provider_factory.h" #endif +#ifdef USE_VSINPU +#include "core/providers/vsinpu/vsinpu_provider_factory.h" +#endif #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif diff --git a/onnxruntime/test/util/include/test_utils.h b/onnxruntime/test/util/include/test_utils.h index 48f0d7c2ab1f7..f55295ac8aec4 100644 --- a/onnxruntime/test/util/include/test_utils.h +++ b/onnxruntime/test/util/include/test_utils.h @@ -11,7 +11,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/execution_provider.h" #include "core/framework/framework_common.h" #include "core/framework/ort_value.h" diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc index 598147b81dd89..6bc0f8d105495 100644 --- a/onnxruntime/test/util/test_utils.cc +++ b/onnxruntime/test/util/test_utils.cc @@ -233,7 +233,7 @@ void CheckShapeEquality(const ONNX_NAMESPACE::TensorShapeProto* shape1, #if !defined(DISABLE_SPARSE_TENSORS) void SparseIndicesChecker(const ONNX_NAMESPACE::TensorProto& indices_proto, gsl::span expected_indicies) { using namespace ONNX_NAMESPACE; - Path model_path; + std::filesystem::path model_path; std::vector unpack_buffer; gsl::span ind_span; std::vector converted_indices; diff --git a/onnxruntime/tool/etw/eparser.cc b/onnxruntime/tool/etw/eparser.cc index 526ba6de81966..ff0348d721eb7 100644 --- a/onnxruntime/tool/etw/eparser.cc +++ b/onnxruntime/tool/etw/eparser.cc @@ -7,7 +7,7 @@ // Get the length of the property data. For MOF-based events, the size is inferred from the data type // of the property. For manifest-based events, the property can specify the size of the property value -// using the length attribute. The length attribue can specify the size directly or specify the name +// using the length attribute. The length attribute can specify the size directly or specify the name // of another property in the event data that contains the size. If the property does not include the // length attribute, the size is inferred from the data type. The length will be zero for variable // length, null-terminated strings and structures. @@ -16,7 +16,7 @@ DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, // Get the size of the array. For MOF-based events, the size is specified in the declaration or using // the MAX qualifier. For manifest-based events, the property can specify the size of the array -// using the count attribute. The count attribue can specify the size directly or specify the name +// using the count attribute. The count attribute can specify the size directly or specify the name // of another property in the event data that contains the size. DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize); diff --git a/orttraining/orttraining/core/framework/checkpoint_common.cc b/orttraining/orttraining/core/framework/checkpoint_common.cc index 295f17b894095..2c36895de2ac5 100644 --- a/orttraining/orttraining/core/framework/checkpoint_common.cc +++ b/orttraining/orttraining/core/framework/checkpoint_common.cc @@ -16,7 +16,7 @@ namespace onnxruntime { namespace training { /** - * @brief Create OrtValues From TensorProto objects + * @brief Create OrtValues From TensorProto objects. Doesn't support external tensor. * * @param tensor_protos vector of TensorProto * @param name_to_ort_value saved results. @@ -42,7 +42,7 @@ Status CreateOrtValuesFromTensorProtos( tensor_proto.data_type()) ->GetElementType(); auto p_tensor = std::make_unique(tensor_dtype, tensor_shape, cpu_allocator); - ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), nullptr, tensor_proto, *p_tensor)); + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), std::filesystem::path(), tensor_proto, *p_tensor)); OrtValue ort_value; ort_value.Init(p_tensor.release(), diff --git a/orttraining/orttraining/core/framework/checkpoint_common.h b/orttraining/orttraining/core/framework/checkpoint_common.h index 316417829e43b..5ff96fa77753d 100644 --- a/orttraining/orttraining/core/framework/checkpoint_common.h +++ b/orttraining/orttraining/core/framework/checkpoint_common.h @@ -6,7 +6,6 @@ #include "core/framework/tensorprotoutils.h" #include "core/common/logging/logging.h" #include "core/common/logging/sinks/clog_sink.h" -#include "core/common/path.h" #include "core/common/path_string.h" #include "core/common/status.h" #include "core/framework/framework_common.h" diff --git a/orttraining/orttraining/core/framework/checkpointing.cc b/orttraining/orttraining/core/framework/checkpointing.cc index 9e1aa8d17e3ee..462a05d9db562 100644 --- a/orttraining/orttraining/core/framework/checkpointing.cc +++ b/orttraining/orttraining/core/framework/checkpointing.cc @@ -10,7 +10,6 @@ #include "core/common/common.h" #include "core/common/logging/logging.h" -#include "core/common/path.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/endian_utils.h" #include "core/framework/ort_value.h" @@ -260,13 +259,10 @@ Status LoadModelCheckpoint( ORT_RETURN_IF_ERROR(Env::Default().GetCanonicalPath( checkpoint_path, checkpoint_canonical_path)); - Path relative_tensors_data_path_obj{}; - ORT_RETURN_IF_ERROR(RelativePath( - Path::Parse(model_directory_canonical_path), - Path::Parse(GetCheckpointTensorsDataFilePath(checkpoint_canonical_path)), - relative_tensors_data_path_obj)); + std::filesystem::path relative_tensors_data_path_obj = std::filesystem::relative( + GetCheckpointTensorsDataFilePath(checkpoint_canonical_path), model_directory_canonical_path); ORT_RETURN_IF_ERROR(UpdateTensorsExternalDataLocations( - relative_tensors_data_path_obj.ToPathString(), loaded_tensor_protos)); + relative_tensors_data_path_obj.native(), loaded_tensor_protos)); } // read properties file diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index c5948e563fcf8..e01456ee3d769 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -183,7 +183,7 @@ Status OrtModuleGraphBuilder::OptimizeForwardGraph(const TrainingGraphTransforme } if (!config.optimized_pre_grad_filepath.empty()) { - ORT_RETURN_IF_ERROR(Model::Save(*forward_model_, config.optimized_pre_grad_filepath)); + ORT_RETURN_IF_ERROR(Model::Save(*forward_model_, ToPathString(config.optimized_pre_grad_filepath))); } return Status::OK(); diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index 67a991ea2cce3..762258a45221e 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -9,9 +9,7 @@ #include #include -namespace onnxruntime { -namespace language_interop_ops { -namespace torch { +namespace onnxruntime::language_interop_ops::torch { typedef std::vector (*CustomFunctionRunnerType)(const char* func_name_char, void* callback, @@ -124,6 +122,4 @@ class OrtTorchFunctionPool final { std::mutex mutex_; }; -} // namespace torch -} // namespace language_interop_ops -} // namespace onnxruntime +} // namespace onnxruntime::language_interop_ops::torch diff --git a/orttraining/orttraining/core/framework/torch/dlpack_python.cc b/orttraining/orttraining/core/framework/torch/dlpack_python.cc index d512dc72a438f..f9b237f051258 100644 --- a/orttraining/orttraining/core/framework/torch/dlpack_python.cc +++ b/orttraining/orttraining/core/framework/torch/dlpack_python.cc @@ -3,10 +3,7 @@ #include "orttraining/core/framework/torch/dlpack_python.h" -namespace onnxruntime { -namespace training { -namespace framework { -namespace torch { +namespace onnxruntime::training::framework::torch { static void DlpackCapsuleDestructor(PyObject* data) { DLManagedTensor* dlmanged_tensor = reinterpret_cast( @@ -35,7 +32,4 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) { return ort_value; } -} // namespace torch -} // namespace framework -} // namespace training -} // namespace onnxruntime +} // namespace onnxruntime::training::framework::torch diff --git a/orttraining/orttraining/core/framework/torch/dlpack_python.h b/orttraining/orttraining/core/framework/torch/dlpack_python.h index 37bae2ab37025..9b641971dceac 100644 --- a/orttraining/orttraining/core/framework/torch/dlpack_python.h +++ b/orttraining/orttraining/core/framework/torch/dlpack_python.h @@ -8,10 +8,7 @@ #include "core/dlpack/dlpack_converter.h" #include "orttraining/core/framework/torch/python_common.h" -namespace onnxruntime { -namespace training { -namespace framework { -namespace torch { +namespace onnxruntime::training::framework::torch { // Allocate a new Capsule object, which takes the ownership of OrtValue. // Caller is responsible for releasing. @@ -22,7 +19,4 @@ PyObject* ToDlpack(OrtValue ort_value); // create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion. OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor); -} // namespace torch -} // namespace framework -} // namespace training -} // namespace onnxruntime +} // namespace onnxruntime::training::framework::torch diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index 450a5048aea44..b80acd6c4791a 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -22,7 +22,7 @@ namespace torch { // For handling temporary PyObject pointer newly created with Py_XXX APIs, here is our practice: // Convention: // Wrap those PyObject* in format of "PythonObjectPtr(Py_XXX(), PythonObjectDeleter)". -// Explaination: +// Explanation: // That means, for the PyObject* created by Py_XXX(), its refcnt will be decreased by one // in the PythonObjectDeleter which is triggered once lifetime of PythonObjectPtr instance // ends. diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.cc b/orttraining/orttraining/core/graph/gradient_builder_base.cc index d57675e8b8e20..4262acb636582 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_base.cc @@ -63,16 +63,20 @@ void ComputeBroadcastBackwardAxes( auto A_dim = A_dims[i].dim_param(), B_dim = B_dims[j].dim_param(); if (A_dim != B_dim) { - LOGS_DEFAULT(INFO) << "Gradient building for node " << node_name << ": symbolic dimension expects to match. " - << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; + LOGS_DEFAULT(INFO) + << "Gradient building for node " << node_name << ": symbolic dimension expects to match. " + << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) + << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; } } else if (A_dims[i].has_dim_param() && B_dims[j].has_dim_value()) { auto A_dim = A_dims[i].dim_param(); auto B_dim = B_dims[j].dim_value(); if (B_dim != 1) { - LOGS_DEFAULT(INFO) << "Gradient building for node " << node_name << ": symbolic broadcasting expects the B_dimension to be 1. " - << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; + LOGS_DEFAULT(INFO) + << "Gradient building for node " << node_name << ": symbolic broadcasting expects the B_dimension to be 1. " + << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) + << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; } else { if (B_axes) { B_axes->push_back(gsl::narrow_cast(k)); @@ -83,8 +87,10 @@ void ComputeBroadcastBackwardAxes( auto B_dim = B_dims[j].dim_param(); if (A_dim != 1) { - LOGS_DEFAULT(INFO) << "Gradient building for node " << node_name << ": symbolic broadcasting expects the A_dimension to be 1. " - << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; + LOGS_DEFAULT(INFO) + << "Gradient building for node " << node_name << ": symbolic broadcasting expects the A_dimension to be 1. " + << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) + << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; } else { if (A_axes) { A_axes->push_back(gsl::narrow_cast(k)); diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 2d8a87f6d4427..a4aa70c99eec1 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -225,7 +225,9 @@ class GradientBuilderBase { } int OnnxOpSetVersion() const { - return graph_ != nullptr && graph_->DomainToVersionMap().find(kOnnxDomain) != graph_->DomainToVersionMap().end() ? graph_->DomainToVersionMap().at(kOnnxDomain) : -1; + return graph_ != nullptr && graph_->DomainToVersionMap().find(kOnnxDomain) != graph_->DomainToVersionMap().end() + ? graph_->DomainToVersionMap().at(kOnnxDomain) + : -1; } template diff --git a/orttraining/orttraining/core/graph/loss_function_registry.h b/orttraining/orttraining/core/graph/loss_function_registry.h index 76242276080ca..6c880b6a3d9f5 100644 --- a/orttraining/orttraining/core/graph/loss_function_registry.h +++ b/orttraining/orttraining/core/graph/loss_function_registry.h @@ -18,7 +18,7 @@ struct LossFunctionUsingOperator : public ILossFunction { class LossFunctionRegistry : public GenericRegistry { public: - // Register a list of non-operator loss functions stacitally. + // Register a list of non-operator loss functions statically. void RegisterNonOperatorLossFunctions(); // Register a operator loss function. diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 0d4291a3b8b31..4b6a9a6e594cd 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -508,7 +508,7 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev *embedding_node); // Add flatten pattern to each input node of the subgraph - // to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...] + // to flatten the shape of [batch_size, seqlen, ...] to [valid_token_count, ...] InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger); handled_input_count++; for (auto& node : candidate_inputs) { diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h index cc3c90dac2d58..a09ee75c73aaf 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h @@ -25,7 +25,7 @@ namespace onnxruntime { * * This transformer is implemented in the following steps: * 1. Iterate the graph and find the Embedding node that matches these requirements: - * 1.1 The 2nd input is a graph input and its rank > 2, with the first two dimensions, are: + * 1.1 Following a PythonOp(FlagAndPrintDensity) node, and its rank > 2, with the first two dimensions, are: * [batch_size, sequence_length]. Both dimensions can be symbolic or concrete dim values. * 1.2 The 3rd input(padding idx) is a scalar constant initializer, and should >= 0. * 2. Append embedding node in node_to_scan_list. @@ -54,6 +54,8 @@ namespace onnxruntime { * \ \ / / / * \_________________\_________________________/________________/______________________/ * | + * PythonOp (FlagAndPrintDensity) + * | * ATen:embedding * | * - - - - - - - - - - - -| @@ -68,7 +70,7 @@ namespace onnxruntime { * output * * - * After the transformation: + * After the transformation (PythonOp (FlagAndPrintDensity) is removed unless user need to print density for each step): * * input_ids [batch_size, seq_length] * | \ diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h b/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h index cf34706115894..2204724bacf55 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h @@ -32,10 +32,10 @@ namespace onnxruntime { * 2. Its 2nd output (log_prob) MUST NOT be a graph output and MUST NOT be consumed by other nodes. * 3. Its ignore_index exists and is a constant scalar value. * 4. Its 2nd input label's input node is not a `ShrunkGather` node (to avoid this transformer duplicated applied). - * 5. Its 2nd input label is 1) a graph input or 2) output of a Reshape node taking a graph input as its data input. + * 5. Following PythonOp (FlagAndPrintDensity). * * - * After the transformation: + * After the transformation (PythonOp (FlagAndPrintDensity) is removed unless user need to print density for each step): * labels [token_count] * \_______ * \ \ diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_config.h b/orttraining/orttraining/core/optimizer/graph_transformer_config.h index f72dbfa3fdfc3..c496e36689de1 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_config.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_config.h @@ -28,6 +28,7 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat bool print_input_density{false}; // Path for serialization of the transformed optimized model. If empty, serialization is disabled. + // A UTF-8 string. std::string optimized_pre_grad_filepath; }; diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 436a24c34ddfd..589e7be455dbc 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -44,7 +44,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" @@ -117,11 +116,10 @@ std::vector> GeneratePreTrainingTransformers( ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); #endif - // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create - // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output - // or consume different initializers with same value, by default, CSE will not merge them. + // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for + // CSE. For example, if A and B nodes consume different initializers with same value, by default, + // CSE will not merge them. transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); // LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input. transformers.emplace_back(std::make_unique(compatible_eps)); // Remove duplicate nodes. Must be applied before any recompute transformations. diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h index 011a007ab6e74..5196f5256fedd 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/optimizer/graph_transformer.h" #include "orttraining/core/optimizer/graph_transformer_config.h" diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.h b/orttraining/orttraining/core/optimizer/insert_output_rewriter.h index 5e4bf5c5ce7a9..de000e00f1bf8 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.h +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.h @@ -7,7 +7,7 @@ namespace onnxruntime { -// Rewrite rule that insert an addtional output to the matched node. +// Rewrite rule that insert an additional output to the matched node. class InsertMaxPoolOutput : public RewriteRule { public: InsertMaxPoolOutput() noexcept @@ -24,7 +24,7 @@ class InsertMaxPoolOutput : public RewriteRule { Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -// Rewrite rule that insert an addtional output to the matched node. +// Rewrite rule that insert an additional output to the matched node. // Adding this second output to expose FW intermediate result for speeding up BW computation class InsertSoftmaxCrossEntropyLossOutput : public RewriteRule { public: diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 088fd345135db..8d110c692751e 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -60,9 +60,9 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap; * Most recent revisited for ONNX v1.15.0 release - https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/docs/Operators.md * * We defined supported list explicitly instead of using a excluding list for the following reasons: - * 1. Some ops generate indeterministic results (for example using random number generator). We need evaluate whether + * 1. Some ops generate non-deterministic results (for example using random number generator). We need evaluate whether * this is a problem for recompute before adding the support, instead of fixing this after we find and try to - * fix convergence issues (which will be very hard if we have multiple indeterministic operators by default supported.) + * fix convergence issues (which will be very hard if we have multiple non-deterministic operators by default supported.) * 2. Some ops schema will be changed in new opsets, we need also check manually whether it is applicable to recompute * or not. * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index 5aa05b0f02e0f..d87706ea98061 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -151,7 +151,7 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the * size of stashed activation. - * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a + * @param can_compromise_stashed_activation A bool return value, to indicate there are opportunities for finding a * compromised subgraph. */ std::unique_ptr CheckNodeForRecompute(const GraphViewer& graph_viewer, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index a4fbacc8a1f4c..dd585068a23ca 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -88,7 +88,7 @@ std::tuple IsResidualNodeArg(const GraphViewer& ----------------------| | | | | - | SimplifiedLayerNormalization (layer boudary node) + | SimplifiedLayerNormalization (layer boundary node) | | | | | MistralAttention diff --git a/orttraining/orttraining/core/optimizer/qdq_fusion.h b/orttraining/orttraining/core/optimizer/qdq_fusion.h index 722565cffa803..3bf9a7909f7e2 100644 --- a/orttraining/orttraining/core/optimizer/qdq_fusion.h +++ b/orttraining/orttraining/core/optimizer/qdq_fusion.h @@ -14,7 +14,7 @@ This transformer will be used during QAT (Quantization Aware Training). For QAT an onnx graph that has Q->DQ nodes needs to be made ready for training. The output of the Q node is a quantized type. Backpropagation on quantized type is not supported in ort. So, we replace the occurrences of Q->DQ with FakeQuant which internally will perform the -Q->DQ opeeration and at the same time can support backpropagation. +Q->DQ operation and at the same time can support backpropagation. from: x (fp32) diff --git a/orttraining/orttraining/core/optimizer/transpose_replacement.h b/orttraining/orttraining/core/optimizer/transpose_replacement.h index c38e402339823..d2bbe2fdcfc19 100644 --- a/orttraining/orttraining/core/optimizer/transpose_replacement.h +++ b/orttraining/orttraining/core/optimizer/transpose_replacement.h @@ -13,10 +13,10 @@ namespace onnxruntime { Transpose is equivalent to a Reshape if: empty dimensions (which dim_value=1) can change place, not empty dimensions must be in - the same order in the permuted tenosr. + the same order in the permuted tensor. Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). -This Rewrite rule replaces Transpose which meets the requirments with Reshape. +This Rewrite rule replaces Transpose which meets the requirements with Reshape. Because Transpose need memory copy while Reshape needn't, this replacement can save overhead for memory copy. It is attempted to be triggered only on nodes with op type "Transpose". diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 15c74d40926ee..6421f7c81f7fb 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -1019,9 +1019,8 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz optimizer = optimizer.substr(0, pos); perf_metrics["Optimizer"] = optimizer; - Path model_path{}; - ORT_RETURN_IF_ERROR(Path::Parse(params_.model_path, model_path)); - PathString leaf = model_path.GetComponents().back(); + std::filesystem::path model_path = params_.model_path; + PathString leaf = model_path.filename(); std::string model_name = ToUTF8String(leaf.c_str()); perf_metrics["ModelName"] = model_name; diff --git a/orttraining/orttraining/models/runner/training_util.cc b/orttraining/orttraining/models/runner/training_util.cc index 7764508d9a091..6af3bf4410065 100644 --- a/orttraining/orttraining/models/runner/training_util.cc +++ b/orttraining/orttraining/models/runner/training_util.cc @@ -53,7 +53,7 @@ common::Status DataSet::AddData(const vector& featu OrtMemoryInfo info("Cpu", OrtDeviceAllocator, OrtDevice{}, 0, OrtMemTypeDefault); std::unique_ptr buffer = std::make_unique(cpu_tensor_length); ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue( - Env::Default(), nullptr, tensor_proto, MemBuffer(buffer.get(), cpu_tensor_length, info), ort_value)); + Env::Default(), std::filesystem::path(), tensor_proto, MemBuffer(buffer.get(), cpu_tensor_length, info), ort_value)); sample->push_back(ort_value); ortvalue_buffers_.emplace_back(std::move(buffer)); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index e2616e1b441f7..a81ea76e807ca 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -621,7 +621,7 @@ void addObjectMethodsForTraining(py::module& m) { ORT_THROW_IF_ERROR(gradient_graph_builder->builder_->Build()); }) .def("save", [](PyGradientGraphBuilderContext* gradient_graph_builder, const std::string& path) { - ORT_THROW_IF_ERROR(Model::Save(*(gradient_graph_builder->model_), path)); + ORT_THROW_IF_ERROR(Model::Save(*(gradient_graph_builder->model_), ToPathString(path))); }) .def("get_model", [](PyGradientGraphBuilderContext* gradient_graph_builder) { std::string model_str; diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 624b30ffdab3b..c98e5bcd97092 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -13,6 +13,9 @@ from onnxruntime.tools.convert_onnx_models_to_ort import OptimizationStyle, convert_onnx_models_to_ort from onnxruntime.training import onnxblock +# threshold for the size of the modelproto where you should use a path instead +USE_PATH_THRESHOLD = 2147483648 + class LossType(Enum): """Loss type to be added to the training model. @@ -37,7 +40,7 @@ class OptimType(Enum): def generate_artifacts( - model: onnx.ModelProto, + model: Union[onnx.ModelProto, str], requires_grad: Optional[List[str]] = None, frozen_params: Optional[List[str]] = None, loss: Optional[Union[LossType, onnxblock.Block]] = None, @@ -61,7 +64,8 @@ def generate_artifacts( All generated ModelProtos will use the same opsets defined by *model*. Args: - model: The base model to be used for gradient graph generation. + model: The base model or path to the base model to be used for gradient graph generation. For models >2GB, + use the path to the base model. requires_grad: List of names of model parameters that require gradient computation frozen_params: List of names of model parameters that should be frozen. loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph. @@ -86,6 +90,22 @@ def generate_artifacts( RuntimeError: If the optimizer provided is not one of the supported optimizers. """ + loaded_model = None + model_path = None + + if isinstance(model, str): + loaded_model = onnx.load(model) + model_path = model + elif isinstance(model, onnx.ModelProto): + if model.ByteSize() > USE_PATH_THRESHOLD: + # infer_shapes and check_model from ONNX both require paths to be used for >2GB models. + raise RuntimeError("This model is > 2GB. Please pass in a path to the ONNX file instead.") + + loaded_model = model + model_path = None + else: + raise RuntimeError("Please pass in either a string or an ONNX ModelProto for the model.") + loss_blocks = { LossType.MSELoss: onnxblock.loss.MSELoss, LossType.CrossEntropyLoss: onnxblock.loss.CrossEntropyLoss, @@ -165,12 +185,12 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), ( + with onnxblock.base(loaded_model, model_path), ( onnxblock.custom_op_library(custom_op_library_path) if custom_op_library is not None else contextlib.nullcontext() ): - _ = training_block(*[output.name for output in model.graph.output]) + _ = training_block(*[output.name for output in loaded_model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() @@ -220,7 +240,7 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_ return opset_version = None - for domain in model.opset_import: + for domain in loaded_model.opset_import: if domain.domain == "" or domain.domain == "ai.onnx": opset_version = domain.version break diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 149d0a360f7d3..80f07c3738a7e 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -4,6 +4,7 @@ import contextlib import copy import logging +import os from abc import ABC, abstractmethod from typing import Any, List, Optional @@ -28,8 +29,13 @@ class Block(ABC): base (onnx.ModelProto): The base model that the subclass can manipulate. """ - def __init__(self): + def __init__(self, temp_file_name="temp.onnx"): + if os.path.isabs(temp_file_name): + raise RuntimeError("Please pass in a relative path for the temp_file_name.") self.base = None + self.temp_onnx_file_path = os.path.join(os.getcwd(), temp_file_name) + # onnx.save location parameter requires a relative path to the model path + self.temp_external_data_file_name = temp_file_name + ".data" @abstractmethod def build(self, *args, **kwargs): @@ -47,10 +53,58 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - onnx.checker.check_model(self.base, True) + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.save( + accessor._GLOBAL_ACCESSOR.model, + self.temp_onnx_file_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=self.temp_external_data_file_name, + ) + + onnx.checker.check_model(self.temp_onnx_file_path, True) + else: + onnx.checker.check_model(self.base, True) return output + def infer_shapes_on_base(self): + """ + Performs shape inference on the global model. If a path was used, then uses the + infer_shapes_path API to support models with external data. + + Returns the shape-inferenced ModelProto. + """ + if accessor._GLOBAL_ACCESSOR.has_path: + onnx.save( + accessor._GLOBAL_ACCESSOR.model, + self.temp_onnx_file_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=self.temp_external_data_file_name, + ) + + onnx.shape_inference.infer_shapes_path(self.temp_onnx_file_path) + # shape inferenced model is saved to original path + model = onnx.load(self.temp_onnx_file_path) + + return model + else: + return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + + def __del__(self): + # since the ModelProto does not store the external data parameters themselves, just the metadata + # for where the external data can be found, we retain the external data files for the intermediate + # calls until the Block no longer needs to be used. + if os.path.exists(self.temp_onnx_file_path): + os.remove(self.temp_onnx_file_path) + # get absolute path for the external data file + external_data_file_path = os.path.join( + os.path.dirname(self.temp_onnx_file_path), self.temp_external_data_file_name + ) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) + class _BinaryOp(Block): def __init__(self, op_name): diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index e719301e13f48..09429dd844187 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -93,19 +93,20 @@ def build(self, scores_input_name: str, labels_name: str = "labels"): labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name)) labels_input.name = labels_name labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64 - # If the predictions are (num_examples x num_classes) - # labels should be (num_examples,) - del labels_input.type.tensor_type.shape.dim[1] + # Assumes classes is the last dimension + # e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,) + # or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len) + del labels_input.type.tensor_type.shape.dim[-1] self.base.graph.input.append(labels_input) loss_node_input_names = [scores_input_name, labels_name] if self._weight: loss_node_input_names.append(weight_name) + loss_node_output_name = _graph_utils.generate_graph_name("loss") - loss_node_output_names = [ - loss_node_output_name, - _graph_utils.generate_graph_name("log_prob"), - ] + log_prob_output_name = _graph_utils.generate_graph_name("log_prob") + + loss_node_output_names = [loss_node_output_name, log_prob_output_name] loss_node = onnx.helper.make_node( "SoftmaxCrossEntropyLoss", loss_node_input_names, diff --git a/orttraining/orttraining/python/training/onnxblock/model_accessor.py b/orttraining/orttraining/python/training/onnxblock/model_accessor.py index ac7a53a554e0a..302573064be6e 100644 --- a/orttraining/orttraining/python/training/onnxblock/model_accessor.py +++ b/orttraining/orttraining/python/training/onnxblock/model_accessor.py @@ -15,10 +15,12 @@ class ModelAccessor: Attributes: model: The onnx model that is manipulated by the onnx blocks. + model_path: The path to the base model. Can be None. """ - def __init__(self, model: onnx.ModelProto): + def __init__(self, model: onnx.ModelProto, model_path: str | None = None): self._model = model + self._path = model_path @property def model(self) -> onnx.ModelProto: @@ -30,6 +32,22 @@ def model(self) -> onnx.ModelProto: ) return self._model + @property + def path(self) -> str: + """ModelAccessor property that gets the path to the base model.""" + + if self._path is None: + raise RuntimeError( + "The path to the onnx model was not set. Please use the context manager onnxblock.onnx_model to create the model and pass in a string." + ) + return self._path + + @property + def has_path(self) -> bool: + """Returns True if ModelAccessor has a path to a model, False otherwise.""" + + return self._path is not None + # These variable resides in the global namespace. # Different methods can access this global model and manipulate it. @@ -39,7 +57,7 @@ def model(self) -> onnx.ModelProto: @contextmanager -def base(model: onnx.ModelProto): +def base(model: onnx.ModelProto, model_path: str | None = None): """Registers the base model to be manipulated by the onnx blocks. Example: @@ -53,6 +71,7 @@ def base(model: onnx.ModelProto): Args: model: The base model to be manipulated by the onnx blocks. + model_path: The path to the base model. None if there is no model path to pass in. Returns: ModelAccessor: The model accessor that contains the modified model. @@ -69,7 +88,7 @@ def base(model: onnx.ModelProto): "model from scratch." ) - _GLOBAL_ACCESSOR = ModelAccessor(model_clone) + _GLOBAL_ACCESSOR = ModelAccessor(model_clone, model_path) try: yield _GLOBAL_ACCESSOR finally: @@ -112,7 +131,7 @@ def empty_base(opset_version: int | None = None): ) ) - _GLOBAL_ACCESSOR = ModelAccessor(model) + _GLOBAL_ACCESSOR = ModelAccessor(model, None) try: yield _GLOBAL_ACCESSOR finally: diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index a2922353ac70e..64f7acf4dc02c 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -70,7 +70,7 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - self._model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + self._model = self.infer_shapes_on_base() _graph_utils.register_graph_outputs(self._model, output) @@ -187,7 +187,7 @@ def __call__(self, *args, **kwargs): output = self.build(*args, **kwargs) - model = onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model) + model = self.infer_shapes_on_base() _graph_utils.register_graph_outputs(model, output) diff --git a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py index ff128c4da4259..12eba90170fb9 100644 --- a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py +++ b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py @@ -26,7 +26,7 @@ def override_function(m_self): # noqa: N805 from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops - warnings.warn("Apex AMP fp16_optimizer functions are overrided with faster implementation.", UserWarning) + warnings.warn("Apex AMP fp16_optimizer functions are overridden with faster implementation.", UserWarning) # Implementation adapted from https://github.com/NVIDIA/apex/blob/082f999a6e18a3d02306e27482cc7486dab71a50/apex/amp/_process_optimizer.py#L161 def post_backward_with_master_weights(self, scaler): diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index 20f4f814e5476..55e2e08432137 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -140,7 +140,7 @@ def can_be_modified(self): ) def override_function(self): - warnings.warn("DeepSpeed fp16_optimizer functions are overrided with faster implementation.", UserWarning) + warnings.warn("DeepSpeed fp16_optimizer functions are overridden with faster implementation.", UserWarning) def get_grad_norm_direct(target, gradients, params, norm_type=2): from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops diff --git a/orttraining/orttraining/python/training/optim/_megatron_modifier.py b/orttraining/orttraining/python/training/optim/_megatron_modifier.py index 707727120c5cd..702eba77cb74a 100644 --- a/orttraining/orttraining/python/training/optim/_megatron_modifier.py +++ b/orttraining/orttraining/python/training/optim/_megatron_modifier.py @@ -27,7 +27,7 @@ def can_be_modified(self): ) def override_function(self): - warnings.warn("Megatron-LM fp16_optimizer functions are overrided with faster implementation.", UserWarning) + warnings.warn("Megatron-LM fp16_optimizer functions are overridden with faster implementation.", UserWarning) def clip_master_grads(target, max_norm, norm_type=2): """ diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 75512cb8e8c88..a8590cea22887 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -238,7 +238,7 @@ def native_group_norm_gradient(): # PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name -# are available for all versions, though they are not that convienent to use. +# are available for all versions, though they are not that convenient to use. def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] if "bicubic" in backward_fn: diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0bd29b8d155c4..10e7f60b7da0f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -437,7 +437,7 @@ def permute_and_reshape_tensor( shape_tensor, ): # If matmul_output_axes and contraction_axes are contiguous in input tensor, - # we can move Reshape to before Transpose, so it's possible that the Transpoase is fused to MatMul. + # we can move Reshape to before Transpose, so it's possible that the Transpose is fused to MatMul. # Otherwise, we have to Transpose first to move those axes together and then Reshape. is_matmul_output_axes_contiguous = is_axes_contiguous(matmul_output_axes) is_contraction_axes_contiguous = is_axes_contiguous(contraction_axes) @@ -525,7 +525,7 @@ def permute_and_reshape_tensor( @register_symbolic("einsum", torch_version_end="1.13.0") @parse_args("s", "v") -def einsum_pre_troch_113(g, equation, tensor_list): +def einsum_pre_torch_113(g, equation, tensor_list): return einsum_internal(g, equation, tensor_list) @@ -540,12 +540,12 @@ def einsum_internal(g, equation, tensor_list): num_ops = len(tensors) assert num_ops > 0 - # Doesn't support implicit output is ellipsis or more than 2 oprands for now. - # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands. + # Doesn't support implicit output is ellipsis or more than 2 operands for now. + # Doesn't support ellipsis ('...') for now as not easy to get sizes of operands. if num_ops != 2 or equation.find("->") == -1 or "." in equation: return g.op("Einsum", *tensors, equation_s=equation) - # Take "ks,ksm->sm" as example. After prcoess inputs, + # Take "ks,ksm->sm" as example. After process inputs, # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m]. lhs_labels, rhs_labels, result_labels = parse_equation(equation) diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index 84d7bf6410966..047cd4c59d636 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -102,7 +102,7 @@ def __init__( ): """ :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string - :param fw_feed_names: Feed names for foward pass. + :param fw_feed_names: Feed names for forward pass. :param fw_outputs_device_info: Device info for fetches in forward pass. :param bw_fetches_names: Fetch names for backward pass. :param bw_outputs_device_info: Device info for fetches in backward pass. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 8e383a5545e42..c1ff62a5faea7 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -4,13 +4,9 @@ # -------------------------------------------------------------------------- import copy -import inspect -import io import logging import os from abc import ABC, abstractmethod # noqa: F401 -from functools import partial -from hashlib import md5 as hash_fn from typing import Dict, List, Optional, Tuple import onnx @@ -19,24 +15,16 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype - -from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils, export_context -from ._fallback import ( - ORTModuleDeviceException, - ORTModuleONNXModelException, - ORTModuleTorchModelException, - _FallbackManager, - _FallbackPolicy, - wrap_exception, -) +from onnxruntime.training.utils import PTable, onnx_dtype_to_pytorch_dtype + +from . import _are_deterministic_algorithms_enabled, _logger, _onnx_models, _utils +from ._fallback import ORTModuleTorchModelException, _FallbackManager, _FallbackPolicy, wrap_exception from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface -from ._io import _FlattenedModule, _InputInfo -from ._logger import LogColor -from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector -from ._utils import check_function_has_param, get_rank +from ._graph_transition_manager import GraphTransitionManager, PostExportProcessedModelInfo +from ._io import _FlattenedModule +from ._runtime_inspector import RuntimeInspector +from ._utils import get_rank from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -85,15 +73,12 @@ def __init__( # Original and flattened (transformed) output module self._flattened_module = module - # onnx models self._onnx_models = _onnx_models.ONNXModels() + self._graph_transition_manager: Optional[GraphTransitionManager] = None - # Model after inference optimization or gradient building. + # Model after inference optimization and then gradient building. self._graph_builder = None self._graph_info = None - self._graph_initializer_names = set() - self._graph_initializer_names_to_train = set() - self._graph_initializers: List[torch.nn.parameter.Parameter] = [] # TrainingAgent or InferenceAgent self._execution_agent = None @@ -107,36 +92,11 @@ def __init__( # To be instantiated in the concrete implementation of GraphExecutionManager self._export_mode = export_mode - # Exporter can take extra arguments for ORTModule extensions - # It cannot overlap with required/immutable arguments (validated in runtime) - self._export_extra_kwargs = {} - - # Input and output infos (including schema) for exported model. - self._input_info: Optional[_InputInfo] = None - self._module_output_schema: Optional[ORTModelInputOutputSchemaType] = None - - # Device where the model is placed. - self._device: Optional[torch.device] = _utils.get_device_from_module(module) - - # Forward function input parameters of the original module. - self._module_parameters: List[inspect.Parameter] = list( - inspect.signature(self._original_module.forward).parameters.values() - ) - - # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. - for input_parameter in self._module_parameters: - if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: - self._logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) # WIP feature to enable caching in Gradient accumulation scenario. self._gradient_accumulation_manager = GradientAccumulationManager() - # Flag to re-export the model due to attribute change on the original module. - # Re-export will be avoided if _skip_check is enabled. - self._original_model_has_changed = False - # Inspector for runtime information, for example input data, memory usage, etc. self._runtime_inspector = RuntimeInspector( self._logger, self._original_module, self._export_mode == torch.onnx.TrainingMode.TRAINING @@ -163,9 +123,7 @@ def __init__( configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) - # Will be reset everytime we re-initialize the graph builder. - # Be noted, we will never enable this feature for inference mode. - self._mem_efficient_grad_management_is_enabled = False + self._initialize_graph_transition_manager() def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): @@ -176,6 +134,18 @@ def _get_torch_gpu_allocator_function_addresses(self): self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address() self._torch_empty_cache = torch_gpu_allocator.gpu_caching_allocator_empty_cache_address() + def _initialize_graph_transition_manager(self): + """Creates a new GraphTransitionManager, initializes it and saves it to self._graph_transition_manager""" + self._graph_transition_manager = GraphTransitionManager( + flatten_module=self._flattened_module, + export_mode=self._export_mode, + debug_options=self._debug_options, + runtime_options=self._runtime_options, + time_tracker=self.time_tracker, + runtime_inspector=self._runtime_inspector, + logger=self._logger, + ) + def _validate_module_type(self, module): """Raises ORTModuleTorchModelException if the module is not a torch.nn.Module""" @@ -205,7 +175,9 @@ def forward(self): def _build_graph(self, config): if self._runtime_options.use_static_shape: - self._graph_builder.build(config, self._input_info.shape) + self._graph_builder.build( + config, self._graph_transition_manager._model_info_for_export.onnx_graph_input_shapes + ) else: self._graph_builder.build(config) @@ -259,7 +231,8 @@ def _get_session_config(self): # Enable memory efficient execution order for training if 1). memory efficient grad management is enabled # or 2). memory optimizer is enabled. use_memory_efficient_topo_sort = (self._export_mode == torch.onnx.TrainingMode.TRAINING) and ( - self._mem_efficient_grad_management_is_enabled or self._runtime_options.memory_optimizer_is_enabled() + self._graph_transition_manager._post_export_processed_model_info.is_mem_efficient_grad_management_enabled + or self._runtime_options.memory_optimizer_is_enabled() ) session_options.execution_order = ( onnxruntime.ExecutionOrder.MEMORY_EFFICIENT @@ -283,266 +256,6 @@ def _get_session_config(self): return session_options, providers, provider_options - @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) - @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) - def _export_model(self, *inputs, **kwargs) -> bool: - # 1. Set the self._device from the user module - # 2. Verify input schema matches the schema used on the previous model export - # 3. Export the user model under self._export_training_flag mode - # Return True if the model needs to be exported, False if no export is required. - - # Note: Model is only exported when: - # 1. Model has never been exported before. - # 2. Model input schema has changed (changes in inputs requiring gradient, shape, boolean inputs values change, etc) - # Model is not re-exported when the model parameters change. This can happen when the model is stateful, - # or the user explicitly changed model parameters after the onnx export. - - # Record random states here and restore later in case any of them gets changed during the export, - # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. - random_states = _utils.get_random_states() - - schema = _io._extract_schema({"args": copy.copy(inputs), "kwargs": copy.copy(kwargs)}, self._device) - if ( - self._onnx_models.exported_model - and schema == self._input_info.schema - and not self._original_model_has_changed - ): - # All required models have already been exported previously - return False - - self._set_device_from_module(inputs, kwargs) - embedding_hook_handles = self._add_check_embedding_sparsity_hook() - label_hook_handles = self._add_check_label_sparsity_hook() - - from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step - - with export_context(), no_increase_global_step(): - self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) - - for hook in embedding_hook_handles: - hook.remove() - for hook in label_hook_handles: - hook.remove() - - if self._debug_options.save_onnx_models.save: - self._onnx_models.save_exported_model( - self._debug_options.save_onnx_models.path, - self._debug_options.save_onnx_models.name_prefix, - self._export_mode, - ) - - if self._runtime_options.run_symbolic_shape_infer: - self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes( - self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True - ) - - # Restore the recorded random states - _utils.set_random_states(random_states) - - return True - - def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inputs, **kwargs) -> onnx.ModelProto: - """Exports PyTorch `self._flattened_module` to ONNX for inferencing or training, - using `*inputs` and `**kwargs` as input - - TODO: How to support dynamic axes? Dimensions are determined by samples - """ - - # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # DEVINFO -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # INFO -> [Rank 0] FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) - # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) - # Be noted: rank 0 log only is controlled by logger configured in _logger.py - torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO - - # Setup dynamic axes for onnx model - self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) - need_deep_copy = self._runtime_options.deepcopy_before_model_export and _io.can_module_be_deep_cloned( - self._original_module, self._device - ) - if not need_deep_copy: - if self._runtime_options.deepcopy_before_model_export: - self._logger.warning( - "Since the user requested not to deep copy this model, " - "the initial weights may not be preserved and could change slightly during the forward run. " - "This could cause a minor difference between the ORTModule and the PyTorch run for the " - "first iteration. The computation will proceed as normal, but this should be noted." - ) - else: - self._logger.warning( - "Due to the limited GPU memory execution manager does not create a deep copy of this model. " - "Therefore, the initial weights might be slightly altered during the forward run. " - "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " - "first iteration. The computation will continue as usual, but this should be noted." - ) - ( - output_names, - output_dynamic_axes, - self._module_output_schema, - ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger, self._device, need_deep_copy - ) - self._input_info.dynamic_axes.update(output_dynamic_axes) - - # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs - self._flattened_module._input_info = self._input_info - - self._logger.info("Exporting the PyTorch model to ONNX...") - - # Leverage cached model if available - cache_dir = self._runtime_options.ortmodule_cache_dir - if cache_dir: - filename = os.path.join( - cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" - ) - if os.path.exists(cache_dir) and os.path.isfile(filename): - self._logger.warning( - f"Cached model detected! Cached model will be used to save export and initialization time." - f"If you want the model to be re-exported then DELETE {filename}." - ) - exported_model = onnx.load(filename) - return exported_model - - # Export torch.nn.Module to ONNX - f = io.BytesIO() - - # Deepcopy inputs, since input values may change after model run. - # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). - # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) - # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) - # Ops behaving differently under train/eval mode need to be exported with the - # correct training flag to reflect the expected behavior. - # For example, the Dropout node in a model is dropped under eval mode. - assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" - - try: - from ._zero_stage3_compatibility import stage3_export_context - - with torch.no_grad(), stage3_export_context(self._runtime_options.enable_zero_stage3_support, self): - required_export_kwargs = { - "input_names": self._input_info.names, - "output_names": output_names, - "opset_version": self._runtime_options.onnx_opset_version, - "do_constant_folding": False, - "training": self._export_mode, - "dynamic_axes": self._input_info.dynamic_axes, - "verbose": torch_exporter_verbose_log, - "export_params": False, - "keep_initializers_as_inputs": True, - } - - if check_function_has_param(torch.onnx.export, "autograd_inlining"): - # From some PyTorch version, autograd_inlining is a valid argument. - # We allow it to be True if custom autograd function is disabled (where autograd.Function - # anyway is not supported in ONNX until it can be inlined). - required_export_kwargs["autograd_inlining"] = ( - not self._runtime_options.enable_custom_autograd_function - ) - - invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() - - if len(invalid_args) != 0: - error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." - raise RuntimeError(error_msg) - - torch.onnx.export( - self._flattened_module, - sample_inputs_as_tuple, - f, - **required_export_kwargs, - **self._export_extra_kwargs, - ) - except Exception as e: - message = _utils.get_exception_as_string(e) - - # Special handling when Huggingface transformers gradient checkpoint usage pattern found. - # For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this: - # File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward - # layer_outputs = self._gradient_checkpointing_func( - # File "/site-packages/torch/_compile.py", line 24, in inner - # return torch._dynamo.disable(fn, recursive)(*args, **kwargs) - # File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn - # raise RuntimeError( - # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment. - if ( - "_gradient_checkpointing_func" in message - and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message - ): - is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1 - notes = ( - " Your model is running with gradient checkpointing, yet the PyTorch exporter\n" - " failed during tracing the graph. Try to enable ORTModule's\n" - " gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n" - " using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n" - ) - if is_ckpt_activation_allowed: - # If the user allows the gradient checkpointing export, we should inform the user to disable it, - # to make layerwise recompute work. - notes += ( - " We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n" - " which enables gradient checkpointing torch.autograd.Functions(s) to export.\n" - " To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n" - " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" - ) - - self._logger.error( - f"{LogColor.RED}\n" - "******************************** IMPORTANT NOTE *******************************\n" - f"{notes}" - "*******************************************************************************\n" - f"{LogColor.ENDC}\n" - ) - - raise wrap_exception( # noqa: B904 - ORTModuleONNXModelException, - RuntimeError(f"There was an error while exporting the PyTorch model to ONNX: \n\n{message}"), - ) - exported_model = onnx.load_model_from_string(f.getvalue()) - - if self._runtime_options.enable_custom_autograd_function: - from ._custom_autograd_function_exporter import post_process_enabling_autograd_function - - exported_model = post_process_enabling_autograd_function(exported_model) - - if self._runtime_options.enable_zero_stage3_support: - from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat - - exported_model = post_processing_enable_zero_stage3_compat( - exported_model, - self._zero_stage3_param_map, - [name for name, _ in self._flattened_module.named_parameters()], - ) - - # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( - # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) - # find input info mismatch, will re-initialize the graph builder. - # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - - # Cache model for future runs - if cache_dir: - if not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - filename = os.path.join( - cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" - ) - self._logger.info(f"Caching model for future runs to {filename}.") - onnx.save(exported_model, filename) - - return exported_model - - def _set_device_from_module(self, inputs, kwargs): - """Get the device from the module and save it to self._device""" - - device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(inputs, kwargs) - if not self._device or self._device != device: - self._device = device - if not self._device: - raise wrap_exception( - ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") - ) - def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfiguration: graph_transformer_config = C.TrainingGraphTransformerConfiguration() graph_transformer_config.propagate_cast_ops_config = C.PropagateCastOpsConfiguration() @@ -563,68 +276,25 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati return graph_transformer_config @_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) - def _initialize_graph_builder(self): + def _initialize_graph_builder(self, post_export_processed_model_info: PostExportProcessedModelInfo): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" - self._mem_efficient_grad_management_is_enabled = ( - self._export_mode != torch.onnx.TrainingMode.EVAL - and self._runtime_options.enable_mem_efficient_grad_management - ) - - # We post process the exported model because the trainable parame might be changed, so this path is - # re-triggered by reinitialize_graph_builder. - exported_model = copy.deepcopy(self._onnx_models.exported_model) - self._onnx_models.processed_exported_model = exported_model - - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training - - # Override the options if model is not modified. - (self._mem_efficient_grad_management_is_enabled, exported_model, self._param_trigger_grad) = ( - post_processing_enable_mem_efficient_training( - exported_model, self._flattened_module.named_parameters(), self._device - ) - ) - - if self._runtime_options.run_symbolic_shape_infer: - exported_model = SymbolicShapeInference.infer_shapes( - exported_model, auto_merge=True, guess_output_rank=True - ) - - # All initializer names along with user inputs are a part of the onnx graph inputs - # since the onnx model was exported with the flag keep_initializers_as_inputs=True - # We need to use the raw exported model here since the graph inputs include both user inputrs and - # parameters. - onnx_initializer_names = {p.name for p in exported_model.graph.input} - - # TODO: PyTorch exporter bug: changes the initializer order in ONNX model - initializer_names = [ - name for name, _ in self._flattened_module.named_parameters() if name in onnx_initializer_names - ] - initializer_names_to_train = [ - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in onnx_initializer_names - ] - # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() - grad_builder_config.initializer_names = initializer_names - grad_builder_config.initializer_names_to_train = initializer_names_to_train - - input_names_require_grad = self._input_info.require_grad_names + grad_builder_config.initializer_names = ( + post_export_processed_model_info.onnx_graph_input_names + ) # containing both user defined and buffers/parameters. + grad_builder_config.initializer_names_to_train = ( + post_export_processed_model_info.onnx_graph_input_names_require_grad + ) # containing both user defined and parameters requiring gradients. + + input_names_require_grad = post_export_processed_model_info.onnx_graph_input_names_require_grad_user_defined if self._runtime_options.enable_zero_stage3_support: from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - - # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. - input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) - grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization @@ -636,41 +306,18 @@ def _initialize_graph_builder(self): # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(exported_model.SerializeToString(), grad_builder_config) - - raw_onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} - - raw_initializer_names = [ - name for name, _ in self._flattened_module.named_parameters() if name in raw_onnx_initializer_names - ] - raw_initializer_names_to_train = [ - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in raw_onnx_initializer_names - ] - - # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train - # a set (unordered_set in the backend) that does not require a copy on each reference. - self._graph_initializer_names = set(raw_initializer_names) - self._graph_initializer_names_to_train = set(raw_initializer_names_to_train) - - # Initializers can be cached and used since they are expected not to be re-instantiated - # between forward calls. - self._graph_initializers = [ - param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names - ] - - def signal_model_changed(self): - """Signals the execution manager to re-export the model on the next forward call""" - self._original_model_has_changed = True + self._graph_builder.initialize( + post_export_processed_model_info._post_export_processed_model.SerializeToString(), grad_builder_config + ) def __getstate__(self): state = copy.copy(self.__dict__) - # Remove any re-contructible/pybound object from the state + # Remove any re-constructible/pybound object from the state serialization_deny_list = [ "_onnx_models", "_graph_builder", "_graph_info", + "_graph_transition_manager", # Not pickled as it is re-constructed in __setstate__ "_execution_agent", "_torch_alloc", "_torch_free", @@ -686,82 +333,12 @@ def __setstate__(self, state): _utils.reinitialize_graph_execution_manager(self) - def _add_check_embedding_sparsity_hook(self): - """ - Add hook to check embedding sparsity and enable padding elimination if applicable. - 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. - 2. Register forward pre hook to the Embedding module and the hook will check sparsity of the embedding input. - 3. If the sparsity is below a threshold, enable padding elimination by adding FlagAndPrintDensity after the - output. GraphTransformer of PaddingElimination will check the FlagAndPrintDensity and do the actual - padding elimination graph modification. - 4. Return the hook handles for later removal. + self._initialize_graph_transition_manager() - """ - if not self._runtime_options.enable_embedding_sparse_optimizer or self._device.type != "cuda": - return [] - - def _embedding_hook(name, module, args): - ebd_input = args[0] - if ebd_input is None or not isinstance(ebd_input, torch.Tensor): - self._logger.warning("Embedding input is not a tensor.") - return None - - valid_token = torch.count_nonzero(ebd_input - module.padding_idx) - total_token = ebd_input.numel() - embed_density = float(valid_token) / float(total_token) * 100 - - if embed_density < 90: - self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) - self._runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density - return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") - else: - self._logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) - return None - - embedding_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): - if isinstance(sub_module, torch.nn.modules.sparse.Embedding): - if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: - embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) - - return embedding_hook_handles - - def _add_check_label_sparsity_hook(self): - """ - Add hook to check label sparsity and enable sceloss compute optimization if applicable. - 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. - 2. If the sparsity is below a threshold, enable sceloss compute optimization by adding FlagAndPrintDensity after the - output. GraphTransformer of InsertGatherBeforeSceLoss will check the FlagAndPrintDensity and do the actual - sceloss compute optimization graph modification. - - """ - if not self._runtime_options.enable_label_sparse_optimizer: - return None - - def _label_hook(name, module, args): - label_input = args[1] - if label_input is None or not isinstance(label_input, torch.Tensor): - self._logger.warning("Label input is not a tensor.") - return None - - valid_token = torch.count_nonzero(label_input - module.ignore_index) - total_token = label_input.numel() - label_density = float(valid_token) / float(total_token) * 100 - - if label_density < 90: - self._logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) - self._runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density - return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) - else: - self._logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) - return None - - label_check_hook_handles = [] - for name, sub_module in self._flattened_module.named_modules(): - if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): - label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) - - return label_check_hook_handles + @property + def _device(self): + # Graph transition manager is responsible for detecting and managing the device to use. + return self._graph_transition_manager._device @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): @@ -775,34 +352,24 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): enable sparsity-based optimization. """ - detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) - - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: - self._append_pull_weight_trigger_as_input(kwargs, detected_device) - - param_to_append_as_onnx_graph_inputs = [] - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger + if ( + self._runtime_options.enable_zero_stage3_support + or self._graph_transition_manager._post_export_processed_model_info.is_mem_efficient_grad_management_enabled + ): + self._append_pull_weight_trigger_as_input(kwargs, self._device) - param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters(), self._onnx_models.exported_model + if ( + self._runtime_inspector.memory_ob.is_enabled() + and not self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed + ): + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device ) - else: - param_to_append_as_onnx_graph_inputs = self._graph_initializers - - _io._combine_input_buffers_initializers( - param_to_append_as_onnx_graph_inputs, - self._graph_builder.get_graph_info().user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - detected_device, - self._runtime_inspector, - self._zero_stage3_param_map, - ) + self._runtime_inspector.memory_ob.collect_symbolic_dim_values( + self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_dynamic_axes_map, + prepared_input_map, + ) + self._runtime_inspector.memory_ob.symbolic_dim_collecting_completed = True if self._runtime_inspector._sceloss_module_to_ignore_density_map: self._runtime_options.label_sparsity_ratio = ",".join( @@ -828,19 +395,6 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic device=device, ).requires_grad_() - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import ( - MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, - ) - - kwargs[MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME] = torch.zeros( - MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), - device=device, - ).requires_grad_() - def _log_feature_stats(self): if get_rank() != 0: return diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py new file mode 100755 index 0000000000000..22627749c316c --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -0,0 +1,1058 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import copy +import inspect +import io +import logging +import os +from collections import OrderedDict +from functools import partial +from hashlib import md5 as hash_fn +from typing import Mapping, Sequence + +import onnx +import torch + +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +from onnxruntime.training.utils import ( + ORTModelInputOutputSchemaType, + ORTModelInputOutputType, + PrimitiveType, + onnx_dtype_to_pytorch_dtype, + unflatten_data_using_schema, +) + +from . import _io, _utils, export_context +from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception +from ._logger import LogColor, LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction +from ._onnx_models import _get_onnx_file_name, _save_model +from ._runtime_inspector import FlagAndPrintDensity, RuntimeInspector +from ._utils import check_function_has_param, get_rank +from ._zero_stage3_compatibility import stage3_export_context +from .options import DebugOptions, _RuntimeOptions + + +class ExportedModelInfo: + """Encapsulates the information of the exported model. + + After ONNX model export, the model info is collected and encapsulated in this class, including: + 1. The ONNX graph input names. + 2. Graph input requiring gradient information. + 3. The model's forward function signature and args/kwargs schema, used as a cache key to compare with the current + inputs to see if the model needs to be re-exported. + + This data structure is returned by the GraphTransitionManager._export_model method. + + """ + + def __init__( + self, + module_forward_args_schema: ORTModelInputOutputSchemaType, + module_forward_kwargs_schema: ORTModelInputOutputSchemaType, + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_names_user_defined: list[str], + onnx_graph_input_names_require_grad_user_defined: list[str], + exported_model: onnx.ModelProto, + module_forward_output_schema: ORTModelInputOutputSchemaType, + ): + # Used as a baseline to compare with the current inputs (args/kwargs) to see if the model needs to be re-exported. + self.module_forward_args_schema: ORTModelInputOutputSchemaType | None = module_forward_args_schema + self.module_forward_kwargs_schema: ORTModelInputOutputSchemaType | None = module_forward_kwargs_schema + + # Input names parsed and then flatten from the model's forward function signature + buffers + parameters (since we use + # keep_initializers_as_inputs=True for model export) + # Be noted: all inputs are used by the model for its compute. + self.onnx_graph_input_names: list[str] = copy.deepcopy(onnx_graph_input_names) + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contain both the user-defined input names, the buffer names, and the parameter names (since we use + # keep_initializers_as_inputs=True for model export) + # Be noted: all inputs are used by the model for its compute. + self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) + + # Input names parsed from the model's forward function signature. + # Be noted: all inputs are used by the model for its compute. + # The ONNX graph input names exclude the parameters, and buffers. + self.onnx_graph_input_names_user_defined = copy.deepcopy(onnx_graph_input_names_user_defined) + + # A subset of onnx_graph_input_names_user_defined. + self.onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + onnx_graph_input_names_require_grad_user_defined + ) + + # Exported model proto. + self.exported_model: onnx.ModelProto | None = exported_model + + # Used for unflattening the outputs from the ORT forward run. + self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema + + def __str__(self): + return f"""ExportedModelInfo class: + \tonnx_graph_input_names: {self.onnx_graph_input_names} + \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} + \tmodule_forward_args_schema: {self.module_forward_args_schema} + \tmodule_forward_kwargs_schema: {self.module_forward_kwargs_schema} + \tmodule_forward_output_schema: {self.module_forward_output_schema} + """ + + def __repr__(self): + return self.__str__() + + +class PostExportProcessedModelInfo: + """Encapsulates the information of the post-export processed model. + + After ONNX model post-export processing, the model info is collected and encapsulated in this class, including: + 1. The ONNX graph input names, dynamic axes, and input data accessor functions. + 2. Graph input requiring gradient information. + 3. The interface to construct the inputs for the ORT forward run, from original given inputs running for PyTorch. + 4. The interface to restore the outputs from the ORT forward run, back to the original data structure. + + """ + + def __init__( + self, + flatten_module: torch.nn.Module, + onnx_graph_input_names_user_defined: list[str], + onnx_graph_input_names_require_grad_user_defined: list[str], + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], + module_forward_output_schema: ORTModelInputOutputSchemaType, + post_export_processed_model: onnx.ModelProto, + onnx_graph_input_data_accessor_user_defined: dict[str, callable], + onnx_graph_input_const_as_tensor: dict[str, torch.device], + enable_mem_efficient_grad_management: bool, + ): + self._flattened_module = flatten_module + + # Input names parsed from the model's forward function signature. + # Be noted: all inputs are used by the model for its compute. + # The ONNX graph input names exclude the parameters, and buffers. + self.onnx_graph_input_names_user_defined = copy.deepcopy(onnx_graph_input_names_user_defined) + + # A subset of onnx_graph_input_names_user_defined. + self.onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + onnx_graph_input_names_require_grad_user_defined + ) + + # Input names for the pre-gradient-build graph. + # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed + # for example when memory efficient gradient management is enabled. + self.onnx_graph_input_names: list[str] = copy.deepcopy(onnx_graph_input_names) + + # A subset of onnx_graph_input_names. + # Input names that require gradients for the pre-gradient-build graph. + self.onnx_graph_input_names_require_grad: list[str] = copy.deepcopy(onnx_graph_input_names_require_grad) + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map + + self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model + + # A function to access the input data from the args and kwargs. + # If it is not None, the length is same as onnx_graph_input_names_user_defined. + # For i-th input name, we can use the i-th function to get the input data from args and kwargs. + self.onnx_graph_input_data_accessor_user_defined: dict[str, callable] | None = ( + onnx_graph_input_data_accessor_user_defined + ) + + self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor + + self.is_mem_efficient_grad_management_enabled = enable_mem_efficient_grad_management + + # Used for unflattening the outputs from the ORT forward run. + self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema + + # A buffer to hold the inputs for the ORT forward run. For performance, we reuse the same buffer for each run. + self._buffer_for_ort_runs: dict[str, torch.Tensor] | None = None + + def __str__(self): + return f"""PostExportProcessedModelInfo class: + \tonnx_graph_input_names: {self.onnx_graph_input_names} + \tonnx_graph_input_names_require_grad: {self.onnx_graph_input_names_require_grad} + \tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map} + \tonnx_graph_input_names_user_defined: {self.onnx_graph_input_names_user_defined} + \tonnx_graph_input_names_require_grad_user_defined: {self.onnx_graph_input_names_require_grad_user_defined} + \tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys() if self._buffer_for_ort_runs else None} + """ + + def __repr__(self): + return self.__str__() + + def construct_inputs( + self, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + constant_as_tensor: bool, + device: torch.device, + ): + """Constructs the inputs for the forward method + + The inputs are constructed in the order they appear in the model's forward function signature + """ + from ._mem_efficient_grad_mgmt import ( + MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE, + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + ) + + # First time construct the buffer for the ORT forward run. + if self._buffer_for_ort_runs is None: + self._buffer_for_ort_runs = OrderedDict() + + # Create the buffers for the inputs that are either parameters or buffers in the original module. + # For user inputs, fill with None for now, and will be filled dynamically during the forward run. + + parameter_names = {k: v for k, v in self._flattened_module.named_parameters()} + buffer_names = {k: v for k, v in self._flattened_module.named_buffers()} + + for input_name in self.onnx_graph_input_names: + if input_name in parameter_names: + self._buffer_for_ort_runs[input_name] = parameter_names[input_name] + elif input_name in buffer_names: + self._buffer_for_ort_runs[input_name] = buffer_names[input_name] + else: + self._buffer_for_ort_runs[input_name] = ( + None # Fill None for user input first, will be overridden later. + ) + + for name in self.onnx_graph_input_names_user_defined: + if self.is_mem_efficient_grad_management_enabled and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: + self._buffer_for_ort_runs[name] = torch.zeros( + MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + continue + + if name in self.onnx_graph_input_data_accessor_user_defined: + assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs" + data = self.onnx_graph_input_data_accessor_user_defined[name](args, kwargs) + if name in self.onnx_graph_input_const_as_tensor: + data = PrimitiveType.get_tensor(data, device) + self._buffer_for_ort_runs[name] = data + else: + raise wrap_exception( + ORTModuleONNXModelException, + RuntimeError(f"Input is present in ONNX graph but not provided: {name}."), + ) + + return self._buffer_for_ort_runs + + def restore_outputs(self, ort_flatten_outputs: list[torch.Tensor]): + """Restores the outputs from the ORT forward run, back to the original data structure""" + + try: + return unflatten_data_using_schema(ort_flatten_outputs, self.module_forward_output_schema) + except TypeError as e: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule fails to unflatten user output: {e}"), + ) from None + + +class GraphTransitionManager: + """Manage the graph transition from 1). PyTorch to ONNX export and 2). ONNX to ONNX post-export processing.""" + + def __init__( + self, + flatten_module: torch.nn.Module, + export_mode: int, + debug_options: DebugOptions, + runtime_options: _RuntimeOptions, + time_tracker: TimeTracker, + runtime_inspector: RuntimeInspector, + logger: logging.Logger, + ): + self._device = _utils._get_device_from_module(flatten_module) + self._export_mode = export_mode + + self._debug_options = debug_options + self._runtime_options = runtime_options + + self._export_extra_kwargs = {} + + self._logger = logger + + # Tracker for ORTModule model export. + self._time_tracker = time_tracker + + self._runtime_inspector = runtime_inspector + + # A signal to indicate if the original model has changed and need a re-export. + self._original_model_has_changed = False + + self._flatten_module = flatten_module + + # Forward function input parameters of the original module. + self._module_forward_func_parameters: list[inspect.Parameter] = list( + inspect.signature(self._flatten_module._original_module.forward).parameters.values() + ) + # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. + for input_parameter in self._module_forward_func_parameters: + if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: + logger.info("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!") + + # Model info collected from the original module's forward function signature and args/kwargs, used for ONNX export. + self._model_info_for_export: _io.ModelInfoForExport | None = None + self._exported_model_info: ExportedModelInfo | None = None + + # Model info after export and post export processing. + self._post_export_processed_model_info = None + + def get_post_processed_model( + self, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType] + ) -> tuple[bool, PostExportProcessedModelInfo]: + """Check if the post-export processed ONNX model can be reused, otherwise, reconstruct the model. + + Return True if the model can be reused, otherwise, return False. + The model can be reused when the following conditions are met: + a. The model has been exported before, and the inputs (args/outputs) schemas are the same as the previous ones. + b. If it is in training mode, the graph inputs requiring gradient are the same as the previous ones. + + """ + + if self._device is None: + device = _utils.get_device_from_module_and_inputs(self._flatten_module._original_module, args, kwargs) + if not self._device or self._device != device: + self._device = device + if not self._device: + raise wrap_exception( + ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") + ) + + # Extract the schema from the args and kwargs, and compare it with the pre-exported one if already exported. + cur_model_info_for_export = _io.parse_inputs_for_onnx_export( + self._module_forward_func_parameters, + args, + kwargs, + True, + self._device, + self._export_mode, + self._logger, + self._export_extra_kwargs, + ) + + need_export_model = GraphTransitionManager._export_check( + prev_exported_model_info=self._exported_model_info, + original_model_has_changed=self._original_model_has_changed, + cur_args_schema=cur_model_info_for_export.onnx_graph_input_arg_schema, + cur_kwargs_schema=cur_model_info_for_export.onnx_graph_input_kwarg_schema, + logger=self._logger, + ) + + if need_export_model: + # Note related to the _io.FlattenedModule export!!! + # + # The _io.FlattenedModule serves as a module wrapper designed to support tuple inputs and outputs for + # PyTorch run during ONNX export. (Remember the PyTorch exporter handles tuple inputs and outputs better.) + # Internally, it facilitates the acceptance of tuple inputs and the generation of tuple outputs by invoking + # the original module's forward function. The workflow involves the following steps: + + # 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and schemas for the + # flattened args and kwargs are generated. This schemas are essential for the subsequent un-flattening + # process. + + # 2. The flattened inputs (args + kwargs) are passed to the _io.FlattenedModule's forward run. + + # 3. The args schema and kwargs schema, etc are conveyed to the _io.FlattenedModule by setting the + # corresponding attributes. + + # 4. Within the _io.FlattenedModule's forward run, the inputs are un-flattened to the original args and + # kwargs using the associated schemas, and then they are passed to the original module's forward function. + + # 5. Upon the completion of the forward function, the outputs from the original module are flattened and + # returned to the caller. + + # 6. The 1-D flattened output tensors retain the same order as the outputs from the ONNX Runtime (ORT) + # forward run. To facilitate un-flattening during subsequent ORT runs, the output schema is saved as + # an attribute named `_output_schema` in the _io.FlattenedModule. + + copied_args = copy.copy(args) + copied_kwargs = copy.copy(kwargs) + flatten_inputs = [] + + # This looks a bit duplicated with `extract_data_and_schema` function, but this might be better to + # defined as a specialized logic that is the counter-part of `parse_inputs_for_onnx_export`, which handles + # args and kwargs separately. + for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor_user_defined.items(): + d = data_accessor(copied_args, copied_kwargs) + if name in cur_model_info_for_export.onnx_graph_input_const_as_tensor: + flatten_inputs.append( + PrimitiveType.get_tensor( + d, + cur_model_info_for_export.onnx_graph_input_const_as_tensor[name], + ) + ) + else: + if isinstance(d, torch.Tensor): + flatten_inputs.append(d) + + # Ignore all other non-tensor inputs. + + self._flatten_module._device = self._device + self._flatten_module._args_schema = cur_model_info_for_export.onnx_graph_input_arg_schema + self._flatten_module._kwargs_schema = cur_model_info_for_export.onnx_graph_input_kwarg_schema + self._flatten_module._num_positionals = cur_model_info_for_export.num_positional_args + + self._logger.info(f"do_export started, model info for export: {cur_model_info_for_export}") + + ( + exported_model, + module_output_schema, # Retrieved from _io.FlattenedModule's _output_schema + onnx_graph_input_names, + onnx_graph_input_names_require_grad, + ) = GraphTransitionManager._export_model( + flattened_module=self._flatten_module, + model_info_for_export=cur_model_info_for_export, + flatten_module_inputs=flatten_inputs, + deepcopy_before_model_export=self._runtime_options.deepcopy_before_model_export, + device=self._device, + ortmodule_cache_dir=self._runtime_options.ortmodule_cache_dir, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + enable_embedding_sparse_optimizer=self._runtime_options.enable_embedding_sparse_optimizer, + onnx_opset_version=self._runtime_options.onnx_opset_version, + stage3_param_handle=self, + debug_options=self._debug_options, + time_tracker=self._time_tracker, + runtime_inspector=self._runtime_inspector, + logger=self._logger, + ) + + # Get the intersection of all user-defined input names (parsed from forward function signature) and + # the exported model input names including both user-defined input names and training parameter/buffer names. + # It is possible some user-defined input names are not in the exported model input names, if it is not used + # by the model for its compute. + onnx_graph_input_names_user_defined = [ + input_name + for input_name in cur_model_info_for_export.onnx_graph_input_names + if input_name in onnx_graph_input_names + ] + onnx_graph_input_names_require_grad_user_defined = [ + input_name + for input_name in cur_model_info_for_export.onnx_graph_input_names_require_grad + if input_name in onnx_graph_input_names_require_grad + ] + + self._exported_model_info = ExportedModelInfo( + module_forward_args_schema=cur_model_info_for_export.onnx_graph_input_arg_schema, + module_forward_kwargs_schema=cur_model_info_for_export.onnx_graph_input_kwarg_schema, + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=onnx_graph_input_names_require_grad, + onnx_graph_input_names_user_defined=onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined=onnx_graph_input_names_require_grad_user_defined, + exported_model=exported_model, + module_forward_output_schema=module_output_schema, + ) + + self._model_info_for_export = cur_model_info_for_export + + # Reset the signal to indicate the original model has changed. + self._original_model_has_changed = False + + # Save the exported model + if self._debug_options.save_onnx_models.save: + _save_model( + self._exported_model_info.exported_model, + os.path.join( + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "torch_exported", self._export_mode + ), + ), + ) + + self._logger.info(f"do_export completed, exported graph infos: {self._exported_model_info}") + + need_re_processed = False + if need_export_model: + need_re_processed = True + else: + need_re_processed, updated_onnx_graph_input_requires_grads = GraphTransitionManager._reprocess_check( + flatten_module=self._flatten_module, + exported_model_info=self._exported_model_info, + export_mode=self._export_mode, + model_info_for_export=self._model_info_for_export, + args=args, + kwargs=kwargs, + ) + if need_re_processed: + # Update the onnx_graph_input_names_require_grads to make it a new default baseline to compare + # using new iteration data. + self._exported_model_info.onnx_graph_input_names_require_grad = updated_onnx_graph_input_requires_grads + + if need_re_processed: + # At this point, the exported model is ready, and we can start post-export processing. + self._post_export_processed_model_info = GraphTransitionManager._post_export_process( + flatten_module=self._flatten_module, + export_mode=self._export_mode, + exported_model_info=self._exported_model_info, + model_info_for_export=self._model_info_for_export, + enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function, + enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support, + run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer, + stage3_param_handle=self, + enable_mem_efficient_grad_management=self._export_mode != torch.onnx.TrainingMode.EVAL + and self._runtime_options.enable_mem_efficient_grad_management, + logger=self._logger, + ) + + # Save the post_processed model + if self._debug_options.save_onnx_models.save: + _save_model( + self._post_export_processed_model_info._post_export_processed_model, + os.path.join( + self._debug_options.save_onnx_models.path, + _get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "post_processed", self._export_mode + ), + ), + ) + + return need_re_processed, self._post_export_processed_model_info + + @staticmethod + def _export_check( + prev_exported_model_info: ExportedModelInfo | None, + original_model_has_changed: bool, + cur_args_schema: ORTModelInputOutputSchemaType, + cur_kwargs_schema: ORTModelInputOutputSchemaType, + logger: logging.Logger, + ): + """Check if the model needs to be exported, if yes, return True. + + If either of the following conditions is met, return True: + 1. The model has never been exported before. + 2. The original_model_has_changed is True. + 3. The model input schema parsed from args and kwargs has changed. + """ + + need_export_model = prev_exported_model_info is None # never exported before + + need_export_model = need_export_model or original_model_has_changed + + need_export_model = ( + need_export_model + or cur_args_schema != prev_exported_model_info.module_forward_args_schema + or cur_kwargs_schema != prev_exported_model_info.module_forward_kwargs_schema + ) + + logger.info(f"_export_check completed - need_export_model: {need_export_model}") + + return need_export_model + + @staticmethod + def _reprocess_check( + flatten_module: _io._FlattenedModule, + exported_model_info: ExportedModelInfo, + export_mode: int, + model_info_for_export: _io.ModelInfoForExport, + args: Sequence[ORTModelInputOutputType], + kwargs: Mapping[str, ORTModelInputOutputType], + ) -> bool: + """Check if the exported model needs to be re-processed, if yes, + return True and the updated onnx_graph_input_requires_grads. + + For the following cases, return True: + 1. The export mode is TRAINING and the model's input names (including both user input and module parameters) + requiring gradient change. + """ + if export_mode == torch.onnx.TrainingMode.TRAINING: + # If inputs requiring gradient change from forward to the next, the gradient graph builder + # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad + + # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. + # This can happen when the user changes the model parameters after the onnx export. + # Model may have unused params dropped after export, so we only check those inputs existing in onnx graph. + + onnx_graph_input_requires_grads = [] + parameter_names = {k: v for k, v in flatten_module.named_parameters()} + for input_name in exported_model_info.onnx_graph_input_names: + if input_name in exported_model_info.onnx_graph_input_names_user_defined: + assert ( + input_name in model_info_for_export.onnx_graph_input_data_accessor_user_defined + ), f"{input_name} model_info_for_export.onnx_graph_input_data_accessor_user_defined" + # We assume the data accessor should be the same as the one used for the previous export, because + # there is args and kwargs schema check during export check phase. + if model_info_for_export.onnx_graph_input_data_accessor_user_defined[input_name]( + args, kwargs + ).requires_grad: + onnx_graph_input_requires_grads.append(input_name) + else: + assert input_name in parameter_names, f"{input_name} not exist parameter_names" + if parameter_names[input_name].requires_grad: + onnx_graph_input_requires_grads.append(input_name) + + if onnx_graph_input_requires_grads == exported_model_info.onnx_graph_input_names_require_grad: + return False, [] + return True, onnx_graph_input_requires_grads + + return False, [] + + @staticmethod + def _post_export_process( + flatten_module, + export_mode, + exported_model_info: ExportedModelInfo, + model_info_for_export: _io.ModelInfoForExport, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + run_symbolic_shape_infer: bool, + stage3_param_handle: type, + enable_mem_efficient_grad_management: bool, + logger: logging.Logger, + ): + """Post process the exported model, generate the processed model which will be used for initializing graph builder.""" + + # Deepcopy the exported model, in case modification affects the exported model. + post_processed_model = copy.deepcopy(exported_model_info.exported_model) + + if enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + post_processed_model = post_process_enabling_autograd_function(post_processed_model) + + if run_symbolic_shape_infer: + # MUST call symbolic shape inference after custom autograd function post-processing is done, + # Otherwise, there is no ctx output for PythonOp. + post_processed_model = GraphTransitionManager._infer_shapes(post_processed_model) + + if export_mode == torch.onnx.TrainingMode.TRAINING: + if enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + post_processed_model = post_processing_enable_zero_stage3_compat( + post_processed_model, + stage3_param_handle._zero_stage3_param_map, + [name for name, _ in flatten_module.named_parameters()], + ) + + onnx_graph_input_names_user_defined = copy.deepcopy(exported_model_info.onnx_graph_input_names_user_defined) + onnx_graph_input_names_require_grad_user_defined = copy.deepcopy( + exported_model_info.onnx_graph_input_names_require_grad_user_defined + ) + onnx_graph_input_names = copy.deepcopy(exported_model_info.onnx_graph_input_names) + onnx_graph_input_names_require_grad = copy.deepcopy(exported_model_info.onnx_graph_input_names_require_grad) + + if enable_mem_efficient_grad_management: + # Remove those trainable parameters from graph input, as they will be retrieved from weight pull node. + from ._mem_efficient_grad_mgmt import get_params_connected_to_pull_param_trigger + + # MUST call this before post_processing_enable_mem_efficient_training, otherwise, the onnx graph input + # will be modified. + parameter_not_as_graph_input_names = get_params_connected_to_pull_param_trigger( + flatten_module.named_parameters(), post_processed_model + ) + + if len(parameter_not_as_graph_input_names) > 0: + for k in parameter_not_as_graph_input_names: + if k in onnx_graph_input_names: + onnx_graph_input_names.remove(k) + + if k in onnx_graph_input_names_require_grad: + onnx_graph_input_names_require_grad.remove(k) + + from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME + + # Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph. + onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME) + + from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training + + # Override the options if model is not modified. + + ( + enable_mem_efficient_grad_management, # Update the flag to indicate the mem efficient grad management is enabled. + post_processed_model, + stage3_param_handle._param_trigger_grad, + ) = post_processing_enable_mem_efficient_training( + post_processed_model, flatten_module.named_parameters(), parameter_not_as_graph_input_names + ) + + if run_symbolic_shape_infer: + post_processed_model = SymbolicShapeInference.infer_shapes( + post_processed_model, auto_merge=True, guess_output_rank=True + ) + + post_export_processed_model_info = PostExportProcessedModelInfo( + flatten_module, + onnx_graph_input_names_user_defined, + onnx_graph_input_names_require_grad_user_defined, + onnx_graph_input_names, + onnx_graph_input_names_require_grad, + model_info_for_export.onnx_graph_input_dynamic_axes_map, + exported_model_info.module_forward_output_schema, + post_processed_model, + model_info_for_export.onnx_graph_input_data_accessor_user_defined, + model_info_for_export.onnx_graph_input_const_as_tensor, + enable_mem_efficient_grad_management, + ) + + logger.info( + f"_post_export_process completed, post-export processed graph infos: {post_export_processed_model_info}" + ) + + return post_export_processed_model_info + + @staticmethod + def _infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto: + """Infer shapes for the exported model.""" + + model = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True) + return model + + @staticmethod + @TrackTimeForStaticFunction(ORTModuleInitPhase.EXPORT) + @SuppressLogs(ORTModuleInitPhase.EXPORT, is_ort_filter=False) + def _export_model( + *, + flattened_module: torch.nn.Module, + model_info_for_export: _io.ModelInfoForExport, + flatten_module_inputs: Sequence[ORTModelInputOutputType], + deepcopy_before_model_export: bool, + device: torch.device, + ortmodule_cache_dir: str, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + enable_embedding_sparse_optimizer: bool, + onnx_opset_version: int, + stage3_param_handle: type, + debug_options: DebugOptions, + time_tracker: TimeTracker, # time_tracker MUST be provided here to support TrackTimeForStaticFunction + runtime_inspector: RuntimeInspector, + logger: logging.Logger, + ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: + + # Add hooks to check the sparsity of the embedding and label inputs during the export. + embedding_hook_handles = GraphTransitionManager._add_check_embedding_sparsity_hook( + enable_embedding_sparse_optimizer, device, logger, runtime_inspector, flattened_module + ) + label_hook_handles = GraphTransitionManager._add_check_label_sparsity_hook( + enable_embedding_sparse_optimizer, logger, runtime_inspector, flattened_module + ) + + # Record random states here and restore later in case any of them gets changed during the export, + # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. + random_states = _utils.get_random_states() + + torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING + from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step + + with export_context(), no_increase_global_step(): + exported_model, module_output_schema = GraphTransitionManager._get_exported_model( + flattened_module=flattened_module, + model_info_for_export=model_info_for_export, + flatten_module_inputs=flatten_module_inputs, + deepcopy_before_model_export=deepcopy_before_model_export, + device=device, + ortmodule_cache_dir=ortmodule_cache_dir, + enable_custom_autograd_function=enable_custom_autograd_function, + enable_zero_stage3_support=enable_zero_stage3_support, + onnx_opset_version=onnx_opset_version, + torch_exporter_verbose_log=torch_exporter_verbose_log, + stage3_param_handle=stage3_param_handle, + logger=logger, + ) + + onnx_graph_input_names = [input.name for input in exported_model.graph.input] + parameter_names = [name for name, _ in flattened_module.named_parameters()] + onnx_graph_input_names_require_grad = [ + input.name + for input in exported_model.graph.input + if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad + ] + + # Restore the recorded random states + _utils.set_random_states(random_states) + + # Clean up all hooks. + for hook in embedding_hook_handles: + hook.remove() + + for hook in label_hook_handles: + hook.remove() + + return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad + + @staticmethod + def _get_exported_model( + flattened_module: torch.nn.Module, + model_info_for_export: _io.ModelInfoForExport, + flatten_module_inputs: Sequence[ORTModelInputOutputType], + deepcopy_before_model_export: bool, + device: torch.device, + ortmodule_cache_dir: str, + enable_custom_autograd_function: bool, + enable_zero_stage3_support: bool, + onnx_opset_version: int, + torch_exporter_verbose_log: bool, + stage3_param_handle: type, + logger: logging.Logger, + ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType]: + """Exports PyTorch `flattened_module` to ONNX for inferencing or training.""" + + need_deep_copy = deepcopy_before_model_export and _io.can_module_be_deep_cloned(flattened_module, device) + if not need_deep_copy: + if deepcopy_before_model_export: + logger.warning( + "Since the user requested not to deep copy this model, " + "the initial weights may not be preserved and could change slightly during the forward run. " + "This could cause a minor difference between the ORTModule and the PyTorch run for the " + "first iteration. The computation will proceed as normal, but this should be noted." + ) + else: + logger.warning( + "Due to the limited GPU memory execution manager does not create a deep copy of this model. " + "Therefore, the initial weights might be slightly altered during the forward run. " + "This could result in a minor discrepancy between the ORTModule and the PyTorch run for the " + "first iteration. The computation will continue as usual, but this should be noted." + ) + ( + output_names, + output_dynamic_axes, + module_output_schema, + ) = _io.parse_outputs_for_onnx_export_and_extract_schema( + flattened_module, flatten_module_inputs, logger, need_deep_copy + ) + + # Combine the dynamic axes from inputs and outputs + dynamic_axes = copy.deepcopy(model_info_for_export.onnx_graph_input_dynamic_axes_map) + + dynamic_axes.update(output_dynamic_axes) + + logger.info("Exporting the PyTorch model to ONNX...") + + # Leverage cached model if available + cache_dir = ortmodule_cache_dir + if cache_dir: + filename = os.path.join( + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + ) + if os.path.exists(cache_dir) and os.path.isfile(filename): + logger.warning( + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." + ) + exported_model = onnx.load(filename) + return exported_model, module_output_schema + + # Export torch.nn.Module to ONNX + f = io.BytesIO() + + # Deepcopy inputs, since input values may change after model run. + # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). + # Therefore, deepcopy only the data component of the input tensors for export. + kwargs = {} + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*flatten_module_inputs, **kwargs) + assert len(sample_kwargs_copy) == 0, "Currently, kwargs are not supported for ONNX export." + sample_inputs_as_tuple = sample_inputs_copy + + # Ops behaving differently under train/eval mode need to be exported with the + # correct training flag to reflect the expected behavior. + # For example, the Dropout node in a model is dropped under eval mode. + assert model_info_for_export.export_mode is not None, "Please use a concrete instance of ExecutionManager" + + try: + with torch.no_grad(), stage3_export_context( + enable_zero_stage3_support, stage3_param_handle, flattened_module + ): + required_export_kwargs = { + "input_names": model_info_for_export.onnx_graph_input_names, # did not contains parameters as its input yet + "output_names": output_names, + "opset_version": onnx_opset_version, + "do_constant_folding": False, + "training": model_info_for_export.export_mode, + "dynamic_axes": dynamic_axes, + "verbose": torch_exporter_verbose_log, + "export_params": False, + "keep_initializers_as_inputs": True, + } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs["autograd_inlining"] = not enable_custom_autograd_function + + invalid_args = model_info_for_export.export_extra_kwargs.keys() & required_export_kwargs.keys() + + if len(invalid_args) != 0: + error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." + raise RuntimeError(error_msg) + + torch.onnx.export( + flattened_module, + sample_inputs_as_tuple, + f, + **required_export_kwargs, + **model_info_for_export.export_extra_kwargs, + ) + except Exception as e: + message = _utils.get_exception_as_string(e) + + # Special handling when Huggingface transformers gradient checkpoint usage pattern found. + # For new versions of PyTorch 2, tracing torch.utils.checkpoint.checkpoint will be failed like this: + # File "microsoft/phi-2/b10c3eba545ad279e7208ee3a5d644566f001670/modeling_phi.py", line 919, in forward + # layer_outputs = self._gradient_checkpointing_func( + # File "/site-packages/torch/_compile.py", line 24, in inner + # return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + # File "/site-packages/torch/_dynamo/eval_frame.py", line 470, in _fn + # raise RuntimeError( + # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment. + if ( + "_gradient_checkpointing_func" in message + and "Detected that you are using FX to torch.jit.trace a dynamo-optimized function" in message + ): + is_ckpt_activation_allowed = int(os.getenv("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", "0")) == 1 + notes = ( + " Your model is running with gradient checkpointing, yet the PyTorch exporter\n" + " failed during tracing the graph. Try to enable ORTModule's\n" + " gradient checkpointing (a.k.a. Transformer layerwise subgraph recompute)\n" + " using `export ORTMODULE_MEMORY_OPT_LEVEL=1` for similar or even better memory efficiency.\n" + ) + if is_ckpt_activation_allowed: + # If the user allows the gradient checkpointing export, we should inform the user to disable it, + # to make layerwise recompute work. + notes += ( + " We also notice your setting `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1`,\n" + " which enables gradient checkpointing torch.autograd.Functions(s) to export.\n" + " To enable ORTModule's layerwise recompute, it needs to be turned OFF by\n" + " `export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=0`.\n" + ) + + logger.error( + f"{LogColor.RED}\n" + "******************************** IMPORTANT NOTE *******************************\n" + f"{notes}" + "*******************************************************************************\n" + f"{LogColor.ENDC}\n" + ) + + raise wrap_exception( # noqa: B904 + ORTModuleONNXModelException, + RuntimeError( + f"There was an error while exporting the PyTorch model to ONNX: " + f"\n\n{_utils.get_exception_as_string(e)}" + ), + ) + exported_model = onnx.load_model_from_string(f.getvalue()) + + # Cache model for future runs + if cache_dir: + if not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.join( + cache_dir, f"{hash_fn(str(flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" + ) + logger.info(f"Caching model for future runs to {filename}.") + onnx.save(exported_model, filename) + + return exported_model, module_output_schema + + def signal_model_changed(self): + """Signals the execution manager to re-export the model on the next forward call""" + self._original_model_has_changed = True + + @staticmethod + def _add_check_embedding_sparsity_hook( + enable_embedding_sparse_optimizer: bool, + device: torch.device, + logger: logging.Logger, + runtime_inspector: RuntimeInspector, + flattened_module: torch.nn.Module, + ) -> list: + """ + Add hook to check embedding sparsity and enable padding elimination if applicable. + 1. Iterate through all modules to find Embedding modules with padding_idx >= 0. + 2. Register forward pre hook to the Embedding module and the hook will check sparsity of the embedding input. + 3. If the sparsity is below a threshold, enable padding elimination by adding FlagAndPrintDensity after the + output. GraphTransformer of PaddingElimination will check the FlagAndPrintDensity and do the actual + padding elimination graph modification. + 4. Return the hook handles for later removal. + + """ + if not enable_embedding_sparse_optimizer or device.type != "cuda": + return [] + + def _embedding_hook(name, module, args): + ebd_input = args[0] + if ebd_input is None or not isinstance(ebd_input, torch.Tensor): + logger.warning("Embedding input is not a tensor.") + return None + + valid_token = torch.count_nonzero(ebd_input - module.padding_idx) + total_token = ebd_input.numel() + embed_density = float(valid_token) / float(total_token) * 100 + + if embed_density < 90: + logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) + runtime_inspector._embedding_module_to_padding_density_map[name] = embed_density + return FlagAndPrintDensity.apply(args[0], module.padding_idx, "embedding") + else: + logger.info("Embedding sparsity-based optimization is OFF for density: %.0f%%", embed_density) + return None + + embedding_hook_handles = [] + for name, sub_module in flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.sparse.Embedding): + if sub_module.padding_idx is not None and sub_module.padding_idx >= 0: + embedding_hook_handles.append(sub_module.register_forward_pre_hook(partial(_embedding_hook, name))) + + return embedding_hook_handles + + @staticmethod + def _add_check_label_sparsity_hook( + enable_label_sparse_optimizer: bool, + logger: logging.Logger, + runtime_inspector: RuntimeInspector, + flattened_module: torch.nn.Module, + ) -> list: + """ + Add hook to check label sparsity and enable sceloss compute optimization if applicable. + 1. Register forward pre hook to the sceloss module in the model and the hook will check sparsity of the label input. + 2. If the sparsity is below a threshold, enable sceloss compute optimization by adding FlagAndPrintDensity after the + output. GraphTransformer of InsertGatherBeforeSceLoss will check the FlagAndPrintDensity and do the actual + sceloss compute optimization graph modification. + + """ + if not enable_label_sparse_optimizer: + return None + + def _label_hook(name, module, args): + label_input = args[1] + if label_input is None or not isinstance(label_input, torch.Tensor): + logger.warning("Label input is not a tensor.") + return None + + valid_token = torch.count_nonzero(label_input - module.ignore_index) + total_token = label_input.numel() + label_density = float(valid_token) / float(total_token) * 100 + + if label_density < 90: + logger.info("Label sparsity-based optimization is ON for density: %.0f%%", label_density) + runtime_inspector._sceloss_module_to_ignore_density_map[name] = label_density + return (args[0], FlagAndPrintDensity.apply(args[1], module.ignore_index, "label")) + else: + logger.info("Label sparsity-based optimization is OFF for density: %.0f%%", label_density) + return None + + label_check_hook_handles = [] + for name, sub_module in flattened_module.named_modules(): + if isinstance(sub_module, torch.nn.modules.loss.CrossEntropyLoss): + label_check_hook_handles.append(sub_module.register_forward_pre_hook(partial(_label_hook, name))) + + return label_check_hook_handles diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 642dc9b0f4dd6..61db462ad3bb8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -11,11 +11,10 @@ from onnxruntime.capi import _pybind_state as C -from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils +from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import unflatten_user_output from ._logger import ORTModuleInitPhase, TrackTime from ._utils import save_tuning_results, set_tuning_results from .options import DebugOptions, _SkipCheck @@ -109,15 +108,19 @@ def forward(self, *inputs, **kwargs): build_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._onnx_models.exported_model + or not self._graph_transition_manager._exported_model_info ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) # Exporting module to ONNX for the first time - build_graph = self._export_model(*inputs, **kwargs) + + ( + build_graph, + post_export_processed_model_info, + ) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs) if build_graph: - # If model was exported, then initialize the graph builder. - self._initialize_graph_builder() + # TODO(): do we need call it for inferencing mode??? + self._initialize_graph_builder(post_export_processed_model_info) # Build the inference graph if build_graph: @@ -134,7 +137,7 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent ): - module_device = _utils.get_device_from_module(self._original_module) + module_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) create_execution_session = ( build_graph @@ -144,7 +147,7 @@ def forward(self, *inputs, **kwargs): _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != module_device: - self._device = module_device + self._graph_transition_manager._device = module_device if create_execution_session: # Create execution session creates the inference_session @@ -160,23 +163,15 @@ def forward(self, *inputs, **kwargs): if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - prepared_input_list = _io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device, - self._runtime_inspector, - self._zero_stage3_param_map, + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device ) user_outputs, _ = InferenceManager.execution_session_run_forward( self._execution_agent, self._onnx_models.optimized_model, self._device, - *prepared_input_list, + *prepared_input_map.values(), ) if ( @@ -188,7 +183,8 @@ def forward(self, *inputs, **kwargs): self._execution_agent._inference_session, False, self._runtime_options.tuning_results_path ) - return unflatten_user_output(self._module_output_schema, user_outputs) + return self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) + except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 1ba62194bf63e..8ad3d0df3e4fa 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -7,10 +7,10 @@ import gc import inspect from collections import OrderedDict, abc +from functools import partial from logging import Logger -from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple -import onnx import torch from onnxruntime.training.utils import ( @@ -20,9 +20,9 @@ extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_io_helper import _TensorStub -from ._fallback import ORTModuleIOError, ORTModuleONNXModelException, wrap_exception -from ._runtime_inspector import RuntimeInspector +from ._fallback import ORTModuleIOError, wrap_exception class _OutputIdentityOp(torch.autograd.Function): @@ -76,195 +76,6 @@ def symbolic(g, self): return g.op("Identity", self) -def flatten_kwargs(kwargs, device): - def _flatten_kwargs(value, name): - if PrimitiveType.is_primitive_type(value): - flattened_kwargs[name] = PrimitiveType.get_tensor(value, device) - elif isinstance(value, torch.Tensor): - flattened_kwargs[name] = value - elif isinstance(value, abc.Sequence): - # If the input is a sequence (like a list), expand the list so that - # each element of the list has a corresponding entry in the flattened - # kwargs dict - for idx, val in enumerate(value): - _flatten_kwargs(val, f"{name}_{idx}") - elif isinstance(value, abc.Mapping): - # If the input is a mapping (like a dict), expand the dict so that - # each element of the dict has an entry in the flattened kwargs dict - for key, val in value.items(): - _flatten_kwargs(val, f"{name}_{key}") - - flattened_kwargs = {} - for key, value in kwargs.items(): - _flatten_kwargs(value, key) - - return flattened_kwargs - - -class _InputInfo: - def __init__( - self, - names: List[str], - shape: List[List[int]], - require_grad_names: Optional[List[str]] = None, - dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, - schema: Optional[ORTModelInputOutputSchemaType] = None, - num_positionals=0, - ): - self.names: List[str] = names - self.shape: List[List[int]] = shape - self.require_grad_names: List[str] = require_grad_names if require_grad_names else [] - self.dynamic_axes: Dict[str, Dict[int, str]] = dynamic_axes if dynamic_axes else {} - self.schema: ORTModelInputOutputSchemaType = schema if schema else [] - self.num_positionals = num_positionals - self.kwargs = None - - def __repr__(self) -> str: - return f"""_InputInfo class: - \tNames: {self.names} - \tShape: {self.shape} - \tRequire gradient: {self.require_grad_names} - \tDynamic axes: {self.dynamic_axes} - \tSchema: {self.schema} - \t#Positionals (total): {self.num_positionals}""" - - def flatten( - self, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device, - ) -> Sequence[ORTModelInputOutputType]: - """Flatten args and kwargs in a single tuple of tensors with strict ordering""" - - ret = [PrimitiveType.get_tensor(arg, device) if PrimitiveType.is_primitive_type(arg) else arg for arg in args] - flattened_kwargs = flatten_kwargs(kwargs, device) - ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs] - self.kwargs = kwargs - - # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter - # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise. - if not kwargs: - ret.append({}) - - return ret - - def unflatten( - self, flat_args: Sequence[ORTModelInputOutputType] - ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: - """Unflatten tuple of tensors into args and kwargs""" - - args = tuple(flat_args[: self.num_positionals]) - kwargs = self.kwargs - self.kwargs = None - return args, kwargs - - -def _combine_input_buffers_initializers( - params: List[torch.nn.parameter.Parameter], - onnx_input_names: List[str], - input_info: Optional[_InputInfo], - named_buffer: Iterator[Tuple[str, torch.Tensor]], - inputs: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], - device: torch.device, - rt_inspector: RuntimeInspector, - zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], -): - """Creates forward `*inputs` list from user input and PyTorch initializers - - ONNX Runtime forward requires an ordered list of: - * User input: computed from forward InferenceSession - * Initializers: computed from original PyTorch model parameters. - """ - - def _expand_inputs(current_input, non_none_inputs, name=""): - # The exporter handles input lists by expanding them so that each - # element of the list is its own input. - # ORTModule must match this behavior by also expanding the inputs. - if current_input is None or isinstance(current_input, str): - # Drop all None and string inputs - return - if isinstance(current_input, abc.Sequence): - # If the input is a sequence (like a list), expand the list so that - # each element of the list is an input by itself - for i, inp in enumerate(current_input): - _expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i)) - elif isinstance(current_input, abc.Mapping): - # If the input is a mapping (like a dict), expand the dict so that - # each element of the dict is an input by itself - for key, val in current_input.items(): - _expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key) - else: - # else just collect all the non none inputs within non_none_inputs - if isinstance(non_none_inputs, abc.Sequence): - non_none_inputs.append(current_input) - elif isinstance(non_none_inputs, abc.Mapping): - non_none_inputs[name] = current_input - - # User inputs - non_none_inputs = [] - _expand_inputs(inputs, non_none_inputs) - flattened_kwargs_inputs = {} - _expand_inputs(kwargs, flattened_kwargs_inputs) - buffer_names_dict = None - result = [] - onnx_input_to_value_map = OrderedDict() - - for input_idx, name in enumerate(onnx_input_names): - inp = None - if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None: - # Only use keywords coming from user that are expected by ONNX model - inp = flattened_kwargs_inputs[name] - - if inp is None: - try: - # Only use positionals coming from user that are expected by ONNX model - # if input_idx >= len(input_info.names), IndexError will be thrown - if name != input_info.names[input_idx]: - # When ONNX drops unused inputs, get correct index from user input - # if name is not in input_info.names, ValueError will be thrown - input_idx = input_info.names.index(name) # noqa: PLW2901 - inp = non_none_inputs[input_idx] - except (IndexError, ValueError): - # ONNX input name is not present in input_info.names. - pass - - if inp is None: - # Registered buffers are translated to user_input+initializer in ONNX - if buffer_names_dict is None: - buffer_names_dict = {buffer_name: i for buffer_name, i in named_buffer} - try: # noqa: SIM105 - inp = buffer_names_dict[name] - except KeyError: - # ONNX input name is not present in the registered buffer dict. - pass - - if inp is not None: - if PrimitiveType.is_primitive_type(inp): - inp = PrimitiveType.get_tensor(inp, device) - - result.append(inp) - onnx_input_to_value_map[name] = inp - else: - raise wrap_exception( - ORTModuleONNXModelException, RuntimeError(f"Input is present in ONNX graph but not provided: {name}.") - ) - - # params is a list of all initializers known to the onnx graph - if zero_stage3_offload_param_map: - for p in params: - if p not in zero_stage3_offload_param_map.values(): - result.append(p) - else: - result.extend(params) - - if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed: - rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map) - rt_inspector.memory_ob.symbolic_dim_collecting_completed = True - - return result - - def deepcopy_model_input( *args, **kwargs ) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: @@ -288,113 +99,153 @@ def extract_tensor(value): return sample_args_copy, sample_kwargs_copy -def unflatten_user_output(output_schema: Optional[ORTModelInputOutputSchemaType], outputs: List[torch.Tensor]): +def _extract_schema( + data: ORTModelInputOutputType, device +) -> Tuple[Sequence[ORTModelInputOutputType], ORTModelInputOutputSchemaType]: try: - return unflatten_data_using_schema(outputs, output_schema) + flatten_data, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) + return flatten_data, schema except TypeError as e: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule fails to unflatten user output: {e}"), - ) from None + raise wrap_exception(ORTModuleIOError, TypeError(f"ORTModule fails to extract schema from data: {e}")) from None -def _extract_schema(data: ORTModelInputOutputType, device) -> ORTModelInputOutputSchemaType: - try: - _, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) - return schema - except TypeError as e: - raise wrap_exception(ORTModuleIOError, TypeError(f"ORTModule fails to extract schema from data: {e}")) from None +class _FlattenedModule(torch.nn.Module): + def __init__(self, original_module: torch.nn.Module): + super().__init__() + self._original_module: torch.nn.Module = original_module + # Before ONNX export, we flatten the args and kwargs into a 1-D list of tensors to make torch.export happy. + # As a result, we need to unflatten the args and kwargs back to the original structure before calling the + # original module's forward function. + # So we need set those information that are needed to unflatten the args and kwargs, before calling the + # torch.export. + self._device: Optional[torch.device] = None + self._args_schema: Optional[ORTModelInputOutputSchemaType] = None + self._kwargs_schema: Optional[ORTModelInputOutputSchemaType] = None + self._num_positionals: Optional[int] = None + + # Similarly, to make torch.export happy, we need to flatten the original module's outputs into a 1-D list of tensors. + # Need to keep the output schema to unflatten the outputs back to the original structure. + # Then those code depends on the original structure of the outputs can work properly. + self._output_schema: Optional[ORTModelInputOutputSchemaType] = None -def _parse_outputs_and_extract_names_and_dynamic_axes(module_output) -> Tuple[List[str], Dict[str, Dict[int, str]]]: - """Parses through the module output and returns output names and dynamic axes""" + def forward(self, *args): + new_args = unflatten_data_using_schema(args[: self._num_positionals], self._args_schema) - def _populate_output_names_and_dynamic_axes( - output, output_names: List[str], output_dynamic_axes: Dict[str, Dict[int, str]], output_idx: List[int] - ): - # Depth first traversal to traverse through the entire output collecting output names and dynamic axes - - if output is None: - return - elif isinstance(output, torch.Tensor): - # Naming the outputs with a hyphen ensures that there can be no input with the same - # name, preventing collisions with other NodeArgs (for example an input to forward called output0) - output_name = f"output-{output_idx[0]}" - output_idx[0] += 1 - output_names.append(output_name) - output_dynamic_axes[output_name] = {} - for dim_idx in range(len(output.shape)): - output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) - return - - if isinstance(output, abc.Sequence): - for value in output: - _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) - elif isinstance(output, abc.Mapping): - for _, value in sorted(output.items()): - _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(output)}"), - ) + new_kwargs = unflatten_data_using_schema(args[self._num_positionals :], self._kwargs_schema) - output_names: List[str] = [] - output_dynamic_axes: Dict[str, Dict[int, str]] = {} - output_idx: List[int] = [0] - _populate_output_names_and_dynamic_axes(module_output, output_names, output_dynamic_axes, output_idx) + original_outputs = self._original_module(*new_args, **new_kwargs) - return output_names, output_dynamic_axes + # Flatten the outputs + flatten_outputs, self._output_schema = _extract_schema(original_outputs, self._device) + # Append _OutputIdentityOp to the outputs to support passthrough outputs + final_flatten_outputs = [] + for output in flatten_outputs: + final_flatten_outputs.append(_OutputIdentityOp.apply(output)) -def _transform_output_to_flat_tuple(data): - """Converts the data to a flat tuple by iterating over the entire data structure""" + return final_flatten_outputs - def _flatten_data(data, flat_data): - # Recursively traverse over the data and populate the flat_data with torch.Tensors - if data is None: - return - elif isinstance(data, torch.Tensor): - identity = _OutputIdentityOp.apply - flat_data.append(identity(data)) - elif isinstance(data, abc.Sequence): - for value in data: - _flatten_data(value, flat_data) - elif isinstance(data, abc.Mapping): - for _, value in sorted(data.items()): - _flatten_data(value, flat_data) - else: - raise wrap_exception( - ORTModuleIOError, TypeError(f"ORTModule does not support the following data type {type(data)}.") - ) +class ModelInfoForExport: + def __init__( + self, + onnx_graph_input_names: List[str], + onnx_graph_input_names_require_grad: List[str], + onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], + onnx_graph_input_shapes: List[List[int]], + onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = None, + onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = None, + onnx_graph_input_arg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + onnx_graph_input_kwarg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + num_positional_args: int = 0, + export_mode: Optional[int] = None, + export_extra_kwargs: Optional[Dict[str, any]] = None, + ): + # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL + self.export_mode = export_mode + + # Exporter can take extra arguments for ORTModule extensions + # It cannot overlap with required/immutable arguments (validated in runtime) + self.export_extra_kwargs = export_extra_kwargs + + # Input names parsed and then flatten from the model's forward function signature. + # This should contains ONLY the user defined input names + # Be noted: some of the input might not be used by the model for its compute. + self.onnx_graph_input_names: List[str] = onnx_graph_input_names + + # A subset of onnx_graph_input_names. + # Input names that require gradient parsed and then flatten from the model's forward function signature + # This should contains ONLY the user defined input names + # Be noted: some of the input might not be used by the model for its compute. + self.onnx_graph_input_names_require_grad: List[str] = onnx_graph_input_names_require_grad + + # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). + # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} + # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} + self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = onnx_graph_input_dynamic_axes_map + + self.onnx_graph_input_shapes: List[List[int]] = onnx_graph_input_shapes + + # The input args schema for the original model's forward function. + # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs + # of the export model). + self.onnx_graph_input_arg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_arg_schema + + # The input kwargs schema for the original model's forward function. + # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs + # of the export model). + self.onnx_graph_input_kwarg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_kwarg_schema + + self.num_positional_args: int = num_positional_args + + # A function to access the input data from the args and kwargs. + # If it is not None, the length is same as onnx_graph_input_names. + # For i-th input name, we can use the i-th function to get the input data from args and kwargs. + self.onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = ( + onnx_graph_input_data_accessor_user_defined + ) + + self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor + + def __str__(self) -> str: + return f"""ModelInfoForExport class: + \tExport mode: {self.export_mode} + \tExport extra kwargs: {self.export_extra_kwargs} + \tInput names: {self.onnx_graph_input_names} + \tInput names require grad: {self.onnx_graph_input_names_require_grad} + \tInput dynamic axes: {self.onnx_graph_input_dynamic_axes_map} + \tInput shapes: {self.onnx_graph_input_shapes} + \tInput args schema: {self.onnx_graph_input_arg_schema} + \tInput kwargs schema: {self.onnx_graph_input_kwarg_schema} + \tNum input args: {self.num_positional_args}""" - flat_data = [] - _flatten_data(data, flat_data) - return tuple(flat_data) + def __repr__(self) -> str: + return self.__str__() -class _FlattenedModule(torch.nn.Module): - def __init__(self, original_module: torch.nn.Module): - super().__init__() - self._original_module: torch.nn.Module = original_module +def _arg_access_with_index_func(arg_index, args, kwargs): + return args[arg_index] - # Before `forward` is called, _ort_module must be assigned - # Updated input info is needed to expand args into *args, **kwargs - self._input_info: Optional[_InputInfo] = None - def forward(self, *args): - new_args, new_kwargs = self._input_info.unflatten(args) - return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs)) +def _kwarg_access_with_name_func(name, args, kwargs): + return kwargs[name] + + +class SkipRetValue: + """A placeholder class to indicate that the return value of a function should be skipped""" def parse_inputs_for_onnx_export( all_input_parameters: List[inspect.Parameter], - onnx_graph: Optional[onnx.ModelProto], - schema: ORTModelInputOutputSchemaType, args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], -) -> _InputInfo: + constant_as_tensor: bool, + device: torch.device, + export_mode: int, + logger: Logger, + export_extra_kwargs: Optional[Dict[str, any]] = None, +) -> ModelInfoForExport: """Parses through the model inputs and returns _InputInfo. Loop through all input parameters, try to flatten them into a 1-D list of inputs. For nested data in the inputs, @@ -414,67 +265,149 @@ def parse_inputs_for_onnx_export( Args: all_input_parameters: All inspected input parameters from the original model forward function signature. - onnx_graph: (optional) The onnx graph. - schema: The schema extracted from the positional arguments and keyword arguments of the model. args: The positional arguments of the model. kwargs: The keyword arguments of the model. + constant_as_tensor: Whether to treat constant inputs as tensors. + device: The device to be used for constant inputs. """ + arg_tensor_idx = [-1] + kwarg_tensor_idx = [-1] + def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: dynamic_axes[name] = {} for dim_idx in range(len(input.shape)): dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"}) return dynamic_axes - def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): + def _warn_of_constant_inputs(data): + logger.info(f"Received input of type {type(data)} is treated as a constant by ORT by default.") + + def _add_input( + name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable, tensor_idx: List[int] + ): """Returns number of expanded non none inputs that _add_input processed""" - if name in input_names or input_value is None or isinstance(input_value, str): - # Drop all None and string inputs and return 0. - return + # in case the input is already handled. + if name in visited_input_names: + return SkipRetValue() + + visited_input_names.append(name) + + value = input_value + primitive_dtype = None + if value is None: + _warn_of_constant_inputs(value) + data_accessors[name] = cur_func + return value + elif isinstance(value, str): + _warn_of_constant_inputs(value) + data_accessors[name] = cur_func + return value + elif PrimitiveType.is_primitive_type(value): + if constant_as_tensor: + # This has special handling for bool type to string conversion. + primitive_dtype = PrimitiveType.get_primitive_dtype(value) + value = PrimitiveType.get_tensor(value, device) + const_to_tensor_inputs[name] = device + + else: + data_accessors[name] = cur_func + _warn_of_constant_inputs(value) + return value + elif isinstance(value, abc.Sequence): + sequence_type = type(value) + stubbed_schema = [] - if isinstance(input_value, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself. - for i, val in enumerate(input_value): + for i, val in enumerate(value): # Name each input with the index appended to the original name of the # argument. - _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) + + def _access_func(i, cur_func, args, kwargs): + return cur_func(args, kwargs)[i] + + input_schema = _add_input( + f"{name}_{i}", + val, + onnx_graph_input_names, + partial(_access_func, i, cur_func), + tensor_idx, + ) + + if not isinstance(input_schema, SkipRetValue): + stubbed_schema.append(input_schema) # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. - return - elif isinstance(input_value, abc.Mapping): + + try: + # namedtuple can be created by passing the list sequence to method _make + stubbed_schema = sequence_type._make(stubbed_schema) + except AttributeError: + # If attribute error is encountered, create the sequence directly + stubbed_schema = sequence_type(stubbed_schema) + return stubbed_schema + + elif isinstance(value, abc.Mapping): + dict_type = type(value) + stubbed_schema = OrderedDict() + # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. - for key, val in input_value.items(): - _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) + for key, val in value.items(): + + def _access_func(key, cur_func, args, kwargs): + return cur_func(args, kwargs)[key] + + input_schema = _add_input( + f"{name}_{key}", + val, + onnx_graph_input_names, + partial(_access_func, key, cur_func), + tensor_idx, + ) + + if not isinstance(input_schema, SkipRetValue): + stubbed_schema[key] = input_schema # Return here since the dict by itself is not a valid input. # All the elements of the dict have already been added as inputs individually. - return - # InputInfo should contain all the names irrespective of whether they are - # a part of the onnx graph or not. - input_names.append(name) + stubbed_schema = dict_type(**stubbed_schema) + return stubbed_schema - if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor): - if input_value.requires_grad: + if isinstance(value, torch.Tensor): + onnx_graph_input_names.append(name) + data_accessors[name] = cur_func + if value.requires_grad: input_names_require_grad.append(name) - dynamic_axes.update(_add_dynamic_shape(name, input_value)) - input_shape.append(list(input_value.size())) + dynamic_axes.update(_add_dynamic_shape(name, value)) + input_shape.append(list(value.size())) + tensor_idx[0] += 1 + return _TensorStub( + tensor_idx[0], + dtype=primitive_dtype if primitive_dtype else str(value.dtype), # special handle for bool primitive + shape_dims=len(value.size()), + name=name, + ) - # Ignore optional inputs explicitly specified as None - # ONNX exporter may remove unused inputs - onnx_graph_input_names: List[str] = [] - if onnx_graph is not None: - onnx_graph_input_names = {inp.name for inp in onnx_graph.graph.input} + raise ORTModuleIOError(f"ORTModule does not support input type {type(value)} for input {name}") - input_names: List[str] = [] + visited_input_names: List[str] = [] + + onnx_graph_input_names: List[str] = [] dynamic_axes: Dict[str, Dict[int, str]] = {} input_names_require_grad: List[str] = [] input_shape: List[List[int]] = [] + input_arg_schema: ORTModelInputOutputSchemaType = [] + input_kwarg_schema: ORTModelInputOutputSchemaType = OrderedDict() + data_accessors: Dict[str, Callable] = OrderedDict() + const_to_tensor_inputs: Dict[str, torch.device] = OrderedDict() + num_positional_args: int = 0 + var_positional_idx = 0 # Be noted, all_input_parameters is a list of inspect.Parameters parsed from the original module's forward method. @@ -504,7 +437,17 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = args[args_i] - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + pre_tensor_idx = arg_tensor_idx[0] + schema = _add_input( + name, + inp, + onnx_graph_input_names, + partial(_arg_access_with_index_func, args_i), + arg_tensor_idx, + ) + num_positional_args += arg_tensor_idx[0] - pre_tensor_idx + if not isinstance(schema, SkipRetValue): + input_arg_schema.append(schema) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -514,25 +457,51 @@ def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): name = input_parameter.name inp = None input_idx += var_positional_idx # noqa: PLW2901 - if input_idx < len(args) and args[input_idx] is not None: + access_func = None + if input_idx < len(args): inp = args[input_idx] - elif name in kwargs and kwargs[name] is not None: + access_func = partial(_arg_access_with_index_func, input_idx) + pre_tensor_idx = arg_tensor_idx[0] + schema = _add_input(name, inp, onnx_graph_input_names, access_func, arg_tensor_idx) + num_positional_args += arg_tensor_idx[0] - pre_tensor_idx + if not isinstance(schema, SkipRetValue): + input_arg_schema.append(schema) + elif name in kwargs: inp = kwargs[name] - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + access_func = partial(_kwarg_access_with_name_func, name) + schema = _add_input(name, inp, onnx_graph_input_names, access_func, kwarg_tensor_idx) + if not isinstance(schema, SkipRetValue): + input_kwarg_schema[name] = schema + elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): - _add_input(name, inp, onnx_graph, onnx_graph_input_names) - - return _InputInfo( - names=input_names, - shape=input_shape, - require_grad_names=input_names_require_grad, - dynamic_axes=dynamic_axes, - schema=schema, - num_positionals=len(args), + schema = _add_input( + name, + inp, + onnx_graph_input_names, + partial(_kwarg_access_with_name_func, name), + kwarg_tensor_idx, + ) + if not isinstance(schema, SkipRetValue): + input_kwarg_schema[name] = schema + + exported_graph = ModelInfoForExport( + onnx_graph_input_names=onnx_graph_input_names, + onnx_graph_input_names_require_grad=input_names_require_grad, + onnx_graph_input_dynamic_axes_map=dynamic_axes, + onnx_graph_input_shapes=input_shape, + onnx_graph_input_data_accessor_user_defined=data_accessors, + onnx_graph_input_const_as_tensor=const_to_tensor_inputs, + onnx_graph_input_arg_schema=input_arg_schema, + onnx_graph_input_kwarg_schema=input_kwarg_schema, + num_positional_args=num_positional_args, + export_mode=export_mode, + export_extra_kwargs=export_extra_kwargs, ) + return exported_graph + def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: """Calculate the total parameter size in bytes""" @@ -568,20 +537,19 @@ def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.de def parse_outputs_for_onnx_export_and_extract_schema( module, - args: Sequence[ORTModelInputOutputType], - kwargs: Mapping[str, ORTModelInputOutputType], + flatten_args: Sequence[ORTModelInputOutputType], logger: Logger, - device: Optional[torch.device], clone_module: bool, ): # Perform a forward call to grab outputs output_names = None output_dynamic_axes = None deep_copied = False + kwargs = {} logger.info("Running model forward to infer output schema and dynamic axes...") with torch.no_grad(): # Deepcopy inputs, since input values may change after model run. - sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*args, **kwargs) + sample_args_copy, sample_kwargs_copy = deepcopy_model_input(*flatten_args, **kwargs) try: if clone_module: # Deepcopy model, in case model is stateful and changes after model run. @@ -600,9 +568,17 @@ def parse_outputs_for_onnx_export_and_extract_schema( sample_outputs = model_copy(*sample_args_copy, **sample_kwargs_copy) # Parse the output and extract the output_names and output_dynamic_axes to be used for onnx export - output_names, output_dynamic_axes = _parse_outputs_and_extract_names_and_dynamic_axes(sample_outputs) + output_names: List[str] = [] + output_dynamic_axes: Dict[str, Dict[int, str]] = {} + for output_idx, output in enumerate(sample_outputs): + output_name = f"output-{output_idx}" + output_names.append(output_name) + output_dynamic_axes[output_name] = {} + for dim_idx in range(len(output.shape)): + output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) + + original_module_output_schema = model_copy._output_schema - output_schema = _extract_schema(sample_outputs, device) if deep_copied: del model_copy gc.collect() @@ -611,4 +587,4 @@ def parse_outputs_for_onnx_export_and_extract_schema( # Release the memory cached by torch. torch.cuda.empty_cache() # Return output names, output dynamic axes and output schema - return output_names, output_dynamic_axes, output_schema + return output_names, output_dynamic_axes, original_module_output_schema diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index 91b99d4323d6f..4d54e8e59fb50 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -165,6 +165,24 @@ def wrapper(graph_execution_manager, *args, **kwargs): return wrapper +class TrackTimeForStaticFunction: + """A function decorator to track time spent in different phases of ORT backend first-time initialization.""" + + def __init__(self, phase: ORTModuleInitPhase): + self.phase = phase + + def __call__(self, func: Callable): + def wrapper(*args, **kwargs): + if "time_tracker" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'time_tracker' kwarg.") + kwargs["time_tracker"].start(self.phase) + result = func(*args, **kwargs) + kwargs["time_tracker"].end(self.phase) + return result + + return wrapper + + @contextmanager def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): """Suppress output from being printed to stdout and stderr. @@ -255,27 +273,27 @@ def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True): self.is_ort_filter = is_ort_filter def __call__(self, func: Callable): - def wrapper(graph_execution_manager, *args, **kwargs): - if not hasattr(graph_execution_manager, "_logger"): - raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.") + def wrapper(*args, **kwargs): + if "logger" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'logger' kwarg.") - if not hasattr(graph_execution_manager, "_debug_options"): - raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.") + if "debug_options" not in kwargs: + raise RuntimeError("The function to be tracked must have a 'debug_options' kwarg.") with _suppress_os_stream_output( - enable=graph_execution_manager._debug_options.log_level >= LogLevel.DEVINFO, + enable=kwargs["debug_options"].log_level >= LogLevel.DEVINFO, on_exit=partial( _log_with_filter, - graph_execution_manager._logger, + kwargs["logger"], ( - graph_execution_manager._debug_options.onnxruntime_log_filter + kwargs["debug_options"].onnxruntime_log_filter if self.is_ort_filter - else graph_execution_manager._debug_options.torch_exporter_filter + else kwargs["debug_options"].torch_exporter_filter ), self.phase.to_string(), ), ): - result = func(graph_execution_manager, *args, **kwargs) + result = func(*args, **kwargs) return result return wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 241b7a5c5344b..fcab32f4356bc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -26,14 +26,6 @@ def get_params_connected_to_pull_param_trigger( return {k: v for k, v in named_params if v.requires_grad and k in onnx_initializer_names} -def get_params_not_connected_to_pull_param_trigger( - named_params: dict[str, torch.nn.parameter.Parameter], exported_model: ModelProto -): - # Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also. - onnx_initializer_names = {p.name for p in exported_model.graph.input} - return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names] - - def post_processing_enable_mem_efficient_training( exported_model: ModelProto, named_params: dict[str, torch.nn.parameter.Parameter], @@ -138,7 +130,7 @@ def _create_param_retrieval_function( Args: trainable_named_params: The trainable named parameters. - param_trigger: The trigger tensor for pulling the weights. param_trigger is pre-alloced just once + param_trigger: The trigger tensor for pulling the weights. param_trigger is pre-allocated just once before model execution, later it will be reused by each iteration. This could save the unnecessary overhead allocating for each iteration run. diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index a0001a2f201f1..4b6011f0786ec 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -23,8 +23,7 @@ def _save_model(model: onnx.ModelProto, file_path: str): class ONNXModels: """Encapsulates all ORTModule onnx models. - 1. exported_model: Model that is exported by torch.onnx.export - 2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, + 1. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, for training mode, it's an optimized model after the gradients graph has been built. In addition, ORTModule also saves two other models, to the user-provided path: a. the pre_grad_model which is the model before the gradients graph is built. @@ -32,16 +31,8 @@ class ONNXModels: It has further optimizations done by the InferenceSession and is saved by the InferenceSession. """ - exported_model: Optional[onnx.ModelProto] = None - processed_exported_model: Optional[onnx.ModelProto] = None optimized_model: Optional[onnx.ModelProto] = None - def save_exported_model(self, path, name_prefix, export_mode): - # save the ortmodule exported model - _save_model( - self.exported_model, os.path.join(path, _get_onnx_file_name(name_prefix, "torch_exported", export_mode)) - ) - def save_optimized_model(self, path, name_prefix, export_mode): # save the ortmodule optimized model _save_model( diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index f35e3f74ba60a..d5d5ce672224c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -12,12 +12,12 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type -from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils +from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils from ._execution_agent import TrainingAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._io import _FlattenedModule, _InputInfo, unflatten_user_output +from ._io import _FlattenedModule from ._logger import ORTModuleInitPhase, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results @@ -247,27 +247,17 @@ def forward(self, *inputs, **kwargs): build_gradient_graph = False if ( self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False - or not self._onnx_models.exported_model + or not self._graph_transition_manager._exported_model_info ): self.time_tracker.start(ORTModuleInitPhase.EndToEnd) - build_gradient_graph = self._export_model(*inputs, **kwargs) + ( + build_gradient_graph, + post_export_processed_model_info, + ) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs) if build_gradient_graph: - # If model was exported, then initialize the graph builder - self._initialize_graph_builder() - - # Since the schema was just extracted while trying to export the model and it was either - # saved to self._input_info.schema or checked for equality with the self._input_info.schema - # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. - input_info = _io.parse_inputs_for_onnx_export( - self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs - ) - - # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. - # Order of or operation is important here because we always need to call - # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. - build_gradient_graph = self._reinitialize_graph_builder(input_info) or build_gradient_graph + self._initialize_graph_builder(post_export_processed_model_info) # Build the gradient graph if build_gradient_graph: @@ -284,9 +274,7 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent ): - device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( - inputs, kwargs - ) + device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs) create_execution_session = ( build_gradient_graph or self._device != device @@ -294,7 +282,7 @@ def forward(self, *inputs, **kwargs): ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != device: - self._device = device + self._graph_transition_manager._device = device if create_execution_session: # Create execution session creates the training_session @@ -309,36 +297,16 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() - if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled: + if self._runtime_options.enable_zero_stage3_support: self._append_pull_weight_trigger_as_input(kwargs, self._device) - param_to_append_as_onnx_graph_inputs = [] - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger - - param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger( - self._flattened_module.named_parameters(), self._onnx_models.exported_model - ) - - else: - param_to_append_as_onnx_graph_inputs = self._graph_initializers - - prepared_input_list = _io._combine_input_buffers_initializers( - param_to_append_as_onnx_graph_inputs, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device, - self._runtime_inspector, - self._zero_stage3_param_map, + prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs( + inputs, kwargs, True, self._device ) - outputs = unflatten_user_output( - self._module_output_schema, - self._forward_class.apply(*prepared_input_list), - ) + user_outputs = self._forward_class.apply(*prepared_input_map.values()) + + outputs = self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) if ( create_execution_session @@ -390,21 +358,15 @@ def _build_graph(self, graph_transformer_config): # Map each input/initializer to its gradient index in the graph output, or -1 is gradient is not required. self._gradient_map = [] - num_user_input_grads = len(self._input_info.require_grad_names) - require_grad_names_set = set(self._input_info.require_grad_names) - require_grad_names_index = 0 - for input_name in self._graph_info.user_input_names: - if input_name in require_grad_names_set: - self._gradient_map.append(require_grad_names_index) - require_grad_names_index += 1 - else: - self._gradient_map.append(-1) - initializer_index = num_user_input_grads - for initializer_name in self._graph_info.initializer_names: - if initializer_name in self._graph_initializer_names_to_train: - self._gradient_map.append(initializer_index) - initializer_index += 1 + index_for_input_requires_grad = 0 + for input_name in self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names: + if ( + input_name + in self._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names_require_grad + ): + self._gradient_map.append(index_for_input_requires_grad) + index_for_input_requires_grad += 1 else: self._gradient_map.append(-1) @@ -414,7 +376,8 @@ def _create_execution_agent(self): session_options, providers, provider_options = self._get_session_config() fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] - device_type = self._device if type(self._device) is str else self._device.type.lower() # noqa: E721 + device_type = self._device if isinstance(self._device, str) else self._device.type.lower() + if device_type == "ort": fw_outputs_device_info = [C.get_ort_device(self._device.index)] * ( len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map) @@ -491,44 +454,11 @@ def _create_execution_agent(self): self._execution_agent._inference_session, True, self._runtime_options.tuning_results_path ) - def _reinitialize_graph_builder(self, input_info: _InputInfo): - """Return true if the module graph builder was reinitialized""" - - # Model may have unused params dropped after export and not part of self._graph_initializer_names_to_train - # To see if any trainable initializers changed, compare self._graph_initializer_names_to_train - # with initializers in module named_parameters that are known to the onnx graph. - initializer_names_to_train_set_user_model = { - name - for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in self._graph_initializer_names - } - - if self._mem_efficient_grad_management_is_enabled: - from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - - # Remove the inputs we added during model post-processing. - existing_require_grad_names = [ - n for n in self._input_info.require_grad_names if n != MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME - ] - else: - existing_require_grad_names = self._input_info.require_grad_names - - # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder - # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad - if ( - input_info.require_grad_names != existing_require_grad_names - or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train - ): - self._input_info = input_info - self._initialize_graph_builder() - return True - return False - def __getstate__(self): state = super().__getstate__() - # Only top level classes are pickleable. So, _ORTModuleFunction is - # not pickleable. So, let's not pickle it, and redefine it when + # Only top level classes are picklable. So, _ORTModuleFunction is + # not picklable. So, let's not pickle it, and redefine it when # loading the state. del state["_forward_class"] return state diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 5faa1c62bae4f..c299d1c5db4e7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -153,7 +153,15 @@ def get_device_str(device: Union[str, int, torch.device]) -> str: return device -def get_device_from_module(module) -> Optional[torch.device]: +def get_device_from_module_and_inputs(module, inputs, kwargs): + """Get the device from the module and save it to self._device""" + + device = _get_device_from_module(module) or _get_device_from_inputs(inputs, kwargs) + + return device + + +def _get_device_from_module(module) -> Optional[torch.device]: """Returns the first device found in the `module`'s parameters or None Args: @@ -179,7 +187,7 @@ def get_device_from_module(module) -> Optional[torch.device]: return device -def get_device_from_inputs(args, kwargs) -> Optional[torch.device]: +def _get_device_from_inputs(args, kwargs) -> Optional[torch.device]: """Returns device from first PyTorch Tensor within args or kwargs Args: @@ -192,9 +200,12 @@ def get_device_from_inputs(args, kwargs) -> Optional[torch.device]: device = None if args: - device = torch.device(args[0].device) + if args[0] is not None and hasattr(args[0], "device"): + device = torch.device(args[0].device) elif kwargs: - device = torch.device(next(iter(kwargs.values())).device) + v = next(iter(kwargs.values())) + if v is not None and hasattr(v, "device"): + device = torch.device(v.device) return device diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index ff110c431d300..11d978e71d8a8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -395,7 +395,7 @@ def _update_python_op_input_related_attributes( @contextmanager -def stage3_export_context(enable: bool, graph_execution_manager): +def stage3_export_context(enable: bool, stage3_param_handle, flattened_module): """Context manager for stage3 specific model export. Some export functions are overridden when entering the context; the original functions are restored when exiting the context. @@ -411,9 +411,7 @@ def stage3_export_context(enable: bool, graph_execution_manager): # Delay collecting stage3 parameters here instead of in the graph execution manager, # to make sure DeepSpeed initialization is done, so that the parameters ds_status are correct. - graph_execution_manager._zero_stage3_param_map = _get_all_zero_stage3_params( - graph_execution_manager._flattened_module - ) + stage3_param_handle._zero_stage3_param_map = _get_all_zero_stage3_params(flattened_module) try: from torch.onnx._internal import _beartype @@ -428,8 +426,8 @@ def _get_tensor_rank(x) -> Optional[int]: from torch.onnx.symbolic_helper import _is_tensor input_name = x.debugName() - if input_name in graph_execution_manager._zero_stage3_param_map: - rank = len(graph_execution_manager._zero_stage3_param_map[input_name].ds_shape) + if input_name in stage3_param_handle._zero_stage3_param_map: + rank = len(stage3_param_handle._zero_stage3_param_map[input_name].ds_shape) return rank if not _is_tensor(x) or x.type() is None: diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py index e6e5ce56773e1..fbd98675aebe6 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py @@ -41,13 +41,13 @@ class GraphMatcher: * Second bool indicates it's producer node or consumer node for source node. * There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers, first integer is the index of the target node in the list, second integer is the output index of the edge, - and thrid integer is the input index of the edge. + and third integer is the input index of the edge. For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also matches rest edge infos. Note that when lookup target node, it will only take the first matched node as target node. For example, if a source - node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned. + node has multiple "MatMul" consumers nodes consuming same output, only the first "MatMul" node will be returned. You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to avoid such confusion if possible. """ diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 9145fb1712e88..2d036a5abcb5c 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -75,7 +75,7 @@ def __init__(self, log_level): def _extract_info(self, log_level): # get the log_level from os env variable - # OS environment variable log level superseeds the locally provided one + # OS environment variable log level supersedes the locally provided one self._validate(log_level) log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)] return log_level @@ -197,7 +197,7 @@ class _MemoryOptimizationLevel(IntFlag): USER_SPECIFIED = 0 # Fully respect user-specified config TRANSFORMER_LAYERWISE_RECOMPUTE = ( - 1 # Enable all recomputable subgraphs (excluding compromised recomptable graphs) per layer + 1 # Enable all recomputable subgraphs (excluding compromised recomputable graphs) per layer ) TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE = 2 # Enable all recomputable subgraphs per layer diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index b5c52bdaef3c6..b291bfb2ba03c 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -124,7 +124,7 @@ def forward(self, *inputs, **kwargs): The first call to forward performs setup and checking steps. During this call, ORTModule determines whether the module can be trained with ONNX Runtime. For this reason, the first forward call execution takes longer than subsequent calls. - Execution is interupted if ONNX Runtime cannot process the model for training. + Execution is interrupted if ONNX Runtime cannot process the model for training. Args: inputs: positional, variable positional inputs to the PyTorch module's forward method. @@ -306,7 +306,9 @@ def __setattr__(self, name: str, value) -> None: # Re-export will be avoided if _skip_check is enabled. if isinstance(self._torch_module, TorchModuleORT): for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode).signal_model_changed() + self._torch_module._execution_manager( + training_mode + )._graph_transition_manager.signal_model_changed() else: # Setting any new attributes should be done on ORTModule only when 'torch_module' is not defined diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index 34cc1ca942a8c..4824ed7137021 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -5,7 +5,7 @@ import copy import warnings -from collections import abc +from collections import OrderedDict, abc from typing import List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -221,8 +221,8 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): return stubbed_schema elif isinstance(data, abc.Mapping): dict_type = type(data) - stubbed_schema = {} - for key, val in sorted(data.items()): + stubbed_schema = OrderedDict() + for key, val in data.items(): stubbed_schema[key] = _flatten_from_data(val, f"{prefix_name}_{key}" if prefix_name else f"{key}") stubbed_schema = dict_type(**stubbed_schema) return stubbed_schema @@ -305,7 +305,7 @@ def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType, return data_schema elif isinstance(data_schema, abc.Mapping): new_user_output = copy.copy(data_schema) - for key, schema_val in sorted(data_schema.items()): + for key, schema_val in data_schema.items(): new_user_output[key] = _replace_stub_with_tensor_value(schema_val, data) data_schema = new_user_output diff --git a/orttraining/orttraining/test/gradient/allreduce_op_test.cc b/orttraining/orttraining/test/gradient/allreduce_op_test.cc index 82f01a3c43681..1b1bd680a1191 100644 --- a/orttraining/orttraining/test/gradient/allreduce_op_test.cc +++ b/orttraining/orttraining/test/gradient/allreduce_op_test.cc @@ -472,7 +472,7 @@ TEST(AllreduceTest, GPUHierarchicalAdasumAllreduceOptimizerTest) { build_allreduce_graph(graph, adasum_graph_configs, training::AdasumReductionType::GpuHierarchicalReduction, true /*build_optimizer*/, false /*half_precision*/); - std::string model_file_name = "GPUHierarchicalAdasumAllreduceOptimizerTest.onnx"; + PathString model_file_name = ORT_TSTR("GPUHierarchicalAdasumAllreduceOptimizerTest.onnx"); auto status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -649,7 +649,7 @@ TEST(AllreduceTest, GPUHierarchicalAdasumAllreduceOptimizerFP16Test) { build_allreduce_graph(graph, adasum_graph_configs, training::AdasumReductionType::GpuHierarchicalReduction, true /*build_optimizer*/, true /*half_precision*/); - std::string model_file_name = "GPUHierarchicalAdasumAllreduceOptimizerFP16Test.onnx"; + PathString model_file_name = ORT_TSTR("GPUHierarchicalAdasumAllreduceOptimizerFP16Test.onnx"); auto status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -791,7 +791,7 @@ TEST(AllreduceTest, GPUHierarchicalAdasumAllreduceTest) { adasum_graph_configs.push_back(adasum_graph_config); build_allreduce_graph(graph, adasum_graph_configs, training::AdasumReductionType::GpuHierarchicalReduction); - std::string model_file_name = "GPUHierarchicalAdasumAllreduceTest.onnx"; + PathString model_file_name = ORT_TSTR("GPUHierarchicalAdasumAllreduceTest.onnx"); auto status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -896,7 +896,7 @@ TEST(AllreduceTest, GPUHierarchicalAdasumFP16AllreduceTest) { false /*build_optimizer*/, true /*half_precision*/); - std::string model_file_name = "GPUHierarchicalAdasumFP16AllreduceTest.onnx"; + PathString model_file_name = ORT_TSTR("GPUHierarchicalAdasumFP16AllreduceTest.onnx"); auto status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -1003,7 +1003,7 @@ TEST(AllreduceTest, GPUAdasumAllreduceTest) { build_allreduce_graph(graph, adasum_graph_configs, training::AdasumReductionType::CpuReduction); - std::string model_file_name = "GPUAdasumAllreduceTest.onnx"; + PathString model_file_name = ORT_TSTR("GPUAdasumAllreduceTest.onnx"); auto status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; @@ -1110,7 +1110,7 @@ TEST(AllreduceTest, GPUAdasumFP16AllreduceTest) { build_allreduce_graph(graph, adasum_graph_configs, training::AdasumReductionType::CpuReduction, true /*half_precision*/); - std::string model_file_name = "GPUAdasumFP16AllreduceTest.onnx"; + PathString model_file_name = ORT_TSTR("GPUAdasumFP16AllreduceTest.onnx"); auto status = onnxruntime::Model::Save(model, model_file_name); SessionOptions so; diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 4ab035a171430..b2ab4891f2e1e 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -627,7 +627,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - auto model_uri2 = "mlp_megatron_basic_test_partition_rank0.onnx"; + PathString model_uri2 = ORT_TSTR("mlp_megatron_basic_test_partition_rank0.onnx"); ASSERT_STATUS_OK(Model::Save(*p_model, model_uri2)); { @@ -705,7 +705,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) { TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - auto model_uri2 = "mlp_megatron_basic_test_partition_rank1.onnx"; + PathString model_uri2 = ORT_TSTR("mlp_megatron_basic_test_partition_rank1.onnx"); ASSERT_STATUS_OK(Model::Save(*p_model, model_uri2)); { @@ -765,7 +765,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) { } TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { - auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; + PathString model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); @@ -781,7 +781,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - auto model_uri2 = "self_attention_megatron_basic_test_partition_rank0.onnx"; + PathString model_uri2 = ORT_TSTR("self_attention_megatron_basic_test_partition_rank0.onnx"); ASSERT_STATUS_OK(Model::Save(*p_model, model_uri2)); { @@ -838,7 +838,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { } TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { - auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; + PathString model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); @@ -856,7 +856,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - auto model_uri2 = "self_attention_megatron_basic_test_partition_rank1.onnx"; + PathString model_uri2 = ORT_TSTR("self_attention_megatron_basic_test_partition_rank1.onnx"); ASSERT_STATUS_OK(Model::Save(*p_model, model_uri2)); { @@ -913,7 +913,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { } TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) { - auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_recompute.onnx"; + PathString model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_recompute.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); @@ -1397,7 +1397,7 @@ static void RunPartitionCorrectnessTest(std::string model_path, TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); graphs.push_back(&graph); - auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_rank_") + ToPathString(std::to_string(i)) + ORT_TSTR(".onnx"); + PathString model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_rank_") + ToPathString(std::to_string(i)) + ORT_TSTR(".onnx"); ASSERT_STATUS_OK(Model::Save(*p_models[i], model_uri2)); } @@ -1405,7 +1405,7 @@ static void RunPartitionCorrectnessTest(std::string model_path, auto& combine_graph = combine_model.MainGraph(); auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); ORT_ENFORCE(ret.IsOK()); - auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_combine.onnx"); + PathString model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_combine.onnx"); ASSERT_STATUS_OK(Model::Save(combine_model, model_uri2)); float scale = 1.f; @@ -1790,7 +1790,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionTwoInputs) { #ifdef ENABLE_TRITON TEST_F(GraphTransformationTests, TritonFusion) { - auto model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; + PathString model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); @@ -1805,7 +1805,7 @@ TEST_F(GraphTransformationTests, TritonFusion) { ASSERT_TRUE(op_to_count["LayerNormalization"] == 4); { - auto model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; + PathString model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); @@ -1845,7 +1845,7 @@ TEST_F(GraphTransformationTests, TritonFusion) { // No Dropout. { - auto model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; + PathString model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); @@ -1884,7 +1884,7 @@ TEST_F(GraphTransformationTests, TritonFusion) { // Ignore min nodes. { - auto model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; + PathString model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); @@ -1924,7 +1924,7 @@ TEST_F(GraphTransformationTests, TritonFusion) { // Exclude Softmax using axis attribute. { - auto model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; + PathString model_uri = MODEL_FOLDER "bert_toy_opset14.onnx"; std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index ac49c1c2834c7..5c63be92d2b2f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1099,3 +1099,91 @@ def test_custom_optimizer_block(): for attr in node.attribute: if attr.name == "weight_decay": assert attr.f == weight_decay + + +def test_generate_artifacts_path(): + + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + + +def test_generate_artifacts_external_data_one_file(): + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + save_as_external_data=True, + all_tensors_to_one_file=True, + size_threshold=0, + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + + +def test_generate_artifacts_external_data_separate_files(): + with tempfile.TemporaryDirectory() as temp_dir: + _, simple_net = _get_models("cpu", 32, 28, 10, 10) + + requires_grad_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"] + + onnx.save_model( + simple_net, + os.path.join(temp_dir, "simple_net.onnx"), + save_as_external_data=True, + all_tensors_to_one_file=False, + size_threshold=0, + ) + + artifacts.generate_artifacts( + os.path.join(temp_dir, "simple_net.onnx"), + requires_grad=requires_grad_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + ) + + # generate_artifacts should have thrown if it didn't complete successfully. + # Below is a sanity check to validate that all the expected files were created. + assert os.path.exists(os.path.join(temp_dir, "training_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f35bb47f6b41d..541473b1561db 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -30,7 +30,7 @@ import onnxruntime.training.ortmodule as ortmodule_module from onnxruntime.training.optim import AdamWMode, FusedAdam -from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils +from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck from onnxruntime.training.utils import pytorch_type_to_onnx_dtype @@ -463,9 +463,11 @@ def test_forward_call_single_positional_argument(): N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) + # Check that the original forward signature is preserved. assert inspect.signature(model.forward) == inspect.signature(ort_model.forward) x = torch.randn(N, D_in, device=device) + # Make sure model runs without any exception prediction = ort_model(x) assert prediction is not None @@ -699,7 +701,15 @@ def test_input_requires_grad_saved(device): model = ORTModule(model) x = torch.randn(N, D_in, device=device, requires_grad=True) + 1 model(x) - assert "input1" in model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names + assert model._torch_module._execution_manager( + model._is_training() + )._graph_transition_manager._model_info_for_export.onnx_graph_input_names_require_grad == ["input1"] + assert ( + "input1" + in model._torch_module._execution_manager( + model._is_training() + )._graph_transition_manager._post_export_processed_model_info.onnx_graph_input_names_require_grad + ) @pytest.mark.parametrize("device", ["cuda", "cpu"]) @@ -841,6 +851,7 @@ def forward(self, input): device = "cuda" pt_model = NeuralNetTranspose(perm).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x): @@ -2655,7 +2666,10 @@ def test_exception_raised_for_custom_class_return_value_module(device): # ORT backend with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3482,6 +3496,13 @@ def train_step(model, x): _test_helpers.assert_values_are_close(pt_out, ort_out) +def _repr_schema(ortmodule): + tm = ortmodule._torch_module._execution_manager(ortmodule._is_training())._graph_transition_manager + return repr(tm._exported_model_info.module_forward_args_schema) + repr( + tm._exported_model_info.module_forward_kwargs_schema + ) + + def test_forward_dynamic_args(): os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" @@ -3506,21 +3527,21 @@ def test_forward_dynamic_args(): for _ in range(10): output = model(*args_size1) assert output is not None - hash_args_size1 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size1 = hash(_repr_schema(model)) assert hash_args_size1 is not None # Decrease number of inputs and train some more for _ in range(10): output = model(*args_size2) assert output is not None - hash_args_size2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size2 = hash(_repr_schema(model)) assert hash_args_size2 != hash_args_size1 # Increase number of inputs and train some more for _ in range(10): output = model(*args_size3) assert output is not None - hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_args_size3 = hash(_repr_schema(model)) assert hash_args_size3 != hash_args_size2 del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -3545,35 +3566,35 @@ def test_forward_dynamic_kwargs(): for _ in range(10): output = model(one) assert output is not None - hash_x = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x = hash(_repr_schema(model)) assert hash_x is not None # Train with x and y as inputs for _ in range(10): output = model(one, y=one) assert output is not None - hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_y = hash(_repr_schema(model)) assert hash_x_y != hash_x # Train with x and z as inputs for _ in range(10): output = model(one, z=one) assert output is not None - hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_z = hash(_repr_schema(model)) assert hash_x_z != hash_x_y # Train with x, y and z as inputs for _ in range(10): output = model(one, y=one, z=one) assert output is not None - hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x_y_z = hash(_repr_schema(model)) assert hash_x_y_z != hash_x_z # Return to original input with x as input for _ in range(10): output = model(one) assert output is not None - hash_x2 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) + hash_x2 = hash(_repr_schema(model)) assert hash_x2 != hash_x_y_z assert hash_x2 == hash_x @@ -4003,10 +4024,14 @@ def forward(self, input1, bool_argument): input1 = torch.randn(N, D_in, device=device) ort_model(input1, bool_arguments[0]) - exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model1 = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model ort_model(input1, bool_arguments[1]) - exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model2 = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model assert exported_model1 != exported_model2 @@ -4172,36 +4197,6 @@ def test_stateless_model_unspecified_device(): _test_helpers.assert_values_are_close(pt_y, ort_y) -@pytest.mark.parametrize( - "model", - [ - (UnusedBeginParameterNet(784, 500, 400, 10)), - (UnusedMiddleParameterNet(784, 500, 400, 10)), - (UnusedEndParameterNet(784, 500, 400, 10)), - ], -) -def test_unused_parameters_does_not_unnecessarily_reinitialize(model): - device = "cuda" - - N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 # noqa: F841, N806 - model = model.to(device) - ort_model = ORTModule(copy.deepcopy(model)) - training_manager = ort_model._torch_module._execution_manager(ort_model._is_training()) - - x = torch.randn(N, D_in, device=device) - _ = ort_model(x) - - input_info = _io.parse_inputs_for_onnx_export( - training_manager._module_parameters, - training_manager._onnx_models.exported_model, - training_manager._input_info.schema, - x, - {}, - ) - - assert not training_manager._reinitialize_graph_builder(input_info) - - def test_load_state_dict_for_wrapped_ortmodule(): class WrapperModule(torch.nn.Module): def __init__(self, ortmodule): @@ -4264,14 +4259,23 @@ def test_hf_save_pretrained(): assert p1.data.ne(p2.data).sum() == 0 -def test_ortmodule_string_inputs_are_ignored(): +def test_ortmodule_string_inputs_are_ignored(caplog): pt_model = MyStrNet() - target_str = "Received input of type which may be treated as a constant by ORT by default." - with pytest.warns(UserWarning, match=target_str): - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) - x = torch.randn(1, 2) - out = ort_model(x, "hello") - _test_helpers.assert_values_are_close(out, x + 1) + target_str = "Received input of type is treated as a constant by ORT by default." + + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO)) + x = torch.randn(1, 2) + out = ort_model(x, "hello") + _test_helpers.assert_values_are_close(out, x + 1) + + found_log = False + for record in caplog.records: + msg = record.getMessage() + if target_str in msg: + found_log = True + break + + assert found_log, f"Expected to find log message '{target_str}' in the logs, but didn't find it." def test_ortmodule_list_input(): @@ -4831,17 +4835,31 @@ def forward(self, a): ort_model = ORTModule(pt_model) _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model1 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + exported_model1 = ort_model._torch_module._execution_manager( + True + )._graph_transition_manager._exported_model_info.exported_model for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is False + assert ( + ort_model._torch_module._execution_manager( + training_mode + )._graph_transition_manager._original_model_has_changed + is False + ) ort_model.input_flag = False for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed is True + assert ( + ort_model._torch_module._execution_manager( + training_mode + )._graph_transition_manager._original_model_has_changed + is True + ) _ = ort_model(torch.randn(N, D_in, device=device)) - exported_model2 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model + exported_model2 = ort_model._torch_module._execution_manager( + True + )._graph_transition_manager._exported_model_info.exported_model assert exported_model1 != exported_model2 @@ -4999,7 +5017,9 @@ def test_override_pytorch_exporter_kwargs(): model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"custom_opsets": None} + ort_model._torch_module._execution_manager(True)._graph_transition_manager._export_extra_kwargs = { + "custom_opsets": None + } # Make sure model runs without any exception prediction = ort_model(x) @@ -5016,7 +5036,7 @@ def test_override_pytorch_exporter_kwargs__invalid(): model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"verbose": False} + ort_model._torch_module._execution_manager(True)._graph_transition_manager._export_extra_kwargs = {"verbose": False} with pytest.raises(_fallback.ORTModuleONNXModelException) as type_error: _ = ort_model(x) assert "The following PyTorch exporter arguments cannot be specified: '{'verbose'}'." in str(type_error.value) @@ -5029,7 +5049,9 @@ class ORTModuleExtension(ORTModule): def __init__(self, module, debug_options=None): super().__init__(module, debug_options) for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"verbose": None} + self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = { + "verbose": None + } N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5049,7 +5071,9 @@ def __init__(self, module, debug_options=None): super().__init__(module, debug_options) # modify GraphExecutionManager internally for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"custom_opsets": None} + self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = { + "custom_opsets": None + } N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) @@ -5280,11 +5304,12 @@ def run_step(model, x): ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: - model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models - for name in ["exported_model", "optimized_model"]: - onx = getattr(model_onx, name) + for onnx_model in [ + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model, + ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model, + ]: opv = None - for op in onx.opset_import: + for op in onnx_model.opset_import: if not op.domain: opv = op.version assert opv == 13 @@ -5323,7 +5348,9 @@ def test_opset_version_change(opset_version): prediction.backward() # Check opset version on ONNX model - exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model + exported_model = ort_model._torch_module._execution_manager( + ort_model._is_training() + )._graph_transition_manager._exported_model_info.exported_model assert exported_model.opset_import[0].version == opset_version if original_env is not None: @@ -5334,6 +5361,7 @@ def test_serialize_ortmodule(): device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 pt_model = SerializationNet(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) x_1 = torch.randn(N, D_in, device=device) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 99c15034cdafe..95012aa0507a5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1415,10 +1415,12 @@ def check_pythonop_training_mode(model, is_eval_mode): ## make sure the ort's PythonOp's training_mode is correct if is_eval_mode: onnx_nodes = ( - model._torch_module._execution_manager._inference_manager._onnx_models.exported_model.graph.node + model._torch_module._execution_manager._inference_manager._graph_transition_manager._exported_model_info.exported_model.graph.node ) else: - onnx_nodes = model._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + onnx_nodes = ( + model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node + ) found_pythonop = False for node in onnx_nodes: @@ -1837,7 +1839,9 @@ def forward(self, model_input): ortmodule = ORTModule(TestModel(output_size)).train() _ = ortmodule(torch.randn(output_size, dtype=torch.float)) - onnx_nodes = ortmodule._torch_module._execution_manager._training_manager._onnx_models.exported_model.graph.node + onnx_nodes = ( + ortmodule._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model.graph.node + ) found_pythonop = False for node in onnx_nodes: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 34453c89157a8..4e0fcafecffe5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -49,6 +49,7 @@ def test_ortmodule_fallback_forward(is_training, fallback_enabled, matching_poli class Point: x: int y: int + device: str = "cpu" # Otherwise, no device can be found from inputs, and the test will fail earlier. class UnsupportedInputModel(torch.nn.Module): def __init__(self): @@ -78,11 +79,17 @@ def forward(self, point): else: with pytest.raises(_fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) - assert "ORTModule fails to extract schema from data" in str(type_error.value) + assert ( + "ORTModule does not support input type .Point'> for input point" + in str(type_error.value) + ) else: with pytest.raises(_fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) - assert "ORTModule fails to extract schema from data" in str(type_error.value) + assert ( + "ORTModule does not support input type .Point'> for input point" + in str(type_error.value) + ) @pytest.mark.parametrize( @@ -250,11 +257,17 @@ def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_polic else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert "ORTModule does not support the following model output type" in str(runtime_error.value) + assert ( + "ORTModule fails to extract schema from data: Unsupported flatten data type: " + in str(runtime_error.value) + ) @pytest.mark.parametrize( @@ -302,20 +315,18 @@ def __init__(self, x): with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) assert ( - "ORTModule fails to extract schema from data: " - "Unsupported flatten data type: " - ".CustomClass'>" in str(ex_info.value) + "ORTModule does not support input type " + ".CustomClass'> " + "for input custom_class_obj" in str(ex_info.value) ) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) - assert ( - "ORTModule fails to extract schema from data: " - "Unsupported flatten data type: " - ".CustomClass'>" in str(ex_info.value) - ) + assert ( + "ORTModule does not support input type " + ".CustomClass'> " + "for input custom_class_obj" in str(ex_info.value) + ) @pytest.mark.parametrize( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index df0b5f195f0b9..e1def2022d63f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -69,7 +69,9 @@ def run_step(model, x): self.assert_values_are_close(ort_prediction, pt_prediction, **kwargs) self.assert_gradients_match_and_reset_gradient(ort_model, pt_model, **kwargs) - onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._onnx_models.exported_model + onnx_graph_inf = ( + ort_model._torch_module._execution_manager._training_manager._graph_transition_manager._exported_model_info.exported_model + ) onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model if debug: with open("debug_%s_ortmodule_infer.onnx" % name, "wb") as f: @@ -146,6 +148,38 @@ def test_onnx_ops(self): device = torch.device(device_name) self.gradient_correctness(name, device) + @unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support") + def test_softmax_bf16_large(self): + if torch.version.cuda is None: + # Only run this test when CUDA is available, as on ROCm BF16 is not supported by MIOpen. + return + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.softmax(input, dim=-1) + return out + + device = "cuda:0" + input_shape = [2, 4096] + # run torch to get the expected result + data_torch = torch.randn(size=input_shape, device=device, dtype=torch.bfloat16) + 10 + data_torch.requires_grad = True + torch_model = Model() + torch_res = torch_model(input=data_torch) + init_grad = torch.ones_like(torch_res) + torch_res.backward(gradient=init_grad) + # run ort + ort_model = ORTModule(torch_model) + data_ort = data_torch.detach().clone() + data_ort.requires_grad = True + ort_res = ort_model(input=data_ort) + ort_res.backward(gradient=init_grad) + # compare result + torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h index 542acb2e337ab..2340104840a6e 100644 --- a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h +++ b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h @@ -11,7 +11,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index cb355ed04e907..56029b34c24d7 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -330,7 +330,7 @@ Status FromTensorProtos(gsl::span trainable_t for (const auto& tensor_proto : tensor_protos) { flatbuffers::Offset fbs_tensor; ORT_RETURN_IF_ERROR( - fbs::utils::SaveInitializerOrtFormat(builder, tensor_proto, Path(), fbs_tensor, external_data_writer)); + fbs::utils::SaveInitializerOrtFormat(builder, tensor_proto, std::filesystem::path(), fbs_tensor, external_data_writer)); fbs_tensors.push_back(fbs_tensor); } diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 347673628e106..dc724fbae48eb 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -685,7 +685,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path ORT_THROW_IF_ERROR( Model::SaveWithExternalInitializers(*inference_model, inference_model_pathstring, external_data_name, 64)); } else { - ORT_THROW_IF_ERROR(Model::Save(*inference_model, inference_model_path)); + ORT_THROW_IF_ERROR(Model::Save(*inference_model, ToPathString(inference_model_path))); } // Save the model at the desired location. return Status::OK(); diff --git a/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc index f42fa0c4dc984..d3f2f9c7a8767 100644 --- a/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cpu/activation/activations_grad.h" -#include "core/common/gsl.h" +#include #if defined(_MSC_VER) #pragma warning(push) diff --git a/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc b/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc index 9aad7edf68bfa..3dd52f36ec72a 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc @@ -7,7 +7,7 @@ #include "core/providers/common.h" #include #include "core/providers/cpu/math/matmul_helper.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc index b3c04af1a2659..c74bf06a77d6e 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc @@ -10,7 +10,7 @@ #include "core/providers/cpu/controlflow/scan_utils.h" #include "orttraining/training_ops/cpu/loss/cross_entropy.h" #include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index c3476161c1e53..f4b9c08bd90cd 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cpu/op_gradients.h" -#include "core/common/gsl.h" +#include #include "core/mlas/inc/mlas.h" #include "core/providers/common.h" #include "core/providers/cpu/math/element_wise_ops.h" diff --git a/orttraining/orttraining/training_ops/cpu/tensor/split.cc b/orttraining/orttraining/training_ops/cpu/tensor/split.cc index d361f3ec64e39..1edfdac7631ee 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/split.cc +++ b/orttraining/orttraining/training_ops/cpu/tensor/split.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cpu/tensor/split.h" -#include "core/common/gsl.h" +#include #include "core/common/narrow.h" #include "core/providers/common.h" #include "core/util/math.h" diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h index 3fdb4306bfbbb..ec227c4641766 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -54,15 +54,15 @@ struct ConvArgs { }; struct ConvParamsHash { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); + // ConvParams must be a trivial type because we read out its memory contents as char* when hashing. + static_assert(std::is_trivial::value, "ConvParams is not a trivial type"); size_t operator()(const ConvParams& conv_params) const; }; struct ConvParamsEqual { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); + // ConvParams must be a trivial type because we read out its memory contents as char* when hashing. + static_assert(std::is_trivial::value, "ConvParams is not a trivial type"); bool operator()(const ConvParams& a, const ConvParams& b) const; }; diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index 7bd759e8976c1..f3feef4391bb5 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor { typedef typename ToCudaType::MappedType CudaT; const CudaT* input_data = reinterpret_cast(input_tensor.Data()); - CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT))); + CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT), + stream)); PadAndUnflattenImpl(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound, input_data, indices_tensor.Data(), reinterpret_cast(output_tensor.MutableData())); @@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { const Tensor* input_tensor = context->Input(0); const Tensor* indices_tensor = context->Input(1); const Tensor* unflatten_dims_tensor = context->Input(2); // Parse the 1-D shape tensor. + ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1, "unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions()); ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2, diff --git a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc index 9da0725274641..22fa5b6f55a5d 100644 --- a/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/rocm/nn/conv_grad.cc @@ -71,8 +71,8 @@ std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { } struct ConvParamsHash { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); + // ConvParams must be a trivial type because we read out its memory contents as char* when hashing. + static_assert(std::is_trivial::value, "ConvParams is not a trivial type"); size_t operator()(const ConvParams& conv_params) const { auto ptr = reinterpret_cast(&conv_params); uint32_t value = 0x811C9DC5; @@ -85,8 +85,8 @@ struct ConvParamsHash { }; struct ConvParamsEqual { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); + // ConvParams must be a trivial type because we read out its memory contents as char* when hashing. + static_assert(std::is_trivial::value, "ConvParams is not a trivial type"); bool operator()(const ConvParams& a, const ConvParams& b) const { auto ptr1 = reinterpret_cast(&a); auto ptr2 = reinterpret_cast(&b); diff --git a/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.1.json b/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.1.json new file mode 100644 index 0000000000000..05fcf08cd3232 --- /dev/null +++ b/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.1.json @@ -0,0 +1,57 @@ +{ + "steps": [ + { + "step": 20, + "loss": 2.0136 + }, + { + "step": 40, + "loss": 1.8466 + }, + { + "step": 60, + "loss": 1.7525 + }, + { + "step": 80, + "loss": 1.6682 + }, + { + "step": 100, + "loss": 1.658 + }, + { + "step": 120, + "loss": 1.6749 + }, + { + "step": 140, + "loss": 1.6263 + }, + { + "step": 160, + "loss": 1.6828 + }, + { + "step": 180, + "loss": 1.6145 + }, + { + "step": 200, + "loss": 1.6197 + }, + { + "step": 220, + "loss": 1.6353 + }, + { + "step": 240, + "loss": 1.5266 + }, + { + "step": 260, + "loss": 1.5441 + } + ], + "samples_per_second": 34.561 +} diff --git a/requirements.txt.in b/requirements.txt similarity index 60% rename from requirements.txt.in rename to requirements.txt index 89242061fb119..2fd9362c949dd 100644 --- a/requirements.txt.in +++ b/requirements.txt @@ -1,6 +1,6 @@ coloredlogs flatbuffers -numpy >= @Python_NumPy_VERSION@ +numpy >= 1.21.6 packaging protobuf sympy diff --git a/setup.py b/setup.py index 3203993e0c4d4..51feedcfd3286 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ def parse_arg_remove_string(argv, arg_name_equal): cuda_version = None rocm_version = None +is_migraphx = False is_rocm = False is_openvino = False # The following arguments are mutually exclusive @@ -64,8 +65,9 @@ def parse_arg_remove_string(argv, arg_name_equal): cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=") elif parse_arg_remove_boolean(sys.argv, "--use_rocm"): is_rocm = True - package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=") +elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): + is_migraphx = True elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): is_openvino = True package_name = "onnxruntime-openvino" @@ -87,6 +89,9 @@ def parse_arg_remove_string(argv, arg_name_equal): elif parse_arg_remove_boolean(sys.argv, "--use_qnn"): package_name = "onnxruntime-qnn" +if is_rocm or is_migraphx: + package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" + # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 # PEP 599 defines the following platform tags: @@ -280,10 +285,21 @@ def finalize_options(self): return ret -providers_cuda_or_rocm = "libonnxruntime_providers_" + ("rocm.so" if is_rocm else "cuda.so") -providers_tensorrt_or_migraphx = "libonnxruntime_providers_" + ("migraphx.so" if is_rocm else "tensorrt.so") -providers_openvino = "libonnxruntime_providers_openvino.so" -providers_cann = "libonnxruntime_providers_cann.so" +providers_cuda_or_rocm = "onnxruntime_providers_" + ("rocm" if is_rocm else "cuda") +providers_tensorrt_or_migraphx = "onnxruntime_providers_" + ("migraphx" if is_migraphx else "tensorrt") +providers_openvino = "onnxruntime_providers_openvino" +providers_cann = "onnxruntime_providers_cann" + +if platform.system() == "Linux": + providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so" + providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so" + providers_openvino = "lib" + providers_openvino + ".so" + providers_cann = "lib" + providers_cann + ".so" +elif platform.system() == "Windows": + providers_cuda_or_rocm = providers_cuda_or_rocm + ".dll" + providers_tensorrt_or_migraphx = providers_tensorrt_or_migraphx + ".dll" + providers_openvino = providers_openvino + ".dll" + providers_cann = providers_cann + ".dll" # Additional binaries dl_libs = [] @@ -297,11 +313,13 @@ def finalize_options(self): "libmklml_gnu.so", "libiomp5.so", "mimalloc.so", + "libonnxruntime.so*", ] dl_libs = ["libonnxruntime_providers_shared.so"] dl_libs.append(providers_cuda_or_rocm) dl_libs.append(providers_tensorrt_or_migraphx) dl_libs.append(providers_cann) + dl_libs.append("libonnxruntime.so*") # DNNL, TensorRT & OpenVINO EPs are built as shared libs libs.extend(["libonnxruntime_providers_shared.so"]) libs.extend(["libonnxruntime_providers_dnnl.so"]) @@ -313,7 +331,12 @@ def finalize_options(self): if nightly_build: libs.extend(["libonnxruntime_pywrapper.so"]) elif platform.system() == "Darwin": - libs = ["onnxruntime_pybind11_state.so", "libdnnl.2.dylib", "mimalloc.so"] # TODO add libmklml and libiomp5 later. + libs = [ + "onnxruntime_pybind11_state.so", + "libdnnl.2.dylib", + "mimalloc.so", + "libonnxruntime.dylib*", + ] # TODO add libmklml and libiomp5 later. # DNNL & TensorRT EPs are built as shared libs libs.extend(["libonnxruntime_providers_shared.dylib"]) libs.extend(["libonnxruntime_providers_dnnl.dylib"]) @@ -323,7 +346,16 @@ def finalize_options(self): if nightly_build: libs.extend(["libonnxruntime_pywrapper.dylib"]) else: - libs = ["onnxruntime_pybind11_state.pyd", "dnnl.dll", "mklml.dll", "libiomp5md.dll"] + libs = [ + "onnxruntime_pybind11_state.pyd", + "dnnl.dll", + "mklml.dll", + "libiomp5md.dll", + providers_cuda_or_rocm, + providers_tensorrt_or_migraphx, + providers_cann, + "onnxruntime.dll", + ] # DNNL, TensorRT & OpenVINO EPs are built as shared libs libs.extend(["onnxruntime_providers_shared.dll"]) libs.extend(["onnxruntime_providers_dnnl.dll"]) @@ -376,7 +408,7 @@ def finalize_options(self): dl_libs.append("plugins.xml") dl_libs.append("usb-ma2x8x.mvcmd") data = ["capi/libonnxruntime_pywrapper.so"] if nightly_build else [] - data += [path.join("capi", x) for x in dl_libs if path.isfile(path.join("onnxruntime", "capi", x))] + data += [path.join("capi", x) for x in dl_libs if glob(path.join("onnxruntime", "capi", x))] ext_modules = [ Extension( "onnxruntime.capi.onnxruntime_pybind11_state", @@ -384,7 +416,7 @@ def finalize_options(self): ), ] else: - data = [path.join("capi", x) for x in libs if path.isfile(path.join("onnxruntime", "capi", x))] + data = [path.join("capi", x) for x in libs if glob(path.join("onnxruntime", "capi", x))] ext_modules = [] # Additional examples diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 6571073963c70..42f669069e365 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -79,6 +79,7 @@ def _openvino_verify_device_type(device_read): "CPU_NO_PARTITION", "GPU_NO_PARTITION", "NPU_NO_PARTITION", + "NPU_NO_CPU_FALLBACK", ] status_hetero = True res = False @@ -557,6 +558,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--use_snpe", action="store_true", help="Build with SNPE support.") parser.add_argument("--snpe_root", help="Path to SNPE SDK root.") parser.add_argument("--use_nnapi", action="store_true", help="Build with NNAPI support.") + parser.add_argument("--use_vsinpu", action="store_true", help="Build with VSINPU support.") parser.add_argument( "--nnapi_min_api", type=int, help="Minimum Android API level to enable NNAPI, should be no less than 27" ) @@ -602,17 +604,13 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument( "--enable_msvc_static_runtime", action="store_true", help="Enable static linking of MSVC runtimes." ) - parser.add_argument( - "--enable_language_interop_ops", - action="store_true", - help="Enable operator implemented in language other than cpp", - ) parser.add_argument( "--cmake_generator", choices=[ "MinGW Makefiles", "Ninja", "NMake Makefiles", + "NMake Makefiles JOM", "Unix Makefiles", "Visual Studio 17 2022", "Xcode", @@ -1011,6 +1009,7 @@ def generate_build_tree( "-Donnxruntime_BUILD_APPLE_FRAMEWORK=" + ("ON" if args.build_apple_framework else "OFF"), "-Donnxruntime_USE_DNNL=" + ("ON" if args.use_dnnl else "OFF"), "-Donnxruntime_USE_NNAPI_BUILTIN=" + ("ON" if args.use_nnapi else "OFF"), + "-Donnxruntime_USE_VSINPU=" + ("ON" if args.use_vsinpu else "OFF"), "-Donnxruntime_USE_RKNPU=" + ("ON" if args.use_rknpu else "OFF"), "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"), @@ -1040,7 +1039,6 @@ def generate_build_tree( else "OFF" ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), - "-Donnxruntime_ENABLE_LANGUAGE_INTEROP_OPS=" + ("ON" if args.enable_language_interop_ops else "OFF"), "-Donnxruntime_USE_DML=" + ("ON" if args.use_dml else "OFF"), "-Donnxruntime_USE_WINML=" + ("ON" if args.use_winml else "OFF"), "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), @@ -1066,7 +1064,7 @@ def generate_build_tree( "-Donnxruntime_USE_NCCL=" + ("ON" if args.enable_nccl else "OFF"), "-Donnxruntime_BUILD_BENCHMARKS=" + ("ON" if args.build_micro_benchmarks else "OFF"), "-Donnxruntime_USE_ROCM=" + ("ON" if args.use_rocm else "OFF"), - "-DOnnxruntime_GCOV_COVERAGE=" + ("ON" if args.code_coverage else "OFF"), + "-Donnxruntime_GCOV_COVERAGE=" + ("ON" if args.code_coverage else "OFF"), "-Donnxruntime_USE_MPI=" + ("ON" if args.use_mpi else "OFF"), "-Donnxruntime_ENABLE_MEMORY_PROFILE=" + ("ON" if args.enable_memory_profile else "OFF"), "-Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=" + ("ON" if args.enable_cuda_line_info else "OFF"), @@ -1220,6 +1218,7 @@ def generate_build_tree( if args.use_openvino: cmake_args += [ "-Donnxruntime_USE_OPENVINO=ON", + "-Donnxruntime_NPU_NO_FALLBACK=" + ("ON" if args.use_openvino == "NPU_NO_CPU_FALLBACK" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU=" + ("ON" if args.use_openvino == "GPU" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU=" + ("ON" if args.use_openvino == "CPU" else "OFF"), "-Donnxruntime_USE_OPENVINO_NPU=" + ("ON" if args.use_openvino == "NPU" else "OFF"), @@ -2207,6 +2206,7 @@ def build_python_wheel( use_cuda, cuda_version, use_rocm, + use_migraphx, rocm_version, use_dnnl, use_tensorrt, @@ -2258,6 +2258,8 @@ def build_python_wheel( args.append("--use_rocm") if rocm_version: args.append(f"--rocm_version={rocm_version}") + elif use_migraphx: + args.append("--use_migraphx") elif use_openvino: args.append("--use_openvino") elif use_dnnl: @@ -2583,13 +2585,19 @@ def main(): if args.use_tensorrt: args.use_cuda = True - if args.use_migraphx: - args.use_rocm = True - if args.build_wheel or args.gen_doc or args.use_tvm or args.enable_training: args.enable_pybind = True - if args.build_csharp or args.build_nuget or args.build_java or args.build_nodejs: + if ( + args.build_csharp + or args.build_nuget + or args.build_java + or args.build_nodejs + or (args.enable_pybind and not args.enable_training) + ): + # If pyhon bindings are enabled, we embed the shared lib in the python package. + # If training is enabled, we don't embed the shared lib in the python package since training requires + # torch interop. args.build_shared_lib = True if args.build_nuget and cross_compiling: @@ -2875,7 +2883,8 @@ def main(): # fail unexpectedly. Similar, if your packaging step forgot to copy a file into the package, we don't know it # either. if args.build: - # TODO: find asan DLL and copy it to onnxruntime/capi folder when args.enable_address_sanitizer is True and the target OS is Windows + # TODO: find asan DLL and copy it to onnxruntime/capi folder when args.enable_address_sanitizer is True and + # the target OS is Windows if args.build_wheel: nightly_build = bool(os.getenv("NIGHTLY_BUILD") == "1") default_training_package_device = bool(os.getenv("DEFAULT_TRAINING_PACKAGE_DEVICE") == "1") @@ -2886,6 +2895,7 @@ def main(): args.use_cuda, args.cuda_version, args.use_rocm, + args.use_migraphx, args.rocm_version, args.use_dnnl, args.use_tensorrt, diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index b53fb33659120..fe47d8dbe57fe 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -79,6 +79,7 @@ def parse_arguments(): "cann", "dnnl", "tensorrt", + "azure", ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") diff --git a/tools/ci_build/github/android/default_full_aar_build_settings.json b/tools/ci_build/github/android/default_full_aar_build_settings.json index 467f7048942cb..b0eff75812673 100644 --- a/tools/ci_build/github/android/default_full_aar_build_settings.json +++ b/tools/ci_build/github/android/default_full_aar_build_settings.json @@ -8,6 +8,7 @@ "android_min_sdk_version": 21, "android_target_sdk_version": 24, "build_params": [ + "--enable_lto", "--android", "--parallel", "--cmake_generator=Ninja", diff --git a/tools/ci_build/github/android/training_full_aar_build_settings.json b/tools/ci_build/github/android/training_full_aar_build_settings.json index 76cb9f0b17a8c..013804e2d63e9 100644 --- a/tools/ci_build/github/android/training_full_aar_build_settings.json +++ b/tools/ci_build/github/android/training_full_aar_build_settings.json @@ -8,6 +8,7 @@ "android_min_sdk_version": 21, "android_target_sdk_version": 24, "build_params": [ + "--enable_lto", "--android", "--parallel", "--cmake_generator=Ninja", diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index dec05ae066a4a..1bbb933f66ba4 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -18,3 +18,6 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Relu|| |ai.onnx:Reshape|| |ai.onnx:Sub|| +|ai.onnx:Sigmoid|| +|ai:onnx:Tanh|| +|ai:onnx:Transpose|| diff --git a/tools/ci_build/github/apple/test_ios_framework_build_settings.json b/tools/ci_build/github/apple/test_ios_framework_build_settings.json new file mode 100644 index 0000000000000..0572df6ecf72e --- /dev/null +++ b/tools/ci_build/github/apple/test_ios_framework_build_settings.json @@ -0,0 +1,30 @@ +{ + "build_osx_archs": { + "iphoneos": [ + "arm64" + ], + "iphonesimulator": [ + "arm64", + "x86_64" + ] + }, + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--use_coreml", + "--use_xnnpack", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" + ], + "iphoneos": [ + "--ios", + "--apple_deploy_target=13.0" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=13.0" + ] + } +} diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 1703490992fb4..a4a3d0e6b334b 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 54e83b03aa614..10d9a9a24d88a 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -48,11 +48,12 @@ parameters: stages: # Separate stage for building CPU vs NNAPI as we only want CodeQL to run on one of them so we don't get duplicate # issues for code that is built in both. We pick NNAPI as that includes the NNAPI EP code. -- stage: BUILD_CPU_STAGE +- stage: BUILD_AND_TEST_CPU + dependsOn: [] variables: Codeql.Enabled: false jobs: - - job: Build_CPU_EP + - job: BUILD_AND_TEST_CPU pool: onnxruntime-Ubuntu2204-AMD-CPU workspace: clean: all @@ -77,14 +78,17 @@ stages: - script: sudo apt-get update -y && sudo apt-get install -y coreutils ninja-build displayName: Install coreutils and ninja - - template: "templates/use-android-ndk.yml" - + - template: templates/use-android-ndk.yml + - template: templates/use-android-emulator.yml + parameters: + create: true + start: true - script: | env | grep ANDROID displayName: View Android ENVs - - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build \ --android_sdk_path $ANDROID_HOME \ @@ -94,41 +98,17 @@ stages: --skip_submodule_sync \ --parallel \ --cmake_generator=Ninja \ - --build_java \ - --skip_tests - displayName: CPU EP, Build - - - task: CopyFiles@2 - displayName: Copy apks - inputs: - contents: 'build/**/*.apk' - targetFolder: $(Build.ArtifactStagingDirectory) - overWrite: true - - - task: CopyFiles@2 - displayName: Copy test data - inputs: - contents: 'build/**/testdata/**' - targetFolder: $(Build.ArtifactStagingDirectory) - overWrite: true - - - task: CopyFiles@2 - displayName: Copy test executables - inputs: - contents: | - build/Debug/* - build/Debug/java/androidtest/android/** - targetFolder: $(Build.ArtifactStagingDirectory) - overWrite: true - - - task: PublishBuildArtifacts@1 - inputs: - pathToPublish: $(Build.ArtifactStagingDirectory) - artifactName: CPUBuildOutput + --build_java + displayName: CPU EP, Build and Test + - template: templates/use-android-emulator.yml + parameters: + stop: true - template: templates/clean-agent-build-directory-step.yml -- stage: BUILD_NNAPI_STAGE +- stage: BUILD_AND_TEST_NNAPI_EP + dependsOn: [] + condition: notIn(variables['Build.Reason'], 'IndividualCI', 'BatchedCI') variables: Codeql.ProjectConfigPath: .github/workflows Codeql.Enabled: true @@ -137,14 +117,12 @@ stages: JobsTimeout: 120 ${{ else }}: JobsTimeout: 60 - jobs: - - job: Build_NNAPI_EP + - job: BUILD_AND_TEST_NNAPI_EP pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: ${{ variables.JobsTimeout }} workspace: clean: all - condition: notIn(variables['Build.Reason'], 'IndividualCI', 'BatchedCI') steps: - task: UsePythonVersion@0 displayName: Use Python $(pythonVersion) @@ -160,8 +138,10 @@ stages: - script: sudo apt-get update -y && sudo apt-get install -y coreutils ninja-build displayName: Install coreutils and ninja - - - template: "templates/use-android-ndk.yml" + - template: templates/use-android-emulator.yml + parameters: + create: true + start: true - script: | env | grep ANDROID @@ -169,191 +149,31 @@ stages: - script: | python3 tools/ci_build/build.py \ - --android \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME \ - --android_ndk_path $ANDROID_NDK_HOME \ - --android_abi=x86_64 \ - --android_api=29 \ - --skip_submodule_sync \ - --parallel \ - --use_nnapi \ - --cmake_generator=Ninja \ - --build_java \ - --skip_tests - displayName: NNAPI EP, Build - - - task: CopyFiles@2 - displayName: Copy apks - inputs: - contents: 'build_nnapi/**/*.apk' - targetFolder: $(Build.ArtifactStagingDirectory) - overWrite: true - - - task: CopyFiles@2 - displayName: Copy test data - inputs: - contents: 'build_nnapi/**/testdata/**' - targetFolder: $(Build.ArtifactStagingDirectory) - overWrite: true - - - task: CopyFiles@2 - displayName: Copy Test Executables - inputs: - contents: | - build_nnapi/Debug/* - build_nnapi/Debug/java/androidtest/android/** - targetFolder: $(Build.ArtifactStagingDirectory) - overWrite: true - - - task: PublishBuildArtifacts@1 - inputs: - pathToPublish: $(Build.ArtifactStagingDirectory) - artifactName: NNAPIBuildOutput + --enable_lto \ + --android \ + --build_dir build_nnapi \ + --android_sdk_path $ANDROID_HOME \ + --android_ndk_path $ANDROID_NDK_HOME \ + --android_abi=x86_64 \ + --android_api=29 \ + --skip_submodule_sync \ + --parallel \ + --use_nnapi \ + --build_shared_lib \ + --cmake_generator=Ninja \ + --build_java + displayName: NNAPI EP, Build, Test on Android Emulator + + - script: /bin/bash tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh $(pwd) + # Build Minimal ORT with NNAPI and reduced Ops, run unit tests on Android Emulator + displayName: Build Minimal ORT with NNAPI and run tests + + - template: templates/use-android-emulator.yml + parameters: + stop: true - template: templates/clean-agent-build-directory-step.yml -- stage: TEST_STAGE - dependsOn: [BUILD_CPU_STAGE, BUILD_NNAPI_STAGE] - jobs: - - job: Test_CPU_EP - pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' - workspace: - clean: all - condition: succeeded() - steps: - - script: | - set -ex - system_profiler SPSoftwareDataType SPHardwareDataType - displayName: 'Mac Agent Info' - - - task: DownloadPipelineArtifact@2 - inputs: - ${{ if eq(parameters.specificArtifact, true) }}: - source: 'specific' - project: 'onnxruntime' - pipeline: $(Build.DefinitionName) - runVersion: 'specific' - runId: ${{ parameters.runId }} - ${{ if ne(parameters.specificArtifact, true) }}: - source: 'current' - artifact: 'CPUBuildOutput' - path: $(Build.SourcesDirectory) - - - task: UsePythonVersion@0 - displayName: Use Python $(pythonVersion) - inputs: - versionSpec: $(pythonVersion) - - - task: JavaToolInstaller@0 - displayName: Use jdk 11 - inputs: - versionSpec: '11' - jdkArchitectureOption: 'x64' - jdkSourceOption: 'PreInstalled' - - - template: "templates/use-android-ndk.yml" - - - template: templates/use-android-emulator.yml - parameters: - create: true - start: true - - - script: | - python3 tools/ci_build/build.py \ - --android \ - --build_dir build \ - --android_sdk_path $ANDROID_HOME \ - --android_ndk_path $ANDROID_NDK_HOME \ - --android_abi=x86_64 \ - --android_api=30 \ - --build_java \ - --test - displayName: CPU EP, Test on Android Emulator - - - template: templates/use-android-emulator.yml - parameters: - stop: true - - - template: templates/clean-agent-build-directory-step.yml - - - job: Test_NNAPI_EP - pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' - timeoutInMinutes: 90 - workspace: - clean: all - condition: and(succeeded(), notIn(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) - steps: - - script: | - set -ex - system_profiler SPSoftwareDataType SPHardwareDataType - displayName: 'Mac Agent Info' - - - task: DownloadPipelineArtifact@2 - inputs: - ${{ if eq(parameters.specificArtifact, true) }}: - source: 'specific' - project: 'onnxruntime' - pipeline: $(Build.DefinitionName) - runVersion: 'specific' - runId: ${{ parameters.runId }} - ${{ if ne(parameters.specificArtifact, true) }}: - source: 'current' - artifact: 'NNAPIBuildOutput' - path: $(Build.SourcesDirectory) - - - task: UsePythonVersion@0 - displayName: Use Python $(pythonVersion) - inputs: - versionSpec: $(pythonVersion) - - - task: JavaToolInstaller@0 - displayName: Use jdk 11 - inputs: - versionSpec: '11' - jdkArchitectureOption: 'x64' - jdkSourceOption: 'PreInstalled' - - - template: "templates/use-android-ndk.yml" - - - template: templates/use-android-emulator.yml - parameters: - create: true - start: true - - - script: | - python3 tools/ci_build/build.py \ - --android \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME \ - --android_ndk_path $ANDROID_NDK_HOME \ - --android_abi=x86_64 \ - --android_api=29 \ - --build_java \ - --use_nnapi \ - --test - displayName: NNAPI EP, Test, CodeCoverage on Android Emulator - - # used by Build Minimal ORT - - script: brew install coreutils ninja - displayName: Install coreutils and ninja - - - script: /bin/bash tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh $(pwd) - # Build Minimal ORT with NNAPI and reduced Ops, run unit tests on Android Emulator - displayName: Build Minimal ORT with NNAPI and run tests - - - template: templates/use-android-emulator.yml - parameters: - stop: true - - - template: templates/clean-agent-build-directory-step.yml - - stage: MASTER_BUILD_STAGE # The below jobs only run on master build. # because coverage report is hard to support in cross machines. @@ -362,20 +182,12 @@ stages: condition: in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI') jobs: - job: NNAPI_EP_MASTER - pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 180 workspace: clean: all condition: in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI') steps: - - script: | - set -ex - system_profiler SPSoftwareDataType SPHardwareDataType - displayName: 'Mac Agent Info' - - task: UsePythonVersion@0 displayName: Use Python $(pythonVersion) inputs: @@ -388,11 +200,7 @@ stages: jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' - - template: "templates/use-android-ndk.yml" - - # used by Build Minimal ORT - - script: brew install coreutils ninja - displayName: Install coreutils and ninja + - template: templates/use-android-ndk.yml - template: templates/use-android-emulator.yml parameters: @@ -401,6 +209,7 @@ stages: - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build_nnapi \ --android_sdk_path $ANDROID_HOME \ @@ -422,50 +231,25 @@ stages: --build_dir build_nnapi \ --android_sdk_path $ANDROID_HOME displayName: Retrieve runtime code coverage files from the emulator and analyze + - script: cat '$(Build.SourcesDirectory)/build_nnapi/Debug/coverage_rpt.txt' displayName: Print coverage report - - task: PublishPipelineArtifact@0 - displayName: 'Publish code coverage report' - inputs: - artifactName: "coverage_rpt.txt" - targetPath: '$(Build.SourcesDirectory)/build_nnapi/Debug/coverage_rpt.txt' - publishLocation: 'pipeline' - - script: /bin/bash tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh $(pwd) # Build Minimal ORT with NNAPI and reduced Ops, run unit tests on Android Emulator displayName: Build Minimal ORT with NNAPI and run tests - - template: templates/use-android-emulator.yml - parameters: - stop: true - - - template: templates/clean-agent-build-directory-step.yml - - - job: Update_Dashboard - workspace: - clean: all - variables: - - name: skipComponentGovernanceDetection - value: true - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) - dependsOn: - - NNAPI_EP_MASTER - steps: - - task: DownloadPipelineArtifact@0 - displayName: 'Download code coverage report' - inputs: - artifactName: 'coverage_rpt.txt' - targetPath: '$(Build.BinariesDirectory)' - - task: AzureCLI@2 displayName: 'Post Android Code Coverage To DashBoard' inputs: azureSubscription: AIInfraBuild scriptType: bash scriptPath: $(Build.SourcesDirectory)/tools/ci_build/github/linux/upload_code_coverage_data.sh - arguments: '"$(Build.BinariesDirectory)/coverage_rpt.txt" "https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=$(Build.BuildId)" arm android nnapi' + arguments: '"$(Build.SourcesDirectory)/build_nnapi/Debug/coverage_rpt.txt" "https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=$(Build.BuildId)" arm android nnapi' workingDirectory: '$(Build.BinariesDirectory)' + - template: templates/use-android-emulator.yml + parameters: + stop: true + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 0c0cd8d0a870b..a66828ee5e188 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -38,20 +38,12 @@ parameters: type: number default: 0 -resources: - repositories: - - repository: LLaMa2Onnx - type: Github - endpoint: Microsoft - name: Microsoft/Llama-2-Onnx - ref: main - variables: - template: templates/common-variables.yml - name: docker_base_image value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 - name: linux_trt_version - value: 10.0.1.6-1.cuda11.8 + value: 10.2.0.19-1.cuda11.8 - name: Repository value: 'onnxruntimecuda11manylinuxbuild' @@ -287,11 +279,12 @@ stages: workingDirectory: $(Build.SourcesDirectory) condition: ne(variables.hitAnother, 'True') -- stage: Llama2_ONNX_FP16 +- stage: Llama2_7B_ONNX dependsOn: - Build_Onnxruntime_Cuda jobs: - - job: Llama2_ONNX_FP16 + - job: Llama2_7B_ONNX + timeoutInMinutes: 120 variables: skipComponentGovernanceDetection: true workspace: @@ -319,7 +312,7 @@ stages: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch Context: tools/ci_build/github/linux/docker/ ScriptName: tools/ci_build/get_docker_image.py DockerBuildArgs: " @@ -327,7 +320,7 @@ stages: --build-arg BASEIMAGE=${{ variables.docker_base_image }} --build-arg TRT_VERSION=${{ variables.linux_trt_version }} " - Repository: onnxruntimeubi8packagestest + Repository: onnxruntimeubi8packagestest_torch UpdateDepsTxt: false - task: DownloadPackage@1 @@ -343,7 +336,7 @@ stages: docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ -v $(Agent.TempDirectory)/meta_llama2_7b_hf:/meta-llama2 \ - onnxruntimeubi8packagestest \ + onnxruntimeubi8packagestest_torch \ bash -c " set -ex; \ pushd /workspace/onnxruntime/python/tools/transformers/ ; \ @@ -352,14 +345,56 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m pip uninstall -y torch ; \ - python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --small_gp;\ + ls -l llama2-7b-fp16; \ + du -sh llama2-7b-fp16; \ popd ; \ " displayName: 'Run Llama2 to Onnx F16 and parity Test' workingDirectory: $(Build.SourcesDirectory) + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/meta_llama2_7b_hf:/meta-llama2 \ + onnxruntimeubi8packagestest_torch \ + bash -c " + set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ + python3 -m pip install --upgrade pip ; \ + pushd models/llama ; \ + python3 -m pip install -r requirements.txt ; \ + popd ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda;\ + ls -l llama2-7b-fp32-gpu; \ + du -sh llama2-7b-fp32-gpu; \ + popd ; \ + " + displayName: 'Run Llama2 to Onnx fp32 and parity Test' + workingDirectory: $(Build.SourcesDirectory) + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/meta_llama2_7b_hf:/meta-llama2 \ + onnxruntimeubi8packagestest_torch \ + bash -c " + set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ + python3 -m pip install --upgrade pip ; \ + pushd models/llama ; \ + python3 -m pip install -r requirements.txt ; \ + popd ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-int4-gpu --precision int4 --execution_provider cuda --use_gqa;\ + ls -l llama2-7b-int4-gpu; \ + du -sh llama2-7b-int4-gpu; \ + popd ; \ + " + displayName: 'Run Llama2 to Onnx INT4 and parity Test' + workingDirectory: $(Build.SourcesDirectory) + - stage: Whisper_ONNX dependsOn: - Build_Onnxruntime_Cuda diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 45dc1487f0c66..700326fe9173c 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 resources: repositories: @@ -83,7 +83,7 @@ variables: value: 11.8 - name: win_trt_home - value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8 - name: win_cuda_home value: $(Agent.TempDirectory)\v11.8 @@ -94,28 +94,6 @@ stages: PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} -- stage: Debug - dependsOn: Setup - jobs: - - job: D1 - pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' - variables: - MyVar: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] - steps: - - checkout: none - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - bash: echo $(MyVar) - - bash: echo $(BuildTime) - - bash: echo $(BuildDate) - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - template: stages/download-java-tools-stage.yml - template: templates/c-api-cpu.yml @@ -167,280 +145,6 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} -# ROCm -- stage: Linux_C_API_Packaging_ROCm_x64 - dependsOn: [] - jobs: - - job: Linux_C_API_Packaging_ROCm_x64 - workspace: - clean: all - timeoutInMinutes: 240 - pool: onnxruntime-Ubuntu2204-AMD-CPU - variables: - RocmVersion: '5.6' - RocmVersionPatchSuffix: '' - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - submodules: recursive - - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml - submodules: false - - # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, - # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur - --build-arg BUILD_UID=$(id -u) - --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) - --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root - --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: - --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib - Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - CheckOutManyLinux: true - - - template: templates/set-version-number-variables-step.yml - - - task: Bash@3 - displayName: 'Build' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh - arguments: >- - -S $(Build.SourcesDirectory) - -B $(Build.BinariesDirectory) - -V $(RocmVersion) - -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - -P python3.10 - - - script: | - set -e -x - mkdir $(Build.ArtifactStagingDirectory)/testdata - cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata - ls -al $(Build.ArtifactStagingDirectory) - displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. - - - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml - parameters: - buildConfig: 'Release' - artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' - libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - template: templates/clean-agent-build-directory-step.yml - - -- stage: NuGet_Packaging_ROCm - dependsOn: - - Setup - - Linux_C_API_Packaging_ROCm_x64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_ROCm - workspace: - clean: all - # we need to use a 2022 pool to create the nuget package with MAUI targets. - # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing - pool: 'Onnxruntime-Win-CPU-2022' - variables: - breakCodesignValidationInjection: ${{ parameters.DoEsrp }} - ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] - - steps: - - checkout: self - submodules: true - fetchDepth: 1 - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-linux-x64-rocm' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'Reconstruct Build Directory' - inputs: - targetType: inline - script: | - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { - # *.tar will be created after *.tgz is extracted - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory - foreach ($ort_dir in $ort_dirs) - { - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname - } - - Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo - - - script: | - tree /F - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Inspect Build Binaries Directory' - - - script: | - mklink /D /J models C:\local\models - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Create models link' - - - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.2.1 - inputs: - versionSpec: 6.2.1 - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: > - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:IsLinuxBuild=true - -p:IsWindowsBuild=false - -p:IsMacOSBuild=false - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - configuration: RelWithDebInfo - platform: 'Any CPU' - msbuildArguments: > - -t:CreatePackage - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:CurrentTime=$(BuildTime) - -p:CurrentDate=$(BuildDate) - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - template: templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)' - DoEsrp: ${{ parameters.DoEsrp }} - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-ROCm' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 - displayName: 'Clean C#' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: RoslynAnalyzers@2 - displayName: 'Run Roslyn Analyzers' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildCommandline: > - "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\msbuild.exe" - $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln - -p:configuration="RelWithDebInfo" - -p:Platform="Any CPU" - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm - -p:IsLinuxBuild=true - -p:IsWindowsBuild=false - -p:IsMacOSBuild=false - condition: and(succeeded(), eq('${{ parameters.DoCompliance }}', true)) - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - -- template: nuget/templates/test_linux.yml - parameters: - AgentPool: AMD-GPU - ArtifactSuffix: 'ROCm' - StageSuffix: 'ROCm' - NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' - SpecificArtifact: ${{ parameters.specificArtifact }} - CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' - BuildId: ${{ parameters.BuildId }} - template: nuget/templates/dml-vs-2022.yml parameters: @@ -606,94 +310,3 @@ stages: - template: templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' - -- template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-x64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' - build_config: 'RelWithDebInfo' -- template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-arm64' - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' - build_config: 'RelWithDebInfo' - -- stage: NuGet_Packaging_QNN - pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' - dependsOn: - - OnnxRuntime_QNN_Nuget_Win_x64 - - OnnxRuntime_QNN_Nuget_Win_Arm64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_QNN - workspace: - clean: all - steps: - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet x64' - inputs: - artifactName: 'drop-nuget-qnn-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet arm64' - inputs: - artifactName: 'drop-nuget-qnn-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - - - task: PowerShell@2 - displayName: 'Bundle NuGet' - inputs: - targetType: 'inline' - script: | - - $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $nuget_package_name = $x64_nupkgs[0].Name - $x64_nuget_package = $x64_nupkgs[0].FullName - - $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) - - $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $x64_unzip_cmd - - $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $arm64_nuget_package = $arm64_nupkgs[0].FullName - - $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $arm64_unzip_cmd - - $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') - if (!(Test-Path $merged_nuget_path)) { - New-Item -Path $merged_nuget_path -ItemType Directory - } - - $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') - $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" - Invoke-Expression -Command $zip_cmd - - $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) - move $merged_zip $merged_nuget - workingDirectory: $(Build.BinariesDirectory) - - - template: templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index daf95af438d2b..9fd13b513e5fd 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -68,9 +68,9 @@ variables: value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 - name: win_trt_home ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4 + value: $(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5 - name: win_cuda_home ${{ if eq(parameters.CudaVersion, '11.8') }}: value: $(Agent.TempDirectory)\v11.8 diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 5f63339fb0d00..3f9707ff50519 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -43,9 +43,9 @@ variables: value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240610.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.0.1.6-1.cuda11.8 + value: 10.2.0.19-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.0.1.6-1.cuda12.4 + value: 10.2.0.19-1.cuda12.5 jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index f36cd9cfbfca1..6bf6324252fb9 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -36,7 +36,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.0 + value: 6.1 - name: RocmVersionPatchSuffix value: ".3" diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 45d763384ee2c..43043633365b2 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -32,5 +32,5 @@ jobs: parameters: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' - RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2024.0.0 -x "--use_openvino CPU --build_wheel"' + RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.0.0 -x "--use_openvino CPU --build_wheel"' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index a1339652a9495..29ebf67dd3f91 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml index 5894631739ac8..8fa2b805dadc9 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml @@ -32,5 +32,5 @@ stages: parameters: AllowReleasedOpsetOnly: 0 BuildForAllArchs: false - AdditionalBuildFlags: --build_objc --enable_language_interop_ops --build_wheel --use_xnnpack + AdditionalBuildFlags: --build_objc --build_wheel --use_xnnpack WithCache: true diff --git a/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml b/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml index b0cc253ae0973..4bfd726f5c58c 100644 --- a/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml @@ -1,7 +1,7 @@ resources: pipelines: - pipeline: build - source: 'Nuget-CUDA-Packaging-Pipeline' + source: 'CUDA-Zip-Nuget-Java-Packaging-Pipeline' trigger: branches: include: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 344c9a8f14022..f20a1ae3e1cd9 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -135,9 +135,9 @@ stages: - task: NuGetToolInstaller@0 - displayName: Use Nuget 5.7.0 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 5.7.0 + versionSpec: 6.10.x - task: MSBuild@1 displayName: 'Restore NuGet Packages' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index b9a5383836447..56e9c73a10a82 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -61,7 +61,7 @@ stages: ${{ if eq(parameters.CudaVersion, '12.2') }}: DockerBuildArgs: " --build-arg BASEIMAGE=nvidia/cuda:12.2.2-devel-ubuntu20.04 - --build-arg TRT_VERSION=10.0.1.6-1+cuda12.4 + --build-arg TRT_VERSION=10.2.0.19-1+cuda12.5 --build-arg BUILD_UID=$( id -u ) " ${{ else }}: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml index c582a836c7dbd..869374aa5a6e7 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml @@ -45,9 +45,9 @@ stages: architecture: x64 - task: NuGetToolInstaller@0 - displayName: Use Nuget 5.7.0 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 5.7.0 + versionSpec: 6.10.x - ${{ if ne( parameters.CudaVersion, '') }}: - template: ../../templates/jobs/download_win_gpu_library.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 7ada4ee6757c9..0e1afdcc5b8ca 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -25,7 +25,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.0 + value: 6.1 - name: RocmVersionPatchSuffix value: ".3" - name: BuildConfig @@ -255,6 +255,33 @@ jobs: arguments: -n $(Agent.Name) -d $HIP_VISIBLE_DEVICES -r $DRIVER_RENDER displayName: 'Check ROCm Environment' + # TODO: move to use ci_build/build.py driven tests + - task: CmdLine@2 + inputs: + script: |- + docker run --rm \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --device=/dev/kfd \ + --device=/dev/dri/renderD$DRIVER_RENDER \ + --group-add $(video) \ + --group-add $(render) \ + --user onnxruntimedev \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + -e OPENBLAS_NUM_THREADS=1 \ + -e OPENMP_NUM_THREADS=1 \ + -e MKL_NUM_THREADS=1 \ + -e PYTHONPATH=/build/$(BuildConfig) \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ + /bin/bash -c " + set -ex; \ + pip install -r /onnxruntime_src/tools/ci_build/requirements-transformers-test.txt; \ + pytest /onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py -v -n 4 --reruns 1" + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Run tranformers tests' + condition: succeededOrFailed() + - task: CmdLine@2 inputs: script: |- diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index d6a3fa3147a47..593d45361324e 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -226,7 +226,7 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_trt.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo @@ -446,14 +446,15 @@ stages: python tools/ci_build/github/apple/build_apple_framework.py \ --build_dir "$(Build.BinariesDirectory)/ios_framework" \ --build_dynamic_framework \ - tools/ci_build/github/apple/default_full_apple_framework_build_settings.json + tools/ci_build/github/apple/test_ios_framework_build_settings.json displayName: "Build iOS dynamic framework" - script: | python tools/ci_build/github/apple/test_apple_packages.py \ --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \ --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \ - --variant Full + --variant Full \ + --skip_macos_test displayName: "Test pod with iOS framework" - stage: IosMinimalTrainingBuild diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index 8ce7915da76d1..e0c588413415b 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -10,148 +10,22 @@ resources: branch: main stages: -- stage: Publish_NuGet_Package_And_Report - jobs: - - job: Publish_NuGet_Package_And_Report - workspace: - clean: all - variables: - - name: GDN_CODESIGN_TARGETDIRECTORY - value: '$(Agent.TempDirectory)\binfiles' - pool: 'onnxruntime-Win-CPU-2022' - - steps: - # https://learn.microsoft.com/en-us/azure/devops/pipelines/yaml-schema/resources-pipelines-pipeline?view=azure-pipelines#pipeline-resource-metadata-as-predefined-variables - - script: | - echo $(resources.pipeline.build.sourceBranch) - echo $(Build.Reason) - displayName: 'Print triggering sourceBranch Name in resources' - - - checkout: self - submodules: false - - template: templates/set-version-number-variables-step.yml - - - script: mkdir "$(Build.BinariesDirectory)\nuget-artifact\final-package" - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-CPU' - - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-CPU\*" "$(Build.BinariesDirectory)\nuget-artifact\final-package" - - - template: nuget/templates/get-nuget-package-version-as-variable.yml - parameters: - packageFolder: '$(Build.BinariesDirectory)/nuget-artifact/final-package' - - - task: CmdLine@2 - displayName: 'Post binary sizes to the dashboard database using command line' - inputs: - script: | - echo changing directory to artifact download path - cd $(Build.BinariesDirectory)/nuget-artifact/final-package - echo processing nupkg - SETLOCAL EnableDelayedExpansion - FOR /R %%i IN (*.nupkg) do ( - set filename=%%~ni - IF NOT "!filename:~25,7!"=="Managed" ( - echo processing %%~ni.nupkg - copy %%~ni.nupkg %%~ni.zip - echo copied to zip - echo listing lib files in the zip - REM use a single .csv file to put the data - echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - ) - ) - - - task: AzureCLI@2 - displayName: 'Azure CLI' - #Only report binary sizes to database if the build build was auto-triggered from the main branch - condition: and (succeeded(), and(eq(variables['resources.pipeline.build.sourceBranch'], 'refs/heads/main'), eq(variables['Build.Reason'], 'ResourceTrigger'))) - inputs: - azureSubscription: AIInfraBuildOnnxRuntimeOSS - scriptLocation: inlineScript - scriptType: batch - inlineScript: | - python.exe -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_to_dashboard\requirements.txt && ^ - python.exe $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_binary_sizes_to_dashboard.py --commit_hash=$(Build.SourceVersion) --size_data_file=binary_size_data.txt --build_project=Lotus --build_id=$(Build.BuildId) - workingDirectory: '$(Build.BinariesDirectory)' - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-dml' - - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-Training-CPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-GPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' - artifact: 'drop-signed-nuget-ROCm' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' - artifact: 'drop-signed-nuget-qnn' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-qnn\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - - script: | - dir $(Build.BinariesDirectory)\nuget-artifact\final-package - cd $(Build.BinariesDirectory)\nuget-artifact\final-package - nuget verify -Signatures *.nupkg - displayName: List Downloaded Package - - - powershell: | - New-Item -Path $(Agent.TempDirectory) -Name "binfiles" -ItemType "directory" - $base_path_name = Join-Path -Path $(Agent.TempDirectory) -ChildPath "binfiles" - Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\final-package -Filter *.nupkg | - Foreach-Object { - $dir_name = Join-Path -Path $base_path_name -ChildPath $_.Basename - $cmd = "7z.exe x $($_.FullName) -y -o$dir_name" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - dir $(Agent.TempDirectory) - tree $(Agent.TempDirectory) - workingDirectory: '$(Agent.TempDirectory)' - - - task: CodeSign@1 - displayName: 'Run Codesign Validation' - - - - task: PublishSecurityAnalysisLogs@3 - displayName: 'Publish Security Analysis Logs' - continueOnError: true - - - task: PostAnalysis@2 - inputs: - GdnBreakAllTools: true - GdnBreakPolicy: M365 - GdnBreakPolicyMinSev: Error - - #TODO: allow choosing different feeds - - task: NuGetCommand@2 - displayName: 'Copy Signed Native NuGet Package to ORT-NIGHTLY' - inputs: - command: 'push' - packagesToPush: '$(Build.BinariesDirectory)/nuget-artifact/final-package/*.nupkg' - publishVstsFeed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/7982ae20-ed19-4a35-a362-a96ac99897b7' - allowPackageConflicts: true - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() + - template: templates/publish-nuget-steps.yml + parameters: + stage_name: 'Publish_NuGet_Packag_And_Report' + include_cpu_ep: true + download_artifacts_steps: + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-dml' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-Training-CPU' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-GPU' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 63e70fa8e6488..d57a7585f3cff 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -55,7 +55,7 @@ stages: python_wheel_suffix: '_gpu' timeout: 480 docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 - trt_version: '10.0.1.6-1.cuda11.8' + trt_version: '10.2.0.19-1.cuda11.8' cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index df4ed796f24e7..8d1b6b7854e50 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.22.0.240425 + default: 2.23.0.240531 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index c5212bd495872..a8b12637b70f3 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,13 +2,13 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 - name: build_config displayName: Build Configuration type: string default: 'RelWithDebInfo' - + - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. type: boolean @@ -17,8 +17,19 @@ parameters: - name: DoEsrp displayName: Run code sign tasks? Must be true if you are doing an Onnx Runtime release. type: boolean + default: true + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean default: false +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + stages: - template: templates/qnn-ep-win.yml @@ -60,25 +71,25 @@ stages: inputs: artifactName: 'drop-nuget-qnn-x64' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' - + - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - QNN NuGet arm64' inputs: artifactName: 'drop-nuget-qnn-arm64' targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - + - task: PowerShell@2 displayName: 'Bundle NuGet' inputs: targetType: 'inline' script: | - + $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) $nuget_package_name = $x64_nupkgs[0].Name $x64_nuget_package = $x64_nupkgs[0].FullName $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) - + $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" Invoke-Expression -Command $x64_unzip_cmd @@ -94,9 +105,9 @@ stages: } $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') - $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" + $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" Invoke-Expression -Command $zip_cmd - + $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) move $merged_zip $merged_nuget workingDirectory: $(Build.BinariesDirectory) @@ -113,3 +124,13 @@ stages: artifactName: 'drop-signed-nuget-qnn' targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' + ArtifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml new file mode 100644 index 0000000000000..f4022a80b0568 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml @@ -0,0 +1,353 @@ +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +- name: UseIncreasedTimeoutForTests + displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. + type: boolean + default: false + +- name: DoCompliance + displayName: Run Compliance Tasks? + type: boolean + default: true + +- name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release + type: boolean + default: true + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. + type: boolean + default: false + +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + +- name: NugetPackageSuffix + displayName: Suffix to append to nuget package + type: string + default: 'NONE' + +resources: + repositories: + - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step + type: github + endpoint: ort-examples + name: microsoft/onnxruntime-inference-examples + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +variables: +- name: ReleaseVersionSuffix + value: '' + +stages: +- template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + +# ROCm +- stage: Linux_C_API_Packaging_ROCm_x64 + dependsOn: [] + jobs: + - job: Linux_C_API_Packaging_ROCm_x64 + workspace: + clean: all + timeoutInMinutes: 240 + pool: onnxruntime-Ubuntu2204-AMD-CPU + variables: + RocmVersion: '5.6' + RocmVersionPatchSuffix: '' + steps: + - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime + submodules: recursive + - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml + submodules: false + + # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, + # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur + --build-arg BUILD_UID=$(id -u) + --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 + --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) + --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root + --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: + --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib + Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) + CheckOutManyLinux: true + + - template: templates/set-version-number-variables-step.yml + + - task: Bash@3 + displayName: 'Build' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh + arguments: >- + -S $(Build.SourcesDirectory) + -B $(Build.BinariesDirectory) + -V $(RocmVersion) + -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) + -P python3.10 + + - script: | + set -e -x + mkdir $(Build.ArtifactStagingDirectory)/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata + ls -al $(Build.ArtifactStagingDirectory) + displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. + + - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml + parameters: + buildConfig: 'Release' + artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' + artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' + libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' + + - template: templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - template: templates/clean-agent-build-directory-step.yml + +- stage: NuGet_Packaging_ROCm + dependsOn: + - Setup + - Linux_C_API_Packaging_ROCm_x64 + condition: succeeded() + jobs: + - job: NuGet_Packaging_ROCm + workspace: + clean: all + # we need to use a 2022 pool to create the nuget package with MAUI targets. + # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing + pool: 'Onnxruntime-Win-CPU-2022' + variables: + breakCodesignValidationInjection: ${{ parameters.DoEsrp }} + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + + steps: + - checkout: self + submodules: true + fetchDepth: 1 + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-linux-x64-rocm' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: PowerShell@2 + displayName: 'Reconstruct Build Directory' + inputs: + targetType: inline + script: | + Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { + # *.tar will be created after *.tgz is extracted + $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + + Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { + $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + + $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory + foreach ($ort_dir in $ort_dirs) + { + $dirname = Split-Path -Path $ort_dir -Leaf + $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) + Write-Output "Renaming $ort_dir to $dirname" + Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname + } + + Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo + + - script: | + tree /F + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Inspect Build Binaries Directory' + + - script: | + mklink /D /J models C:\local\models + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Create models link' + + - task: NuGetToolInstaller@0 + displayName: Use Nuget 6.10.x + inputs: + versionSpec: 6.10.x + + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# bindings' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: > + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:IsLinuxBuild=true + -p:IsWindowsBuild=false + -p:IsMacOSBuild=false + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: templates/win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + DisplayName: 'ESRP - Sign C# dlls' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: 3.8 + + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + configuration: RelWithDebInfo + platform: 'Any CPU' + msbuildArguments: > + -t:CreatePackage + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:CurrentTime=$(BuildTime) + -p:CurrentDate=$(BuildDate) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - template: templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)' + DoEsrp: ${{ parameters.DoEsrp }} + + - template: templates/validate-package.yml + parameters: + PackageType: 'nuget' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' + PlatformsSupported: 'linux-x64' + VerifyNugetSigning: false + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-ROCm' + targetPath: '$(Build.ArtifactStagingDirectory)' + + - task: MSBuild@1 + displayName: 'Clean C#' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + +- template: nuget/templates/test_linux.yml + parameters: + AgentPool: AMD-GPU + ArtifactSuffix: 'ROCm' + StageSuffix: 'ROCm' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' + SpecificArtifact: ${{ parameters.specificArtifact }} + CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' + BuildId: ${{ parameters.BuildId }} + +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet ROCm Package' + ArtifactName: 'drop-signed-nuget-ROCm' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index b6943f9e1b77b..7dfafeb67acf8 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -49,9 +49,9 @@ jobs: value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240610.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.0.1.6-1.cuda11.8 + value: 10.2.0.19-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.0.1.6-1.cuda12.4 + value: 10.2.0.19-1.cuda12.5 pool: ${{ parameters.machine_pool }} steps: - checkout: self diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 02df814eb2a93..424ed6237260b 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -76,9 +76,9 @@ stages: displayName: 'Create models link' - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.2.1 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 6.2.1 + versionSpec: 6.10.x - task: PowerShell@2 displayName: Install MAUI workloads @@ -116,6 +116,11 @@ stages: DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: 3.8 + - task: MSBuild@1 displayName: 'Build Nuget Packages' inputs: @@ -193,13 +198,6 @@ stages: msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: RoslynAnalyzers@2 - displayName: 'Run Roslyn Analyzers' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\msbuild.exe" $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln -p:configuration="RelWithDebInfo" -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' - condition: and(succeeded(), eq('${{ parameters.DoCompliance }}', true)) - - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index cca53e36ebab9..2ca5129ac6e5d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -80,9 +80,9 @@ stages: - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.0.1.6-1.cuda11.8 + value: 10.2.0.19-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.0.1.6-1.cuda12.4 + value: 10.2.0.19-1.cuda12.5 steps: - checkout: self clean: true @@ -149,9 +149,9 @@ stages: value: '12' - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.0.1.6-1.cuda11.8 + value: 10.2.0.19-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.0.1.6-1.cuda12.4 + value: 10.2.0.19-1.cuda12.5 steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime submodules: false diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml index 01f0337be7714..dcd681bd4b915 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -65,9 +65,9 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} ${{ if eq(parameters.cuda_version, '11.8') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ${{ if eq(parameters.cuda_version, '12.2') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ${{ if eq(parameters.enable_linux_gpu, true) }}: - template: ../templates/py-linux-gpu.yml @@ -79,7 +79,7 @@ stages: cuda_version: ${{ parameters.cuda_version }} ${{ if eq(parameters.cuda_version, '11.8') }}: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 - trt_version: 10.0.1.6-1.cuda11.8 + trt_version: 10.2.0.19-1.cuda11.8 ${{ if eq(parameters.cuda_version, '12.2') }}: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240610.1 - trt_version: 10.0.1.6-1.cuda12.4 + trt_version: 10.2.0.19-1.cuda12.5 diff --git a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml index 3e2b3b585df9a..e0f0e6f358e70 100644 --- a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml @@ -44,3 +44,25 @@ stages: - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' + +- stage: Debug + dependsOn: Setup + jobs: + - job: D1 + pool: + name: 'onnxruntime-Ubuntu2204-AMD-CPU' + variables: + MyVar: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + steps: + - checkout: none + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - bash: echo $(MyVar) + - bash: echo $(BuildTime) + - bash: echo $(BuildDate) + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index d7aecae18968e..7ba1179e7ad4d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -376,9 +376,9 @@ stages: workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Create models link' - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.2.1 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 6.2.1 + versionSpec: 6.10.x - task: PowerShell@2 displayName: Install mobile workloads @@ -478,13 +478,6 @@ stages: msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: RoslynAnalyzers@2 - displayName: 'Run Roslyn Analyzers' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\msbuild.exe" $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln -p:configuration="RelWithDebInfo" -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' - condition: and(succeeded(), eq('${{ parameters.DoCompliance }}', true)) - - template: component-governance-component-detection-steps.yml parameters : condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 9bd5a81181e10..cf350704f8356 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.163 + version: 1.0.165 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.163 + version: 1.0.165 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml index 31519a2cef376..c9b7c01146981 100644 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml @@ -53,24 +53,35 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} - - task: Bash@3 - inputs: - targetType: 'inline' - script: | - echo "Java Version" - java --version - mkdir test - pushd test - jar xf '$(Build.BinariesDirectory)/final-jar/testing.jar' - popd - # if you want to run the tests in the power shell, you need to replace ':' to ';', that is, "-cp .;.\test;protobuf-java-3.21.7.jar;onnxruntime-$(OnnxRuntimeVersion).jar" - java -jar ./junit-platform-console-standalone-1.6.2.jar -cp .:./test:./protobuf-java-3.21.7.jar:./onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)/final-jar' - env: - ${{ if eq(parameters.OS, 'MacOS') }}: - DYLD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(DYLD_LIBRARY_PATH)' - ${{ if eq(parameters.OS, 'Linux') }}: - LD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(LD_LIBRARY_PATH)' + - ${{ if eq(parameters.OS, 'Windows') }}: + - task: CmdLine@2 + inputs: + script: | + mkdir test + pushd test + jar xf $(Build.BinariesDirectory)\final-jar\testing.jar + popd + java -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.21.7.jar;onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner + workingDirectory: '$(Build.BinariesDirectory)\final-jar' + - ${{ else }}: + - task: Bash@3 + inputs: + targetType: 'inline' + script: | + echo "Java Version" + java --version + mkdir test + pushd test + jar xf '$(Build.BinariesDirectory)/final-jar/testing.jar' + popd + # if you want to run the tests in the power shell, you need to replace ':' to ';', that is, "-cp .;.\test;protobuf-java-3.21.7.jar;onnxruntime-$(OnnxRuntimeVersion).jar" + java -jar ./junit-platform-console-standalone-1.6.2.jar -cp .:./test:./protobuf-java-3.21.7.jar:./onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner + workingDirectory: '$(Build.BinariesDirectory)/final-jar' + env: + ${{ if eq(parameters.OS, 'MacOS') }}: + DYLD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(DYLD_LIBRARY_PATH)' + ${{ if eq(parameters.OS, 'Linux') }}: + LD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(LD_LIBRARY_PATH)' - ${{ if eq(parameters['OS'], 'MacOS') }}: - template: use-xcode-version.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 236998407ad16..ada3603ae8476 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.22.0.240425' + default: '2.23.0.240531' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index 0dd9ffd5282e7..de29a3de9fded 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -13,10 +13,10 @@ parameters: - 12.2 - name: TrtVersion type: string - default: '10.0.1.6' + default: '10.2.0.19' values: - 8.6.1.6 - - 10.0.1.6 + - 10.2.0.19 steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: @@ -42,9 +42,9 @@ steps: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.0" displayName: Set trtCudaVersion - - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.0.1.6')) }}: + - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.2.0.19')) }}: - powershell: | - Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.4" + Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.5" displayName: Set trtCudaVersion - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 0f43dfc497dff..3a68803896ab3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.22.0.240425' + default: '2.23.0.240531' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index 6c82958fc0b78..63d521f1e7d9a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -24,17 +24,11 @@ steps: displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' - ${{ if eq(parameters.DownloadTRT, 'true') }}: - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) - displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) + displayName: 'Download TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8' - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0" $(Agent.TempDirectory) - displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0' - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8' - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5" $(Agent.TempDirectory) + displayName: 'Download TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index fee4a82ceac78..9a35d7b75c708 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -123,9 +123,9 @@ jobs: displayName: 'API Documentation Check and generate' - task: NuGetToolInstaller@0 - displayName: Use Nuget 5.7.0 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 5.7.0 + versionSpec: 6.10.x - task: NuGetCommand@2 displayName: 'NuGet restore' diff --git a/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml b/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml index 9a666155028cc..0d62ed7907a67 100644 --- a/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml +++ b/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml @@ -9,12 +9,16 @@ parameters: steps: - task: CmdLine@2 displayName: 'Gradle cmakeCheck' - continueOnError: ${{ parameters.buildOnly }} inputs: - script: | - @echo on - call gradlew.bat cmakeCheck -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all - workingDirectory: $(Build.SourcesDirectory)\java + ${{ if eq(parameters.buildOnly, true) }}: + script: | + call gradlew.bat testClasses -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo + call gradlew.bat cmakeCheck -x test -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all + workingDirectory: $(Build.SourcesDirectory)\java + ${{ else }}: + script: | + call gradlew.bat cmakeCheck -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all + workingDirectory: $(Build.SourcesDirectory)\java - task: CmdLine@2 displayName: 'Add symbols and notices to Java' diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index 61cf5de8bdf17..fb9ff65fe8534 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -191,9 +191,9 @@ stages: displayName: 'Create models link' - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.2.1 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 6.2.1 + versionSpec: 6.10.x - task: PowerShell@2 displayName: Install mobile workloads @@ -228,6 +228,11 @@ stages: DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: 3.8 + - task: MSBuild@1 displayName: 'Build Nuget Packages' inputs: @@ -288,13 +293,6 @@ stages: msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: RoslynAnalyzers@2 - displayName: 'Run Roslyn Analyzers' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\msbuild.exe" $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln -p:configuration="RelWithDebInfo" -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' - condition: and(succeeded(), eq('${{ parameters.DoCompliance }}', true)) - - template: component-governance-component-detection-steps.yml parameters : condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml b/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml new file mode 100644 index 0000000000000..8639a5ca0a55d --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/publish-nuget-steps.yml @@ -0,0 +1,139 @@ +parameters: +- name: include_cpu_ep + type: boolean + default: false +- name: download_artifacts_steps + type: stepList +- name: stage_name + type: string + default: 'Publish_NuGet_Package' + +stages: +- stage: ${{ parameters.stage_name }} + jobs: + - job: ${{ parameters.stage_name }} + workspace: + clean: all + variables: + - name: GDN_CODESIGN_TARGETDIRECTORY + value: '$(Agent.TempDirectory)\binfiles' + pool: 'onnxruntime-Win-CPU-2022' + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + # https://learn.microsoft.com/en-us/azure/devops/pipelines/yaml-schema/resources-pipelines-pipeline?view=azure-pipelines#pipeline-resource-metadata-as-predefined-variables + - script: | + echo $(resources.pipeline.build.sourceBranch) + echo $(Build.Reason) + displayName: 'Print triggering sourceBranch Name in resources' + + - checkout: self + submodules: false + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.9' + addToPath: true + + - template: set-version-number-variables-step.yml + + - script: mkdir "$(Build.BinariesDirectory)\nuget-artifact\final-package" + + - template: ../nuget/templates/get-nuget-package-version-as-variable.yml + parameters: + packageFolder: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + + - ${{if eq(parameters.include_cpu_ep, true)}}: + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-CPU' + + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-CPU\*" "$(Build.BinariesDirectory)\nuget-artifact\final-package" + + - task: CmdLine@2 + displayName: 'Post binary sizes to the dashboard database using command line' + inputs: + script: | + echo changing directory to artifact download path + cd $(Build.BinariesDirectory)/nuget-artifact/final-package + echo processing nupkg + SETLOCAL EnableDelayedExpansion + FOR /R %%i IN (*.nupkg) do ( + set filename=%%~ni + IF NOT "!filename:~25,7!"=="Managed" ( + echo processing %%~ni.nupkg + copy %%~ni.nupkg %%~ni.zip + echo copied to zip + echo listing lib files in the zip + REM use a single .csv file to put the data + echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + ) + ) + + - task: AzureCLI@2 + displayName: 'Azure CLI' + #Only report binary sizes to database if the build build was auto-triggered from the main branch + condition: and (succeeded(), and(eq(variables['resources.pipeline.build.sourceBranch'], 'refs/heads/main'), eq(variables['Build.Reason'], 'ResourceTrigger'))) + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptLocation: inlineScript + scriptType: batch + inlineScript: | + python.exe -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_to_dashboard\requirements.txt && ^ + python.exe $(Build.SourcesDirectory)\tools\ci_build\github\windows\post_binary_sizes_to_dashboard.py --commit_hash=$(Build.SourceVersion) --size_data_file=binary_size_data.txt --build_project=Lotus --build_id=$(Build.BuildId) + workingDirectory: '$(Build.BinariesDirectory)' + + - ${{ parameters.download_artifacts_steps }} + + - script: | + dir $(Build.BinariesDirectory)\nuget-artifact\final-package + cd $(Build.BinariesDirectory)\nuget-artifact\final-package + nuget verify -Signatures *.nupkg + displayName: List Downloaded Package + + - powershell: | + New-Item -Path $(Agent.TempDirectory) -Name "binfiles" -ItemType "directory" + $base_path_name = Join-Path -Path $(Agent.TempDirectory) -ChildPath "binfiles" + Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\final-package -Filter *.nupkg | + Foreach-Object { + $dir_name = Join-Path -Path $base_path_name -ChildPath $_.Basename + $cmd = "7z.exe x $($_.FullName) -y -o$dir_name" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + dir $(Agent.TempDirectory) + tree $(Agent.TempDirectory) + workingDirectory: '$(Agent.TempDirectory)' + + - task: CodeSign@1 + displayName: 'Run Codesign Validation' + + + - task: PublishSecurityAnalysisLogs@3 + displayName: 'Publish Security Analysis Logs' + continueOnError: true + + - task: PostAnalysis@2 + inputs: + GdnBreakAllTools: true + GdnBreakPolicy: M365 + GdnBreakPolicyMinSev: Error + + #TODO: allow choosing different feeds + - task: NuGetCommand@2 + displayName: 'Copy Signed Native NuGet Package to ORT-NIGHTLY' + inputs: + command: 'push' + packagesToPush: '$(Build.BinariesDirectory)/nuget-artifact/final-package/*.nupkg' + publishVstsFeed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/7982ae20-ed19-4a35-a362-a96ac99897b7' + allowPackageConflicts: true + + - template: component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index 97f95797be1f1..6c66cceb33d5c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -22,10 +22,10 @@ parameters: - name: trt_version type: string - default: '10.0.1.6-1.cuda11.8' + default: '10.2.0.19-1.cuda11.8' values: - - 10.0.1.6-1.cuda11.8 - - 10.0.1.6-1.cuda12.4 + - 10.2.0.19-1.cuda11.8 + - 10.2.0.19-1.cuda12.5 - name: cuda_version type: string default: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 3081624225b12..8eca22c8c123f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -18,10 +18,10 @@ parameters: - name: trt_version type: string - default: '10.0.1.6-1.cuda11.8' + default: '10.2.0.19-1.cuda11.8' values: - - 10.0.1.6-1.cuda11.8 - - 10.0.1.6-1.cuda12.4 + - 10.2.0.19-1.cuda11.8 + - 10.2.0.19-1.cuda12.5 - name: cuda_version type: string default: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml index cc07df59da619..47980955b8798 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml @@ -152,17 +152,6 @@ stages: filename: 'C:\Program Files\Intel\openvino_2021.4.752\bin\setupvars.bat' modifyEnvironment: true - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import sys - np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - task: PowerShell@2 displayName: 'Install ONNX' inputs: @@ -392,7 +381,7 @@ stages: variables: CUDA_VERSION: '11.8' buildArch: x64 - EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" + EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" EnvSetupScript: setup_env_gpu.bat EP_NAME: gpu VSGenerator: 'Visual Studio 17 2022' @@ -419,17 +408,6 @@ stages: modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import sys - np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - task: PowerShell@2 displayName: 'Install ONNX' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 60c365ed32f84..27f85dc5c1648 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -63,7 +63,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.22.0.240425 + default: 2.23.0.240531 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: @@ -153,17 +153,6 @@ stages: modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import sys - np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - template: download-deps.yml - task: PythonScript@0 @@ -299,7 +288,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.8' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -309,7 +298,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.9' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -319,7 +308,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -329,7 +318,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -339,7 +328,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -509,7 +498,7 @@ stages: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} - trt_version: '10.0.1.6-1.cuda11.8' + trt_version: '10.2.0.19-1.cuda11.8' cuda_version: '11.8' - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 32fdf4819bd88..af239b4384af9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 - name: PYTHON_VERSION type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index c7a74a7f0e9c7..e89227d51de32 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -89,20 +89,6 @@ stages: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import sys - np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.26' - import subprocess - try: - subprocess.check_call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) - except subprocess.CalledProcessError: - sys.exit(1) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - template: download-deps.yml - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 668e51c828dcd..884e6eafee965 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 - name: ENV_SETUP_SCRIPT type: string @@ -60,17 +60,6 @@ jobs: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import sys - np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 6534490dd9ade..b4c4f36c5dcc6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.22.0.240425' + QnnSdk: '2.23.0.240531' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index 231e375bedd9e..47c0ba3eb2d1e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -17,7 +17,7 @@ jobs: - job: wheels_python_${{ replace(parameters.PythonVersion,'.','_') }}_rocm_${{ replace(parameters.RocmVersion,'.','_') }}_${{ parameters.BuildConfig }} workspace: clean: all - timeoutInMinutes: 180 + timeoutInMinutes: 360 pool: Ubuntu-2204-rocm-aiinfra variables: - name: PythonVersion diff --git a/tools/ci_build/github/azure-pipelines/templates/use-android-emulator.yml b/tools/ci_build/github/azure-pipelines/templates/use-android-emulator.yml index b31882c8da18f..4251a8401f8f0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-android-emulator.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-android-emulator.yml @@ -15,6 +15,25 @@ parameters: steps: - ${{ if eq(parameters.create, true) }}: + - script: | + if [[ ":$PATH:" == *":${ANDROID_SDK_ROOT}/emulator:"* ]]; then + echo "${ANDROID_SDK_ROOT}/emulator is in PATH" + else + ${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager --install "emulator" + echo "##vso[task.prependpath]${ANDROID_SDK_ROOT}/emulator" + fi + displayName: Check if emulator are installed and add to PATH + + - script: | + if [[ ":$PATH:" == *":${ANDROID_SDK_ROOT}/platform-tools:"* ]]; then + echo "${ANDROID_SDK_ROOT}/platform-tools is in PATH" + else + ${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager --install "platform-tools" + echo "##vso[task.prependpath]${ANDROID_SDK_ROOT}/platform-tools" + fi + ls -R ${ANDROID_SDK_ROOT}/platform-tools + displayName: Check if platform tools are installed and add to PATH + - script: | set -e -x python3 tools/python/run_android_emulator.py \ diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index c726054d8eb10..52547fd9a796b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -196,11 +196,11 @@ stages: parameters: msbuildPlatform: ${{ parameters.msbuildPlatform }} java_artifact_id: ${{ parameters.java_artifact_id }} - ${{ if contains(parameters.ort_build_pool_name, 'CPU') }}: - buildOnly: false + ${{ if or(contains(parameters.buildparameter, 'use_cuda'), contains(parameters.buildparameter, 'use_tensorrt')) }}: # When it is a GPU build, we only assemble the java binaries, testing will be done in the later stage with GPU machine - ${{ else }}: buildOnly: true + ${{ else }}: + buildOnly: false - task: PublishBuildArtifacts@1 displayName: 'Publish Java temp binaries' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index d6fa5184f6882..b5120f01bff3e 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -39,9 +39,9 @@ jobs: versionSpec: '18.x' - task: NuGetToolInstaller@0 - displayName: Use Nuget 5.7.0 + displayName: Use Nuget 6.10.x inputs: - versionSpec: 5.7.0 + versionSpec: 6.10.x - task: PythonScript@0 displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 39e68f5631f01..7d64f78c695fa 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -137,6 +137,25 @@ stages: WITH_CACHE: false MachinePool: 'onnxruntime-Win-CPU-2022' +# Build only. Does not run any tests. +- stage: x64_release_vitisai + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + buildArch: x64 + additionalBuildFlags: --build_wheel --use_vitisai + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_release + RunOnnxRuntimeTests: false + isTraining: false + ORT_EP_NAME: VITISAI + GenerateDocumentation: false + WITH_CACHE: false + MachinePool: 'onnxruntime-Win-CPU-2022' + - stage: x64_release_winml dependsOn: [] jobs: diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 291e2f4e19401..438e51175c5b4 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -44,7 +44,7 @@ stages: buildArch: x64 additionalBuildFlags: >- --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" - --enable_cuda_profiling + --enable_cuda_profiling --enable_transformers_tool_test --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index 1af00da01241a..70c0c7d4a04e7 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -55,7 +55,7 @@ jobs: WithCache: True Today: $(TODAY) AdditionalKey: "gpu-tensorrt | RelWithDebInfo" - BuildPyArguments: '--config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86' + BuildPyArguments: '--config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86' MsbuildArguments: $(MsbuildArguments) BuildArch: 'x64' Platform: 'x64' @@ -75,7 +75,7 @@ jobs: del wheel_filename_file python.exe -m pip install -q --upgrade %WHEEL_FILENAME% set PATH=$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo;%PATH% - python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' displayName: 'Run tests' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 0053a4a64ee02..97745fd09fbf7 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index ede7b3d336768..2ab81e16cd57e 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.22.0.240425 + default: 2.23.0.240531 jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index 86c178aae519b..2d3dc05285e3c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG TRT_VERSION=10.0.1.6-1.cuda11.8 +ARG TRT_VERSION=10.2.0.19-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch new file mode 100644 index 0000000000000..a50788e98ffe0 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch @@ -0,0 +1,57 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to Test ONNX Runtime on UBI8 with TensorRT 10.0 and CUDA 11.8 by default + +# Build base image with required system packages +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +ARG TRT_VERSION=10.2.0.19-1.cuda11.8 +FROM $BASEIMAGE AS base +ARG TRT_VERSION +ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} + +RUN dnf install -y bash wget &&\ + dnf clean dbcache + +RUN pip3 install --upgrade pip +RUN pip3 install setuptools>=68.2.2 + +#Install TensorRT only if TRT_VERSION is not empty +RUN if [ -n "$TRT_VERSION" ]; then \ + echo "TRT_VERSION is $TRT_VERSION" && \ + dnf -y install \ + libnvinfer10-${TRT_VERSION} \ + libnvinfer-headers-devel-${TRT_VERSION} \ + libnvinfer-devel-${TRT_VERSION} \ + libnvinfer-lean10-${TRT_VERSION} \ + libnvonnxparsers10-${TRT_VERSION} \ + libnvonnxparsers-devel-${TRT_VERSION} \ + libnvinfer-dispatch10-${TRT_VERSION} \ + libnvinfer-plugin10-${TRT_VERSION} \ + libnvinfer-vc-plugin10-${TRT_VERSION} \ + libnvinfer-bin-${TRT_VERSION} \ + libnvinfer-plugin10-${TRT_VERSION} \ + libnvinfer-plugin-devel-${TRT_VERSION} \ + libnvinfer-vc-plugin-devel-${TRT_VERSION} \ + libnvinfer-lean-devel-${TRT_VERSION} \ + libnvinfer-dispatch-devel-${TRT_VERSION} \ + libnvinfer-headers-plugin-devel-${TRT_VERSION} && \ + dnf clean dbcache ; \ +else \ + echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ +fi + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && /tmp/scripts/install_java.sh && rm -rf /tmp/scripts + +RUN python3 -m pip uninstall -y torch +RUN python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 5ef56fd885ca7..1aca3e305452d 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.0.1.6-1+cuda11.8 +ARG TRT_VERSION=10.2.0.19-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg index 194a22850030c..5697120a48b2b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.0.1.6-1+cuda11.8 +ARG TRT_VERSION=10.2.0.19-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 deleted file mode 100644 index 8b32425afce1c..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 +++ /dev/null @@ -1,63 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to run ONNXRuntime with TensorRT integration - -FROM nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04 - - -# ONNX Runtime Variables -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main -ARG CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80 - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get update &&\ - apt-get install -y sudo git bash unattended-upgrades wget -RUN unattended-upgrade - -# Install python3 -RUN apt-get install -y --no-install-recommends \ - python3 \ - python3-pip \ - python3-dev \ - python3-wheel &&\ - cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python &&\ - ln -s /usr/bin/pip3 pip; - -RUN pip install --upgrade pip -RUN pip install setuptools>=68.2.2 - -# Install TensorRT -RUN v="8.4.1-1+cuda11.6" &&\ - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ - apt-get update &&\ - sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} \ - libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} \ - python3-libnvinfer=${v} libnvinfer-samples=${v} - -# Compile trtexec -RUN cd /usr/src/tensorrt/samples/trtexec && make - -# Install Valgrind -RUN apt-get install -y valgrind - -ARG BUILD_USER=onnxruntimedev -ARG BUILD_UID=1000 -RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID -USER $BUILD_USER -WORKDIR /code -ENV CUDA_MODULE_LOADING "LAZY" - -# Prepare onnxruntime repository & build onnxruntime with TensorRT -RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ - /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' &&\ - pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ - cd .. diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 deleted file mode 100644 index cfc7023ef8e61..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 +++ /dev/null @@ -1,92 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to run ONNXRuntime with TensorRT integration - -# Build base image with required system packages -FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base - -# The local directory into which to build and install CMAKE -ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get update &&\ - apt-get install -y sudo git bash unattended-upgrades wget -RUN unattended-upgrade - -# Install python3 -RUN apt-get install -y --no-install-recommends \ - python3 \ - python3-pip \ - python3-dev \ - python3-wheel &&\ - cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python &&\ - ln -s /usr/bin/pip3 pip; - -RUN pip install --upgrade pip -RUN pip install setuptools>=68.2.2 - -# Install TensorRT -RUN v="8.5.1-1+cuda11.8" &&\ - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ - apt-get update &&\ - sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} \ - libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} \ - python3-libnvinfer=${v} libnvinfer-samples=${v} - -# Compile trtexec -RUN cd /usr/src/tensorrt/samples/trtexec && make - -# Install Valgrind -RUN apt-get install -y valgrind - -# Build final image from base. Builds ORT. -FROM base as final -ARG BUILD_USER=onnxruntimedev -ARG BUILD_UID=1000 -RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID -USER $BUILD_USER - -# ONNX Runtime arguments - -# URL to the github repo from which to clone ORT. -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime - -# The local directory into which to clone ORT. -ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code - -# The git branch of ORT to checkout and build. -ARG ONNXRUNTIME_BRANCH=main - -# Optional. The specific commit to pull and build from. If not set, the latest commit is used. -ARG ONNXRUNTIME_COMMIT_ID - -# The supported CUDA architecture -ARG CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80 - -WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR} - -# Clone ORT repository with branch -RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh - -WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime - -# Reset to a specific commit if specified by build args. -RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIME_BRANCH}" ;\ - else echo "Building branch ${ONNXRUNTIME_BRANCH} @ commit ${ONNXRUNTIME_COMMIT_ID}" &&\ - git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi - -# Build ORT -ENV CUDA_MODULE_LOADING "LAZY" -RUN /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' - -# Switch to root to continue following steps of CI -USER root - -# Intall ORT wheel -RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 similarity index 99% rename from tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt10_0 rename to tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 index cd168e1911d95..0bd56a1a5873f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 @@ -31,7 +31,7 @@ RUN pip install --upgrade pip RUN pip install psutil setuptools>=68.2.2 # Install TensorRT -RUN version="10.0.1.6-1+cuda11.8" &&\ +RUN version="10.2.0.19-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_4_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 similarity index 83% rename from tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_4_tensorrt10_0 rename to tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 index 3e48415118c63..7f66943dd8745 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_4_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 @@ -5,7 +5,7 @@ # Dockerfile to run ONNXRuntime with TensorRT integration # Build base image with required system packages -FROM nvidia/cuda:12.4.1-devel-ubuntu20.04 AS base +FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code @@ -30,15 +30,27 @@ RUN apt-get install -y --no-install-recommends \ RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 psutil -# Install cuDNN v9 -RUN apt-get -y install cudnn9-cuda-12 - # Install TensorRT -RUN version="10.0.1.6-1+cuda12.4" &&\ +RUN version="10.2.0.19-1+cuda12.5" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ - tensorrt=${version} + libnvinfer-dev=${version} \ + libnvinfer-dispatch-dev=${version} \ + libnvinfer-dispatch10=${version} \ + libnvinfer-headers-dev=${version} \ + libnvinfer-headers-plugin-dev=${version} \ + libnvinfer-lean-dev=${version} \ + libnvinfer-lean10=${version} \ + libnvinfer-plugin-dev=${version} \ + libnvinfer-plugin10=${version} \ + libnvinfer-vc-plugin-dev=${version} \ + libnvinfer-vc-plugin10=${version} \ + libnvinfer10=${version} \ + libnvonnxparsers-dev=${version} \ + libnvonnxparsers10=${version} \ + tensorrt-dev=${version} \ + libnvinfer-bin=${version} # Compile trtexec if not installed RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; then \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index 45682c797bbb8..dbd2076041b94 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -1,13 +1,13 @@ -ARG UBUNTU_VERSION=20.04 +ARG UBUNTU_VERSION=22.04 FROM ubuntu:${UBUNTU_VERSION} ARG OPENVINO_VERSION=2024.0.0 -ARG PYTHON_VERSION=3.9 +ARG PYTHON_VERSION=3.10 ADD scripts /tmp/scripts -RUN /tmp/scripts/install_ubuntu.sh -p ${PYTHON_VERSION} -d EdgeDevice && \ - /tmp/scripts/install_os_deps.sh -d EdgeDevice && \ - /tmp/scripts/install_python_deps.sh -p ${PYTHON_VERSION} -d EdgeDevice +RUN /tmp/scripts/install_ubuntu.sh -p $PYTHON_VERSION -d EdgeDevice +RUN /tmp/scripts/install_os_deps.sh -d EdgeDevice +RUN /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d EdgeDevice RUN apt update && apt install -y libnuma1 ocl-icd-libopencl1 && \ rm -rf /var/lib/apt/lists/* /tmp/scripts @@ -19,9 +19,9 @@ ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/linux/l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && \ - tar xzf l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && \ - mv l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64 openvino_2024.0.0 && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/linux/l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64.tgz && \ + tar xzf l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64.tgz && \ + mv l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64 openvino_2024.0.0 && \ cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y WORKDIR /root diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index a26bf88fbbdf6..0281c1c8fef25 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -5,7 +5,7 @@ # Dockerfile to run ONNXRuntime with TensorRT installed from provided binaries # Build base image with required system packages -FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base +FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code @@ -30,9 +30,6 @@ RUN apt-get install -y --no-install-recommends \ RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 -# Install cuDNN v9 -RUN apt-get -y install cudnn9-cuda-12 - # Install TensorRT # Must provide version numbers used to build the name of the tar file containing TensorRT binaries. # See: https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-tar diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt index cc47718f78a46..a977ccae1922f 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt @@ -1,6 +1,5 @@ -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' mypy pytest setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt index cc47718f78a46..a977ccae1922f 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt @@ -1,6 +1,5 @@ -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' mypy pytest setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile index 3a7f410d3859e..a0020a9827290 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile @@ -5,7 +5,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 FROM $BASEIMAGE -ARG TRT_VERSION=10.0.1.6-1.cuda11.8 +ARG TRT_VERSION=10.2.0.19-1.cuda11.8 #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "${TRT_VERSION}" ]; then \ diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt index cc47718f78a46..a977ccae1922f 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/requirements.txt @@ -1,6 +1,5 @@ -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' mypy pytest setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index 9c2f02dd34bf5..a980963429034 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -12,6 +12,7 @@ PYTHON_VER=${PYTHON_VER:=3.8} # Some Edge devices only have limited disk space, use this option to exclude some package DEVICE_TYPE=${DEVICE_TYPE:=Normal} +# shellcheck disable=SC2034 DEBIAN_FRONTEND=noninteractive echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections @@ -19,8 +20,6 @@ apt-get update && apt-get install -y software-properties-common lsb-release OS_VERSION=$(lsb_release -r -s) -SYS_LONG_BIT=$(getconf LONG_BIT) - PACKAGE_LIST="autotools-dev \ automake \ build-essential \ @@ -37,9 +36,8 @@ PACKAGE_LIST="autotools-dev \ gfortran \ python3-dev \ language-pack-en \ - liblttng-ust0 \ + liblttng-ust-dev \ libcurl4 \ - libssl1.1 \ libkrb5-3 \ libtinfo-dev \ libtinfo5 \ @@ -50,7 +48,7 @@ PACKAGE_LIST="autotools-dev \ unzip \ zip \ rsync libunwind8 libpng-dev libexpat1-dev \ - python3-setuptools python3-numpy python3-wheel python python3-pip python3-pytest \ + python3-setuptools python3-numpy python3-wheel python3-pip python3-pytest python3-distutils \ openjdk-11-jdk \ graphviz" @@ -59,7 +57,7 @@ if [ $DEVICE_TYPE = "Normal" ]; then PACKAGE_LIST="$PACKAGE_LIST libedit-dev libxml2-dev python3-packaging" fi -PACKAGE_LIST="$PACKAGE_LIST libicu66" +PACKAGE_LIST="$PACKAGE_LIST libicu-dev" apt-get install -y --no-install-recommends $PACKAGE_LIST @@ -67,8 +65,14 @@ locale-gen en_US.UTF-8 update-locale LANG=en_US.UTF-8 if [ "$OS_VERSION" = "20.04" ]; then + # The defaul version of python is 3.8 + major=$(echo $PYTHON_VER | cut -d. -f1) + minor=$(echo $PYTHON_VER | cut -d. -f2) + if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 8 ]; then + PYTHON_VER="3.8" + fi if [ "$PYTHON_VER" != "3.8" ]; then - add-apt-repository -y ppa:deadsnakes/ppa + add-apt-repository -y ppa:deadsnakes/ppa apt-get update apt-get install -y --no-install-recommends \ python${PYTHON_VER} \ @@ -80,6 +84,23 @@ if [ "$OS_VERSION" = "20.04" ]; then #put at /usr/local/. Then there will be two pips. /usr/bin/python${PYTHON_VER} -m pip install --upgrade --force-reinstall pip==19.0.3 fi +elif [ "$OS_VERSION" = "22.04" ] ; then + # The defaul version of python is 3.10 + major=$(echo $PYTHON_VER | cut -d. -f1) + minor=$(echo $PYTHON_VER | cut -d. -f2) + if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 10 ]; then + PYTHON_VER="3.10" + fi + if [ "$PYTHON_VER" != "3.10" ]; then + add-apt-repository -y ppa:deadsnakes/ppa + apt-get update + apt-get install -y --no-install-recommends \ + python${PYTHON_VER} \ + python${PYTHON_VER}-dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VER} 1 + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2 + update-alternatives --set python3 /usr/bin/python${PYTHON_VER} + fi else exit 1 fi diff --git a/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt b/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt index e9b222fe09711..d76a4337e7487 100644 --- a/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt @@ -8,7 +8,8 @@ onnx==1.16.1 astunparse expecttest!=0.2.0 hypothesis -numpy +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' psutil pyyaml requests diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index bdae9d72a1a63..12db3bd132bb7 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -1,6 +1,5 @@ -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' mypy pytest setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 3e619ea3dfb56..36af6aa71b075 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -1,7 +1,6 @@ cerberus -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' +numpy==1.24.4 ; python_version < '3.9' +numpy==2.0.0; python_version >= '3.9' mypy pytest setuptools==69.0.3 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt index 57331d6df97d9..89bda11737d10 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt @@ -1,3 +1,2 @@ -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' \ No newline at end of file +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt index 08e251eddbf96..ee4f8bd586804 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt @@ -5,6 +5,7 @@ setuptools>=68.2.2 cerberus h5py scikit-learn -numpy +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' pandas parameterized diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt index 47f64568f424a..d7fab6a1c8a27 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt @@ -1,8 +1,7 @@ pandas scikit-learn -numpy==1.21.6 ; python_version < '3.11' -numpy==1.24.2 ; python_version == '3.11' -numpy==1.26.0 ; python_version >= '3.12' +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' transformers==v4.36.0 accelerate==0.25.0 rsa==4.9 diff --git a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config index 215bc4015acde..3f1691f47e706 100644 --- a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config +++ b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config @@ -3,6 +3,7 @@ "os": "android", "arch": "arm64-v8a", "build_params": [ + "--enable_lto", "--android", "--android_sdk_path=/android_home", "--android_ndk_path=/ndk_home", diff --git a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config index 1348707a071cb..dbebec5788ddb 100644 --- a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config +++ b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config @@ -3,6 +3,7 @@ "os": "android", "arch": "arm64-v8a", "build_params": [ + "--enable_lto", "--android", "--android_sdk_path=/android_home", "--android_ndk_path=/ndk_home", diff --git a/tools/ci_build/github/linux/run_dockerbuild.sh b/tools/ci_build/github/linux/run_dockerbuild.sh index 440752bc819de..9944861f519f4 100755 --- a/tools/ci_build/github/linux/run_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_dockerbuild.sh @@ -20,16 +20,16 @@ ORTMODULE_BUILD=false #Training only USE_CONDA=false ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV="ALLOW_RELEASED_ONNX_OPSET_ONLY="$ALLOW_RELEASED_ONNX_OPSET_ONLY -echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as "$ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV +echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as $ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV" while getopts o:d:p:x:v:y:t:i:mue parameter_Option do case "${parameter_Option}" in -#yocto, ubuntu20.04 +#yocto, ubuntu22.04 o) BUILD_OS=${OPTARG};; #gpu, tensorrt or openvino. It is ignored when BUILD_OS is yocto. d) BUILD_DEVICE=${OPTARG};; -#python version: 3.6 3.7 (absence means default 3.6) +#python version: 3.8 3.9 3.10 3.11 3.12 (absence means default 3.8) p) PYTHON_VER=${OPTARG};; # "--build_wheel --use_openblas" x) BUILD_EXTR_PAR=${OPTARG};; @@ -48,9 +48,11 @@ m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; u) ORTMODULE_BUILD=true;; # install and use conda e) USE_CONDA=true;; +*) echo "Invalid option";; esac done +# shellcheck disable=SC2034 EXIT_CODE=1 DEFAULT_PYTHON_VER="3.8" @@ -62,7 +64,10 @@ if [[ -n "${IMAGE_CACHE_CONTAINER_REGISTRY_NAME}" ]]; then GET_DOCKER_IMAGE_CMD="${GET_DOCKER_IMAGE_CMD} --container-registry ${IMAGE_CACHE_CONTAINER_REGISTRY_NAME}" fi DOCKER_CMD="docker" - +# If BUILD_OS is ubuntu, then UBUNTU_VERSION is set to the version string after ubuntu +if [[ $BUILD_OS == ubuntu* ]]; then + UBUNTU_VERSION=${BUILD_OS#ubuntu} +fi NEED_BUILD_SHARED_LIB=true cd $SCRIPT_DIR/docker @@ -96,10 +101,9 @@ elif [ $BUILD_DEVICE = "gpu" ]; then --docker-build-args="--build-arg BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-${BUILD_OS} --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg INSTALL_DEPS_EXTRA_ARGS=\"${INSTALL_DEPS_EXTRA_ARGS}\" --build-arg USE_CONDA=${USE_CONDA} --network=host" \ --dockerfile Dockerfile.ubuntu_gpu_training --context . elif [[ $BUILD_DEVICE = "openvino"* ]]; then - BUILD_ARGS="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=3.8" + BUILD_ARGS="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg OPENVINO_VERSION=${OPENVINO_VERSION} --build-arg UBUNTU_VERSION=${UBUNTU_VERSION}" IMAGE="$BUILD_OS-openvino" DOCKER_FILE=Dockerfile.ubuntu_openvino - BUILD_ARGS+=" --build-arg OPENVINO_VERSION=${OPENVINO_VERSION}" $GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \ --docker-build-args="${BUILD_ARGS}" \ --dockerfile $DOCKER_FILE --context . diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh index 56f5ff9f9eac0..9cd1222cabfa6 100755 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh @@ -2,7 +2,7 @@ pip3 install --user --upgrade pip -pip3 install --user numpy==1.19.0 torch pytest +pip3 install --user numpy torch pytest pip3 install --user /build/Release/dist/*.whl export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.8/site-packages:$PYTHONPATH diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 59f6c0ab2136c..bf21a65314985 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.0 +ARG ROCM_VERSION=6.1 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' @@ -77,7 +77,7 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi RUN export MAJOR=$(cut -d '.' -f 1 <<< "$ROCM_VERSION") && \ export MINOR=$(cut -d '.' -f 2 <<< "$ROCM_VERSION") && \ export PATCH=$(cut -d '.' -f 3 <<< "$ROCM_VERSION") && \ - pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ && \ + pip install torch==2.1.2 torchvision==0.16.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ && \ pip install torch-ort --no-dependencies ##### Install Cupy to decrease CPU utilization diff --git a/tools/ci_build/github/windows/eager/requirements.txt b/tools/ci_build/github/windows/eager/requirements.txt index a820174957185..08e7baa76471b 100644 --- a/tools/ci_build/github/windows/eager/requirements.txt +++ b/tools/ci_build/github/windows/eager/requirements.txt @@ -1,6 +1,7 @@ setuptools wheel -numpy +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' typing_extensions torch==1.13.1 parameterized diff --git a/tools/ci_build/github/windows/helpers.ps1 b/tools/ci_build/github/windows/helpers.ps1 index 0e7d279c9fa49..95a36aa24e904 100644 --- a/tools/ci_build/github/windows/helpers.ps1 +++ b/tools/ci_build/github/windows/helpers.ps1 @@ -635,16 +635,18 @@ function Install-ONNX { if ($lastExitCode -ne 0) { exit $lastExitCode } - + $temp_dir = Get-TempDirectory + $new_requirements_text_file = Join-Path $temp_dir "new_requirements.txt" Write-Host "Installing python packages..." - [string[]]$pip_args = "-m", "pip", "install", "-qq", "--disable-pip-version-check", "setuptools>=68.2.2", "wheel", "numpy", "protobuf==$protobuf_version" + Get-Content "$src_root\tools\ci_build\github\linux\docker\inference\x86_64\python\cpu\scripts\requirements.txt" | Select-String -pattern 'onnx' -notmatch | Out-File $new_requirements_text_file + + [string[]]$pip_args = "-m", "pip", "install", "-qq", "--disable-pip-version-check", "-r", $new_requirements_text_file &"python.exe" $pip_args if ($lastExitCode -ne 0) { exit $lastExitCode } $url=Get-DownloadURL -name onnx -src_root $src_root - $temp_dir = Get-TempDirectory $onnx_src_dir = Join-Path $temp_dir "onnx" $download_finished = DownloadAndExtract -Uri $url -InstallDirectory $onnx_src_dir -Force if(-Not $download_finished){ diff --git a/tools/ci_build/github/windows/post_to_dashboard/requirements.txt b/tools/ci_build/github/windows/post_to_dashboard/requirements.txt index b8c00a610b781..6ece3c1f92c4e 100644 --- a/tools/ci_build/github/windows/post_to_dashboard/requirements.txt +++ b/tools/ci_build/github/windows/post_to_dashboard/requirements.txt @@ -1,2 +1,2 @@ -azure-kusto-data[pandas]==3.0.1 -azure-kusto-ingest[pandas]==3.0.1 +azure-kusto-data[pandas]==4.5.1 +azure-kusto-ingest[pandas]==4.5.1 diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index b753cdae16b90..6c59866ea925a 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -6,10 +6,10 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( ) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8\lib;%PATH% @REM The default version is still cuda v11.8, because set cuda v12.2 after it -set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4\lib +set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5\lib if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 ) else ( diff --git a/tools/ci_build/github/windows/setup_env_trt.bat b/tools/ci_build/github/windows/setup_env_trt.bat index 4e43b5999a315..249bb98815897 100644 --- a/tools/ci_build/github/windows/setup_env_trt.bat +++ b/tools/ci_build/github/windows/setup_env_trt.bat @@ -6,6 +6,6 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( ) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64 ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8\lib;%PATH% set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY \ No newline at end of file diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index d200a2f666939..88d1cebc84f8d 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -1015,57 +1015,31 @@ def generate_files(line_list, args): # Process Training specific targets and props if args.package_name == "Microsoft.ML.OnnxRuntime.Training": - monoandroid_source_targets = os.path.join( - args.sources_path, - "csharp", - "src", - "Microsoft.ML.OnnxRuntime", - "targets", - "monoandroid11.0", - "targets.xml", - ) - monoandroid_target_targets = os.path.join( - args.sources_path, - "csharp", - "src", - "Microsoft.ML.OnnxRuntime", - "targets", - "monoandroid11.0", - args.package_name + ".targets", - ) - - net6_android_source_targets = os.path.join( + net8_android_source_targets = os.path.join( args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", - "net6.0-android", + "net8.0-android", "targets.xml", ) - net6_android_target_targets = os.path.join( + net8_android_target_targets = os.path.join( args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", - "net6.0-android", + "net8.0-android", args.package_name + ".targets", ) - os.system(copy_command + " " + monoandroid_source_targets + " " + monoandroid_target_targets) - os.system(copy_command + " " + net6_android_source_targets + " " + net6_android_target_targets) - - files_list.append("') - files_list.append( - "' - ) - + os.system(copy_command + " " + net8_android_source_targets + " " + net8_android_target_targets) files_list.append( - "' + "' ) files_list.append( - "' + "' ) # README diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h index eae7dc37941c7..88945b75c75e4 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -3,7 +3,7 @@ #include "iengine.h" #include "UniqueOrtPtr.h" -#include "core/common/gsl.h" +#include #include #include